diff options
267 files changed, 14084 insertions, 1871 deletions
diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..a2a260538 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,19 @@ +language: minimal +sudo: required +dist: xenial +cache: + directories: + - /home/travis/.cache/bazel/ +services: + - docker +matrix: + include: + - os: linux + arch: amd64 + env: RUNSC_PATH=./bazel-bin/runsc/linux_amd64_pure_stripped/runsc + - os: linux + arch: arm64 + env: RUNSC_PATH=./bazel-bin/runsc/linux_arm64_pure_stripped/runsc +script: + - uname -a + - make DOCKER_RUN_OPTIONS="" BAZEL_OPTIONS="build runsc:runsc" bazel && $RUNSC_PATH --alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do ls diff --git a/Dockerfile b/Dockerfile index 738623023..2bfdfec6c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,9 @@ -FROM ubuntu:bionic +FROM fedora:31 -RUN apt-get update && apt-get install -y curl gnupg2 git python python3 python3-distutils python3-pip -RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ - curl https://bazel.build/bazel-release.pub.gpg | apt-key add - -RUN apt-get update && apt-get install -y bazel && apt-get clean +RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel + +RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static + +RUN pip install pycparser WORKDIR /gvisor @@ -2,6 +2,9 @@ UID := $(shell id -u ${USER}) GID := $(shell id -g ${USER}) GVISOR_BAZEL_CACHE := $(shell readlink -f ~/.cache/bazel/) +# The --privileged is required to run tests. +DOCKER_RUN_OPTIONS ?= --privileged + all: runsc docker-build: @@ -19,7 +22,7 @@ bazel-server-start: docker-build -v "$(CURDIR):$(CURDIR)" \ --workdir "$(CURDIR)" \ --tmpfs /tmp:rw,exec \ - --privileged \ + $(DOCKER_RUN_OPTIONS) \ gvisor-bazel \ sh -c "while :; do sleep 100; done" && \ docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --non-unique --gid $(GID) -d $(HOME) gvisor" @@ -33,6 +33,20 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") gazelle_dependencies() +# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories" +# block below once 1876 is fixed. +# +# The com_google_protobuf repository below would trigger downloading a older +# version of org_golang_x_sys. If putting this repository statment in a place +# after that of the com_google_protobuf, this statement will not work as +# expectd to download a new version of org_golang_x_sys. +go_repository( + name = "org_golang_x_sys", + importpath = "golang.org/x/sys", + sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=", + version = "v0.0.0-20191220220014-0732a990476f", +) + # Load C++ rules. http_archive( name = "rules_cc", @@ -257,13 +271,6 @@ go_repository( ) go_repository( - name = "org_golang_x_sys", - importpath = "golang.org/x/sys", - sum = "h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=", - version = "v0.0.0-20190215142949-d0b11bdaac8a", -) - -go_repository( name = "org_golang_x_time", commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e", importpath = "golang.org/x/time", @@ -330,3 +337,13 @@ http_archive( "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", ], ) + +http_archive( + name = "com_google_benchmark", + sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a", + strip_prefix = "benchmark-1.5.0", + urls = [ + "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz", + "https://github.com/google/benchmark/archive/v1.5.0.tar.gz", + ], +) diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py index 3d32d3dda..5bdc4aa85 100644 --- a/benchmarks/harness/machine.py +++ b/benchmarks/harness/machine.py @@ -43,6 +43,8 @@ from benchmarks.harness import machine_mocks from benchmarks.harness import ssh_connection from benchmarks.harness import tunnel_dispatcher +log = logging.getLogger(__name__) + class Machine(object): """The machine object is the primary object for benchmarks. @@ -236,9 +238,10 @@ class RemoteMachine(Machine): archive=archive, dir=harness.REMOTE_INSTALLERS_PATH)) self._has_installers = True - # Execute the remote installer. - self.run("sudo {dir}/{file}".format( - dir=harness.REMOTE_INSTALLERS_PATH, file=installer)) + # Execute the remote installer. + self.run("sudo {dir}/{file}".format( + dir=harness.REMOTE_INSTALLERS_PATH, file=installer)) + if results: results[index] = True diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py index a50e34293..b8c8e42d4 100644 --- a/benchmarks/harness/ssh_connection.py +++ b/benchmarks/harness/ssh_connection.py @@ -13,7 +13,7 @@ # limitations under the License. """SSHConnection handles the details of SSH connections.""" - +import logging import os import warnings @@ -24,6 +24,8 @@ from benchmarks import harness # Get rid of paramiko Cryptography Warnings. warnings.filterwarnings(action="ignore", module=".*paramiko.*") +log = logging.getLogger(__name__) + def send_one_file(client: paramiko.SSHClient, path: str, remote_dir: str) -> str: @@ -94,10 +96,13 @@ class SSHConnection: The contents of stdout and stderr. """ with self._client() as client: + log.info("running command: %s", cmd) _, stdout, stderr = client.exec_command(command=cmd) - stdout.channel.recv_exit_status() + log.info("returned status: %d", stdout.channel.recv_exit_status()) stdout = stdout.read().decode("utf-8") stderr = stderr.read().decode("utf-8") + log.info("stdout: %s", stdout) + log.info("stderr: %s", stderr) return stdout, stderr def send_workload(self, name: str) -> str: diff --git a/kokoro/benchmark_tests.cfg b/kokoro/benchmark_tests.cfg new file mode 100644 index 000000000..c48518a05 --- /dev/null +++ b/kokoro/benchmark_tests.cfg @@ -0,0 +1,26 @@ +build_file : 'repo/scripts/benchmark.sh' + + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id : 73898 + keyname : 'kokoro-rbe-service-account' + }, + } +} + +env_vars { + key : 'PROJECT' + value : 'gvisor-kokoro-testing' +} + +env_vars { + key : 'ZONE' + value : 'us-central1-b' +} + +env_vars { + key : 'KOKORO_SERVICE_ACCOUNT' + value : '73898_kokoro-rbe-service-account' +} diff --git a/kokoro/runtime_tests/go1.12.cfg b/kokoro/runtime_tests/go1.12.cfg index 164ddc18f..fd4911e88 100644 --- a/kokoro/runtime_tests/go1.12.cfg +++ b/kokoro/runtime_tests/go1.12.cfg @@ -4,3 +4,13 @@ env_vars { key: "RUNTIME_TEST_NAME" value: "go1.12" } + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + regex: "**/runsc" + regex: "**/runsc.*" + } +}
\ No newline at end of file diff --git a/kokoro/runtime_tests/java11.cfg b/kokoro/runtime_tests/java11.cfg index 4957d4794..7f8611a08 100644 --- a/kokoro/runtime_tests/java11.cfg +++ b/kokoro/runtime_tests/java11.cfg @@ -4,3 +4,13 @@ env_vars { key: "RUNTIME_TEST_NAME" value: "java11" } + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + regex: "**/runsc" + regex: "**/runsc.*" + } +}
\ No newline at end of file diff --git a/kokoro/runtime_tests/nodejs12.4.0.cfg b/kokoro/runtime_tests/nodejs12.4.0.cfg index 1df343f95..c67ad5567 100644 --- a/kokoro/runtime_tests/nodejs12.4.0.cfg +++ b/kokoro/runtime_tests/nodejs12.4.0.cfg @@ -4,3 +4,13 @@ env_vars { key: "RUNTIME_TEST_NAME" value: "nodejs12.4.0" } + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + regex: "**/runsc" + regex: "**/runsc.*" + } +}
\ No newline at end of file diff --git a/kokoro/runtime_tests/php7.3.6.cfg b/kokoro/runtime_tests/php7.3.6.cfg index 8e3667125..f266c5e26 100644 --- a/kokoro/runtime_tests/php7.3.6.cfg +++ b/kokoro/runtime_tests/php7.3.6.cfg @@ -4,3 +4,13 @@ env_vars { key: "RUNTIME_TEST_NAME" value: "php7.3.6" } + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + regex: "**/runsc" + regex: "**/runsc.*" + } +}
\ No newline at end of file diff --git a/kokoro/runtime_tests/python3.7.3.cfg b/kokoro/runtime_tests/python3.7.3.cfg index 0ca70d5bb..574add152 100644 --- a/kokoro/runtime_tests/python3.7.3.cfg +++ b/kokoro/runtime_tests/python3.7.3.cfg @@ -4,3 +4,13 @@ env_vars { key: "RUNTIME_TEST_NAME" value: "python3.7.3" } + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + regex: "**/runsc" + regex: "**/runsc.*" + } +}
\ No newline at end of file diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a89f34d4b..322d1ccc4 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -30,6 +30,7 @@ go_library( "futex.go", "inotify.go", "ioctl.go", + "ioctl_tun.go", "ip.go", "ipc.go", "limits.go", diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go index 6e4de69da..1121a1a92 100644 --- a/pkg/abi/linux/epoll.go +++ b/pkg/abi/linux/epoll.go @@ -14,6 +14,10 @@ package linux +import ( + "gvisor.dev/gvisor/pkg/binary" +) + // Event masks. const ( EPOLLIN = 0x1 @@ -53,3 +57,6 @@ const ( EPOLL_CTL_DEL = 0x2 EPOLL_CTL_MOD = 0x3 ) + +// SizeOfEpollEvent is the size of EpollEvent struct. +var SizeOfEpollEvent = int(binary.Size(EpollEvent{})) diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go index 57041491c..34ff18009 100644 --- a/pkg/abi/linux/epoll_amd64.go +++ b/pkg/abi/linux/epoll_amd64.go @@ -15,6 +15,8 @@ package linux // EpollEvent is equivalent to struct epoll_event from epoll(2). +// +// +marshal type EpollEvent struct { Events uint32 // Linux makes struct epoll_event::data a __u64. We represent it as diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go index 62ef5821e..f86c35329 100644 --- a/pkg/abi/linux/epoll_arm64.go +++ b/pkg/abi/linux/epoll_arm64.go @@ -15,6 +15,8 @@ package linux // EpollEvent is equivalent to struct epoll_event from epoll(2). +// +// +marshal type EpollEvent struct { Events uint32 // Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index c3ab15a4f..e229ac21c 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -241,6 +241,8 @@ const ( ) // Statx represents struct statx. +// +// +marshal type Statx struct { Mask uint32 Blksize uint32 diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index 2c652baa2..158d2db5b 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -38,6 +38,8 @@ const ( ) // Statfs is struct statfs, from uapi/asm-generic/statfs.h. +// +// +marshal type Statfs struct { // Type is one of the filesystem magic values, defined above. Type uint64 diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 0e18db9ef..2062e6a4b 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -72,3 +72,29 @@ const ( SIOCGMIIPHY = 0x8947 SIOCGMIIREG = 0x8948 ) + +// ioctl(2) directions. Used to calculate requests number. +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NONE = 0 + _IOC_WRITE = 1 + _IOC_READ = 2 +) + +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NRBITS = 8 + _IOC_TYPEBITS = 8 + _IOC_SIZEBITS = 14 + _IOC_DIRBITS = 2 + + _IOC_NRSHIFT = 0 + _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS + _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS + _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS +) + +// IOC outputs the result of _IOC macro in asm-generic/ioctl.h. +func IOC(dir, typ, nr, size uint32) uint32 { + return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT +} diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/abi/linux/ioctl_tun.go index 876783e78..c59c9c136 100644 --- a/pkg/usermem/usermem_unsafe.go +++ b/pkg/abi/linux/ioctl_tun.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package usermem +package linux -import ( - "unsafe" +// ioctl(2) request numbers from linux/if_tun.h +var ( + TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4) + TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4) ) -// stringFromImmutableBytes is equivalent to string(bs), except that it never -// copies even if escape analysis can't prove that bs does not escape. This is -// only valid if bs is never mutated after stringFromImmutableBytes returns. -func stringFromImmutableBytes(bs []byte) string { - // Compare strings.Builder.String(). - return *(*string)(unsafe.Pointer(&bs)) -} +// Flags from net/if_tun.h +const ( + IFF_TUN = 0x0001 + IFF_TAP = 0x0002 + IFF_NO_PI = 0x1000 + IFF_NOFILTER = 0x1000 +) diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 2179ac995..314c318b6 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -225,11 +225,14 @@ type XTEntryTarget struct { // SizeOfXTEntryTarget is the size of an XTEntryTarget. const SizeOfXTEntryTarget = 32 -// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or -// RETURN. It corresponds to struct xt_standard_target in +// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, +// RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. type XTStandardTarget struct { - Target XTEntryTarget + Target XTEntryTarget + // A positive verdict indicates a jump, and is the offset from the + // start of the table to jump to. A negative value means one of the + // other built-in targets. Verdict int32 _ [4]byte } diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index c69b04ea9..1c330e763 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -115,6 +115,8 @@ const ( ) // SignalSet is a signal mask with a bit corresponding to each signal. +// +// +marshal type SignalSet uint64 // SignalSetSize is the size in bytes of a SignalSet. diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go index e562b46d9..e6860ed49 100644 --- a/pkg/abi/linux/time.go +++ b/pkg/abi/linux/time.go @@ -157,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec { const SizeOfTimeval = 16 // Timeval represents struct timeval in <time.h>. +// +// +marshal type Timeval struct { Sec int64 Usec int64 @@ -230,6 +232,8 @@ type Tms struct { type TimerID int32 // StatxTimestamp represents struct statx_timestamp. +// +// +marshal type StatxTimestamp struct { Sec int64 Nsec uint32 @@ -258,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) { } // Utime represents struct utimbuf used by utimes(2). +// +// +marshal type Utime struct { Actime int64 Modtime int64 diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go index a3b6406fa..99180b208 100644 --- a/pkg/abi/linux/xattr.go +++ b/pkg/abi/linux/xattr.go @@ -18,6 +18,7 @@ package linux const ( XATTR_NAME_MAX = 255 XATTR_SIZE_MAX = 65536 + XATTR_LIST_MAX = 65536 XATTR_CREATE = 1 XATTR_REPLACE = 2 diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 3948074ba..1a30f6967 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -5,10 +5,10 @@ package(licenses = ["notice"]) go_library( name = "atomicbitops", srcs = [ - "atomic_bitops.go", - "atomic_bitops_amd64.s", - "atomic_bitops_arm64.s", - "atomic_bitops_common.go", + "atomicbitops.go", + "atomicbitops_amd64.s", + "atomicbitops_arm64.s", + "atomicbitops_noasm.go", ], visibility = ["//:sandbox"], ) @@ -16,7 +16,7 @@ go_library( go_test( name = "atomicbitops_test", size = "small", - srcs = ["atomic_bitops_test.go"], + srcs = ["atomicbitops_test.go"], library = ":atomicbitops", deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomicbitops.go index fcc41a9ea..1be081719 100644 --- a/pkg/atomicbitops/atomic_bitops.go +++ b/pkg/atomicbitops/atomicbitops.go @@ -14,47 +14,34 @@ // +build amd64 arm64 -// Package atomicbitops provides basic bitwise operations in an atomic way. -// The implementation on amd64 leverages the LOCK prefix directly instead of -// relying on the generic cas primitives, and the arm64 leverages the LDAXR -// and STLXR pair primitives. +// Package atomicbitops provides extensions to the sync/atomic package. // -// WARNING: the bitwise ops provided in this package doesn't imply any memory -// ordering. Using them to construct locks must employ proper memory barriers. +// All read-modify-write operations implemented by this package have +// acquire-release memory ordering (like sync/atomic). package atomicbitops -// AndUint32 atomically applies bitwise and operation to *addr with val. +// AndUint32 atomically applies bitwise AND operation to *addr with val. func AndUint32(addr *uint32, val uint32) -// OrUint32 atomically applies bitwise or operation to *addr with val. +// OrUint32 atomically applies bitwise OR operation to *addr with val. func OrUint32(addr *uint32, val uint32) -// XorUint32 atomically applies bitwise xor operation to *addr with val. +// XorUint32 atomically applies bitwise XOR operation to *addr with val. func XorUint32(addr *uint32, val uint32) // CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns // the value previously stored at addr. func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32 -// AndUint64 atomically applies bitwise and operation to *addr with val. +// AndUint64 atomically applies bitwise AND operation to *addr with val. func AndUint64(addr *uint64, val uint64) -// OrUint64 atomically applies bitwise or operation to *addr with val. +// OrUint64 atomically applies bitwise OR operation to *addr with val. func OrUint64(addr *uint64, val uint64) -// XorUint64 atomically applies bitwise xor operation to *addr with val. +// XorUint64 atomically applies bitwise XOR operation to *addr with val. func XorUint64(addr *uint64, val uint64) // CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns // the value previously stored at addr. func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64 - -// IncUnlessZeroInt32 increments the value stored at the given address and -// returns true; unless the value stored in the pointer is zero, in which case -// it is left unmodified and false is returned. -func IncUnlessZeroInt32(addr *int32) bool - -// DecUnlessOneInt32 decrements the value stored at the given address and -// returns true; unless the value stored in the pointer is 1, in which case it -// is left unmodified and false is returned. -func DecUnlessOneInt32(addr *int32) bool diff --git a/pkg/atomicbitops/atomic_bitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s index db0972001..54c887ee5 100644 --- a/pkg/atomicbitops/atomic_bitops_amd64.s +++ b/pkg/atomicbitops/atomicbitops_amd64.s @@ -75,41 +75,3 @@ TEXT ·CompareAndSwapUint64(SB),$0-32 CMPXCHGQ DX, 0(DI) MOVQ AX, ret+24(FP) RET - -TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9 - MOVQ addr+0(FP), DI - MOVL 0(DI), AX - -retry: - TESTL AX, AX - JZ fail - LEAL 1(AX), DX - LOCK - CMPXCHGL DX, 0(DI) - JNZ retry - - SETEQ ret+8(FP) - RET - -fail: - MOVB AX, ret+8(FP) - RET - -TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9 - MOVQ addr+0(FP), DI - MOVL 0(DI), AX - -retry: - LEAL -1(AX), DX - TESTL DX, DX - JZ fail - LOCK - CMPXCHGL DX, 0(DI) - JNZ retry - - SETEQ ret+8(FP) - RET - -fail: - MOVB DX, ret+8(FP) - RET diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s index 97f8808c1..5c780851b 100644 --- a/pkg/atomicbitops/atomic_bitops_arm64.s +++ b/pkg/atomicbitops/atomicbitops_arm64.s @@ -50,7 +50,6 @@ TEXT ·CompareAndSwapUint32(SB),$0-20 MOVD addr+0(FP), R0 MOVW old+8(FP), R1 MOVW new+12(FP), R2 - again: LDAXRW (R0), R3 CMPW R1, R3 @@ -95,7 +94,6 @@ TEXT ·CompareAndSwapUint64(SB),$0-32 MOVD addr+0(FP), R0 MOVD old+8(FP), R1 MOVD new+16(FP), R2 - again: LDAXR (R0), R3 CMP R1, R3 @@ -105,35 +103,3 @@ again: done: MOVD R3, prev+24(FP) RET - -TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9 - MOVD addr+0(FP), R0 - -again: - LDAXRW (R0), R1 - CBZ R1, fail - ADDW $1, R1 - STLXRW R1, (R0), R2 - CBNZ R2, again - MOVW $1, R2 - MOVB R2, ret+8(FP) - RET -fail: - MOVB ZR, ret+8(FP) - RET - -TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9 - MOVD addr+0(FP), R0 - -again: - LDAXRW (R0), R1 - SUBSW $1, R1, R1 - BEQ fail - STLXRW R1, (R0), R2 - CBNZ R2, again - MOVW $1, R2 - MOVB R2, ret+8(FP) - RET -fail: - MOVB ZR, ret+8(FP) - RET diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomicbitops_noasm.go index 85163ad62..3b2898256 100644 --- a/pkg/atomicbitops/atomic_bitops_common.go +++ b/pkg/atomicbitops/atomicbitops_noasm.go @@ -20,7 +20,6 @@ import ( "sync/atomic" ) -// AndUint32 atomically applies bitwise and operation to *addr with val. func AndUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -31,7 +30,6 @@ func AndUint32(addr *uint32, val uint32) { } } -// OrUint32 atomically applies bitwise or operation to *addr with val. func OrUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -42,7 +40,6 @@ func OrUint32(addr *uint32, val uint32) { } } -// XorUint32 atomically applies bitwise xor operation to *addr with val. func XorUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -53,8 +50,6 @@ func XorUint32(addr *uint32, val uint32) { } } -// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns -// the value previously stored at addr. func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) { for { prev = atomic.LoadUint32(addr) @@ -67,7 +62,6 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) { } } -// AndUint64 atomically applies bitwise and operation to *addr with val. func AndUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -78,7 +72,6 @@ func AndUint64(addr *uint64, val uint64) { } } -// OrUint64 atomically applies bitwise or operation to *addr with val. func OrUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -89,7 +82,6 @@ func OrUint64(addr *uint64, val uint64) { } } -// XorUint64 atomically applies bitwise xor operation to *addr with val. func XorUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -100,8 +92,6 @@ func XorUint64(addr *uint64, val uint64) { } } -// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns -// the value previously stored at addr. func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) { for { prev = atomic.LoadUint64(addr) @@ -113,35 +103,3 @@ func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) { } } } - -// IncUnlessZeroInt32 increments the value stored at the given address and -// returns true; unless the value stored in the pointer is zero, in which case -// it is left unmodified and false is returned. -func IncUnlessZeroInt32(addr *int32) bool { - for { - v := atomic.LoadInt32(addr) - if v == 0 { - return false - } - - if atomic.CompareAndSwapInt32(addr, v, v+1) { - return true - } - } -} - -// DecUnlessOneInt32 decrements the value stored at the given address and -// returns true; unless the value stored in the pointer is 1, in which case it -// is left unmodified and false is returned. -func DecUnlessOneInt32(addr *int32) bool { - for { - v := atomic.LoadInt32(addr) - if v == 1 { - return false - } - - if atomic.CompareAndSwapInt32(addr, v, v-1) { - return true - } - } -} diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomicbitops_test.go index 9466d3e23..73af71bb4 100644 --- a/pkg/atomicbitops/atomic_bitops_test.go +++ b/pkg/atomicbitops/atomicbitops_test.go @@ -196,67 +196,3 @@ func TestCompareAndSwapUint64(t *testing.T) { } } } - -func TestIncUnlessZeroInt32(t *testing.T) { - for _, test := range []struct { - initial int32 - final int32 - ret bool - }{ - { - initial: 0, - final: 0, - ret: false, - }, - { - initial: 1, - final: 2, - ret: true, - }, - { - initial: 2, - final: 3, - ret: true, - }, - } { - val := test.initial - if got, want := IncUnlessZeroInt32(&val), test.ret; got != want { - t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want) - } - if got, want := val, test.final; got != want { - t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want) - } - } -} - -func TestDecUnlessOneInt32(t *testing.T) { - for _, test := range []struct { - initial int32 - final int32 - ret bool - }{ - { - initial: 0, - final: -1, - ret: true, - }, - { - initial: 1, - final: 1, - ret: false, - }, - { - initial: 2, - final: 1, - ret: true, - }, - } { - val := test.initial - if got, want := DecUnlessOneInt32(&val), test.ret; got != want { - t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want) - } - if got, want := val, test.final; got != want { - t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want) - } - } -} diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index 333ca0a04..a0bc55ea1 100644 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go @@ -725,6 +725,18 @@ func vendorIDFromRegs(bx, cx, dx uint32) string { return string(bytes) } +var maxXsaveSize = func() uint32 { + // Leaf 0 of xsaveinfo function returns the size for currently + // enabled xsave features in ebx, the maximum size if all valid + // features are saved with xsave in ecx, and valid XCR0 bits in + // edx:eax. + // + // If xSaveInfo isn't supported, cpuid will not fault but will + // return bogus values. + _, _, maxXsaveSize, _ := HostID(uint32(xSaveInfo), 0) + return maxXsaveSize +}() + // ExtendedStateSize returns the number of bytes needed to save the "extended // state" for this processor and the boundary it must be aligned to. Extended // state includes floating point registers, and other cpu state that's not @@ -736,12 +748,7 @@ func vendorIDFromRegs(bx, cx, dx uint32) string { // about 2.5K worst case, with avx512). func (fs *FeatureSet) ExtendedStateSize() (size, align uint) { if fs.UseXsave() { - // Leaf 0 of xsaveinfo function returns the size for currently - // enabled xsave features in ebx, the maximum size if all valid - // features are saved with xsave in ecx, and valid XCR0 bits in - // edx:eax. - _, _, maxSize, _ := HostID(uint32(xSaveInfo), 0) - return uint(maxSize), 64 + return uint(maxXsaveSize), 64 } // If we don't support xsave, we fall back to fxsave, which requires diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index ee84471b2..67dd1e225 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -8,9 +8,11 @@ go_library( name = "fspath", srcs = [ "builder.go", - "builder_unsafe.go", "fspath.go", ], + deps = [ + "//pkg/gohacks", + ], ) go_test( diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go index 7ddb36826..6318d3874 100644 --- a/pkg/fspath/builder.go +++ b/pkg/fspath/builder.go @@ -16,6 +16,8 @@ package fspath import ( "fmt" + + "gvisor.dev/gvisor/pkg/gohacks" ) // Builder is similar to strings.Builder, but is used to produce pathnames @@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) { copy(b.buf[b.start:], b.buf[oldStart:]) copy(b.buf[len(b.buf)-len(str):], str) } + +// String returns the accumulated string. No other methods should be called +// after String. +func (b *Builder) String() string { + return gohacks.StringFromImmutableBytes(b.buf[b.start:]) +} diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go index 9fb3fee24..4c983d5fd 100644 --- a/pkg/fspath/fspath.go +++ b/pkg/fspath/fspath.go @@ -67,7 +67,8 @@ func Parse(pathname string) Path { // Path contains the information contained in a pathname string. // -// Path is copyable by value. +// Path is copyable by value. The zero value for Path is equivalent to +// fspath.Parse(""), i.e. the empty path. type Path struct { // Begin is an iterator to the first path component in the relative part of // the path. diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD new file mode 100644 index 000000000..798a65eca --- /dev/null +++ b/pkg/gohacks/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "gohacks", + srcs = [ + "gohacks_unsafe.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go new file mode 100644 index 000000000..aad675172 --- /dev/null +++ b/pkg/gohacks/gohacks_unsafe.go @@ -0,0 +1,57 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gohacks contains utilities for subverting the Go compiler. +package gohacks + +import ( + "reflect" + "unsafe" +) + +// Noescape hides a pointer from escape analysis. Noescape is the identity +// function but escape analysis doesn't think the output depends on the input. +// Noescape is inlined and currently compiles down to zero instructions. +// USE CAREFULLY! +// +// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().) +// +//go:nosplit +func Noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the +// same memory backing s instead of making a heap-allocated copy. This is only +// valid if the returned slice is never mutated. +func ImmutableBytesFromString(s string) []byte { + shdr := (*reflect.StringHeader)(unsafe.Pointer(&s)) + var bs []byte + bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + bshdr.Data = shdr.Data + bshdr.Len = shdr.Len + bshdr.Cap = shdr.Len + return bs +} + +// StringFromImmutableBytes is equivalent to string(bs), except that it uses +// the same memory backing bs instead of making a heap-allocated copy. This is +// only valid if bs is never mutated after StringFromImmutableBytes returns. +func StringFromImmutableBytes(bs []byte) string { + // This is cheaper than messing with reflect.StringHeader and + // reflect.SliceHeader, which as of this writing produces many dead stores + // of zeroes. Compare strings.Builder.String(). + return *(*string)(unsafe.Pointer(&bs)) +} diff --git a/pkg/log/glog.go b/pkg/log/glog.go index cab5fae55..b4f7bb5a4 100644 --- a/pkg/log/glog.go +++ b/pkg/log/glog.go @@ -46,7 +46,7 @@ var pid = os.Getpid() // line The line number // msg The user-supplied message // -func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) { +func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) { // Log level. prefix := byte('?') switch level { @@ -64,9 +64,7 @@ func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, ar microsecond := int(timestamp.Nanosecond() / 1000) // 0 = this frame. - // 1 = Debugf, etc. - // 2 = Caller. - _, file, line, ok := runtime.Caller(2) + _, file, line, ok := runtime.Caller(depth + 1) if ok { // Trim any directory path from the file. slash := strings.LastIndexByte(file, byte('/')) diff --git a/pkg/log/json.go b/pkg/log/json.go index a278c8fc8..0943db1cc 100644 --- a/pkg/log/json.go +++ b/pkg/log/json.go @@ -62,7 +62,7 @@ type JSONEmitter struct { } // Emit implements Emitter.Emit. -func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { j := jsonLog{ Msg: fmt.Sprintf(format, v...), Level: level, diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go index cee6eb514..6c6fc8b6f 100644 --- a/pkg/log/json_k8s.go +++ b/pkg/log/json_k8s.go @@ -33,7 +33,7 @@ type K8sJSONEmitter struct { } // Emit implements Emitter.Emit. -func (e *K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { j := k8sJSONLog{ Log: fmt.Sprintf(format, v...), Level: level, diff --git a/pkg/log/log.go b/pkg/log/log.go index 5056f17e6..a794da1aa 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -79,7 +79,7 @@ func (l Level) String() string { type Emitter interface { // Emit emits the given log statement. This allows for control over the // timestamp used for logging. - Emit(level Level, timestamp time.Time, format string, v ...interface{}) + Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) } // Writer writes the output to the given writer. @@ -142,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) { } // Emit emits the message. -func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) { +func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) { fmt.Fprintf(l, format, args...) } @@ -150,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i type MultiEmitter []Emitter // Emit emits to all emitters. -func (m *MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) { for _, e := range *m { - e.Emit(level, timestamp, format, v...) + e.Emit(1+depth, level, timestamp, format, v...) } } @@ -167,7 +167,7 @@ type TestEmitter struct { } // Emit emits to the TestLogger. -func (t *TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { t.Logf(format, v...) } @@ -198,22 +198,37 @@ type BasicLogger struct { // Debugf implements logger.Debugf. func (l *BasicLogger) Debugf(format string, v ...interface{}) { - if l.IsLogging(Debug) { - l.Emit(Debug, time.Now(), format, v...) - } + l.DebugfAtDepth(1, format, v...) } // Infof implements logger.Infof. func (l *BasicLogger) Infof(format string, v ...interface{}) { - if l.IsLogging(Info) { - l.Emit(Info, time.Now(), format, v...) - } + l.InfofAtDepth(1, format, v...) } // Warningf implements logger.Warningf. func (l *BasicLogger) Warningf(format string, v ...interface{}) { + l.WarningfAtDepth(1, format, v...) +} + +// DebugfAtDepth logs at a specific depth. +func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) { + if l.IsLogging(Debug) { + l.Emit(1+depth, Debug, time.Now(), format, v...) + } +} + +// InfofAtDepth logs at a specific depth. +func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) { + if l.IsLogging(Info) { + l.Emit(1+depth, Info, time.Now(), format, v...) + } +} + +// WarningfAtDepth logs at a specific depth. +func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) { if l.IsLogging(Warning) { - l.Emit(Warning, time.Now(), format, v...) + l.Emit(1+depth, Warning, time.Now(), format, v...) } } @@ -257,17 +272,32 @@ func SetLevel(newLevel Level) { // Debugf logs to the global logger. func Debugf(format string, v ...interface{}) { - Log().Debugf(format, v...) + Log().DebugfAtDepth(1, format, v...) } // Infof logs to the global logger. func Infof(format string, v ...interface{}) { - Log().Infof(format, v...) + Log().InfofAtDepth(1, format, v...) } // Warningf logs to the global logger. func Warningf(format string, v ...interface{}) { - Log().Warningf(format, v...) + Log().WarningfAtDepth(1, format, v...) +} + +// DebugfAtDepth logs to the global logger. +func DebugfAtDepth(depth int, format string, v ...interface{}) { + Log().DebugfAtDepth(1+depth, format, v...) +} + +// InfofAtDepth logs to the global logger. +func InfofAtDepth(depth int, format string, v ...interface{}) { + Log().InfofAtDepth(1+depth, format, v...) +} + +// WarningfAtDepth logs to the global logger. +func WarningfAtDepth(depth int, format string, v ...interface{}) { + Log().WarningfAtDepth(1+depth, format, v...) } // defaultStackSize is the default buffer size to allocate for stack traces. diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 3b6987665..d7794db63 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -34,6 +34,9 @@ const ( SyscallWidth = 4 ) +// ARMTrapFlag is the mask for the trap flag. +const ARMTrapFlag = uint64(1) << 21 + // aarch64FPState is aarch64 floating point state. type aarch64FPState []byte diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 4c4b7d5cc..9b6bb26d0 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -9,6 +9,7 @@ go_library( "device.go", "fs.go", "full.go", + "net_tun.go", "null.go", "random.go", "tty.go", @@ -19,15 +20,19 @@ go_library( "//pkg/context", "//pkg/rand", "//pkg/safemem", + "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/pgalloc", + "//pkg/sentry/socket/netstack", "//pkg/syserror", + "//pkg/tcpip/link/tun", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index 35bd23991..7e66c29b0 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo }) } -func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555)) +func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode { + iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), @@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { // A devpts is typically mounted at /dev/pts to provide // pseudoterminal support. Place an empty directory there for // the devpts to be mounted over. - "pts": newDirectory(ctx, msrc), + "pts": newDirectory(ctx, nil, msrc), // Similarly, applications expect a ptmx device at /dev/ptmx // connected to the terminals provided by /dev/pts/. Rather // than creating a device directly (which requires a hairy @@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "ptmx": newSymlink(ctx, "pts/ptmx", msrc), "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor), + + "net": newDirectory(ctx, map[string]*fs.Inode{ + "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor), + }, msrc), } iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go new file mode 100644 index 000000000..755644488 --- /dev/null +++ b/pkg/sentry/fs/dev/net_tun.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + netTunDevMajor = 10 + netTunDevMinor = 200 +) + +// +stateify savable +type netTunInodeOperations struct { + fsutil.InodeGenericChecker `state:"nosave"` + fsutil.InodeNoExtendedAttributes `state:"nosave"` + fsutil.InodeNoopAllocate `state:"nosave"` + fsutil.InodeNoopRelease `state:"nosave"` + fsutil.InodeNoopTruncate `state:"nosave"` + fsutil.InodeNoopWriteOut `state:"nosave"` + fsutil.InodeNotDirectory `state:"nosave"` + fsutil.InodeNotMappable `state:"nosave"` + fsutil.InodeNotSocket `state:"nosave"` + fsutil.InodeNotSymlink `state:"nosave"` + fsutil.InodeVirtual `state:"nosave"` + + fsutil.InodeSimpleAttributes +} + +var _ fs.InodeOperations = (*netTunInodeOperations)(nil) + +func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations { + return &netTunInodeOperations{ + InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC), + } +} + +// GetFile implements fs.InodeOperations.GetFile. +func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil +} + +// +stateify savable +type netTunFileOperations struct { + fsutil.FileNoSeek `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + device tun.Device +} + +var _ fs.FileOperations = (*netTunFileOperations)(nil) + +// Release implements fs.FileOperations.Release. +func (fops *netTunFileOperations) Release() { + fops.device.Release() +} + +// Ioctl implements fs.FileOperations.Ioctl. +func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + request := args[1].Uint() + data := args[2].Pointer() + + switch request { + case linux.TUNSETIFF: + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserror.EPERM + } + stack, ok := t.NetworkContext().(*netstack.Stack) + if !ok { + return 0, syserror.EINVAL + } + + var req linux.IFReq + if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + flags := usermem.ByteOrder.Uint16(req.Data[:]) + return 0, fops.device.SetIff(stack.Stack, req.Name(), flags) + + case linux.TUNGETIFF: + var req linux.IFReq + + copy(req.IFName[:], fops.device.Name()) + + // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when + // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. + flags := fops.device.Flags() | linux.IFF_NOFILTER + usermem.ByteOrder.PutUint16(req.Data[:], flags) + + _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +// Write implements fs.FileOperations.Write. +func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + data := make([]byte, src.NumBytes()) + if _, err := src.CopyIn(ctx, data); err != nil { + return 0, err + } + return fops.device.Write(data) +} + +// Read implements fs.FileOperations.Read. +func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + data, err := fops.device.Read() + if err != nil { + return 0, err + } + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { + // Not an error for partial copying. Packet truncated. + err = nil + } + return int64(n), err +} + +// Readiness implements watier.Waitable.Readiness. +func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return fops.device.Readiness(mask) +} + +// EventRegister implements watier.Waitable.EventRegister. +func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fops.device.EventRegister(e, mask) +} + +// EventUnregister implements watier.Waitable.EventUnregister. +func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { + fops.device.EventUnregister(e) +} diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go index e672a438c..a3d10770b 100644 --- a/pkg/sentry/fs/mount_test.go +++ b/pkg/sentry/fs/mount_test.go @@ -36,11 +36,12 @@ func mountPathsAre(root *Dirent, got []*Mount, want ...string) error { gotPaths := make(map[string]struct{}, len(got)) gotStr := make([]string, len(got)) for i, g := range got { - groot := g.Root() - name, _ := groot.FullName(root) - groot.DecRef() - gotStr[i] = name - gotPaths[name] = struct{}{} + if groot := g.Root(); groot != nil { + name, _ := groot.FullName(root) + groot.DecRef() + gotStr[i] = name + gotPaths[name] = struct{}{} + } } if len(got) != len(want) { return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want) diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index 574a2cc91..c7981f66e 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -100,10 +100,14 @@ func newUndoMount(d *Dirent) *Mount { } } -// Root returns the root dirent of this mount. Callers must call DecRef on the -// returned dirent. +// Root returns the root dirent of this mount. +// +// This may return nil if the mount has already been free. Callers must handle this +// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent. func (m *Mount) Root() *Dirent { - m.root.IncRef() + if !m.root.TryIncRef() { + return nil + } return m.root } diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go index c10888100..94deb553b 100644 --- a/pkg/sentry/fs/proc/mounts.go +++ b/pkg/sentry/fs/proc/mounts.go @@ -60,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) { }) for _, m := range ms { mroot := m.Root() + if mroot == nil { + continue // No longer valid. + } mountPath, desc := mroot.FullName(rootDir) mroot.DecRef() if !desc { // MountSources that are not descendants of the chroot jail are ignored. continue } - fn(mountPath, m) } } @@ -91,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se var buf bytes.Buffer forEachMount(mif.t, func(mountPath string, m *fs.Mount) { + mroot := m.Root() + if mroot == nil { + return // No longer valid. + } + defer mroot.DecRef() + // Format: // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11) @@ -107,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se // (3) Major:Minor device ID. We don't have a superblock, so we // just use the root inode device number. - mroot := m.Root() - defer mroot.DecRef() - sa := mroot.Inode.StableAttr fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor) @@ -207,6 +212,9 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan // // The "needs dump"and fsck flags are always 0, which is allowed. root := m.Root() + if root == nil { + return // No longer valid. + } defer root.DecRef() flags := root.Inode.MountSource.Flags diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 6f2775344..95d5817ff 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -43,7 +43,10 @@ import ( // newNet creates a new proc net entry. func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode { var contents map[string]*fs.Inode - if s := p.k.NetworkStack(); s != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. We should make this per-process, + // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. + if s := p.k.RootNetworkNamespace().Stack(); s != nil { contents = map[string]*fs.Inode{ "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 0772d4ae4..d4c4b533d 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { var contents map[string]*fs.Inode - if s := p.k.NetworkStack(); s != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. + if s := p.k.RootNetworkNamespace().Stack(); s != nil { contents = map[string]*fs.Inode{ "ipv4": p.newSysNetIPv4Dir(ctx, msrc, s), "core": p.newSysNetCore(ctx, msrc, s), diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index e657c39bc..6aa17bfc1 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -117,15 +117,19 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) // default anyways. // // TODO(gvisor.dev/issue/1623): Check mount has read and exec permission. -func (l *vfsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { +func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem() creds := auth.CredentialsFromContext(ctx) + path := fspath.Parse(pathname) pop := &vfs.PathOperation{ Root: l.root, - Start: l.root, - Path: fspath.Parse(path), + Start: l.workingDir, + Path: path, FollowFinalSymlink: resolveFinal, } + if path.Absolute { + pop.Start = l.root + } fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts) if err != nil { return nil, err diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index ce08a7d53..10c08fa90 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -73,9 +73,9 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}), "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"), "net": newNetDir(root, inoGen, k), - "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}), + "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}), "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}), - "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}), + "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}), } inode := &tasksInode{ diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go index 608fec017..d4e1812d8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_net.go +++ b/pkg/sentry/fsimpl/proc/tasks_net.go @@ -39,7 +39,10 @@ import ( func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { var contents map[string]*kernfs.Dentry - if stack := k.NetworkStack(); stack != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. We should make this per-process, + // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. + if stack := k.RootNetworkNamespace().Stack(); stack != nil { const ( arp = "IP address HW type Flags HW address Mask Device\n" netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n" diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index c7ce74883..3d5dc463c 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { var contents map[string]*kernfs.Dentry - if stack := k.NetworkStack(); stack != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. + if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ "ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ "tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}), diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index d0be32e72..488478e29 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -128,6 +128,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns ThreadGroup: tc, TaskContext: &kernel.TaskContext{Name: name}, Credentials: auth.CredentialsFromContext(ctx), + NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), UTSNamespace: kernel.UTSNamespaceFromContext(ctx), IPCNamespace: kernel.IPCNamespaceFromContext(ctx), diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 7f7b791c4..e1b551422 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -16,7 +16,6 @@ package tmpfs import ( "fmt" - "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -347,10 +346,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open return nil, err } if opts.Flags&linux.O_TRUNC != 0 { - impl.mu.Lock() - impl.data.Truncate(0, impl.memFile) - atomic.StoreUint64(&impl.size, 0) - impl.mu.Unlock() + if _, err := impl.truncate(0); err != nil { + return nil, err + } } return &fd.vfsfd, nil case *directory: diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index dab346a41..711442424 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -15,6 +15,7 @@ package tmpfs import ( + "fmt" "io" "math" "sync/atomic" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -34,25 +36,53 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// regularFile is a regular (=S_IFREG) tmpfs file. type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. memFile *pgalloc.MemoryFile - // mu protects the fields below. - mu sync.RWMutex + // mapsMu protects mappings. + mapsMu sync.Mutex `state:"nosave"` + + // mappings tracks mappings of the file into memmap.MappingSpaces. + // + // Protected by mapsMu. + mappings memmap.MappingSet + + // writableMappingPages tracks how many pages of virtual memory are mapped + // as potentially writable from this file. If a page has multiple mappings, + // each mapping is counted separately. + // + // This counter is susceptible to overflow as we can potentially count + // mappings from many VMAs. We count pages rather than bytes to slightly + // mitigate this. + // + // Protected by mapsMu. + writableMappingPages uint64 + + // dataMu protects the fields below. + dataMu sync.RWMutex // data maps offsets into the file to offsets into memFile that store // the file's data. + // + // Protected by dataMu. data fsutil.FileRangeSet - // size is the size of data, but accessed using atomic memory - // operations to avoid locking in inode.stat(). - size uint64 - // seals represents file seals on this inode. + // + // Protected by dataMu. seals uint32 + + // size is the size of data. + // + // Protected by both dataMu and inode.mu; reading it requires holding + // either mutex, while writing requires holding both AND using atomics. + // Readers that do not require consistency (like Stat) may read the + // value atomically without holding either lock. + size uint64 } func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { @@ -66,39 +96,170 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod // truncate grows or shrinks the file to the given size. It returns true if the // file size was updated. -func (rf *regularFile) truncate(size uint64) (bool, error) { - rf.mu.Lock() - defer rf.mu.Unlock() +func (rf *regularFile) truncate(newSize uint64) (bool, error) { + rf.inode.mu.Lock() + defer rf.inode.mu.Unlock() + return rf.truncateLocked(newSize) +} - if size == rf.size { +// Preconditions: rf.inode.mu must be held. +func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) { + oldSize := rf.size + if newSize == oldSize { // Nothing to do. return false, nil } - if size > rf.size { - // Growing the file. + // Need to hold inode.mu and dataMu while modifying size. + rf.dataMu.Lock() + if newSize > oldSize { + // Can we grow the file? if rf.seals&linux.F_SEAL_GROW != 0 { - // Seal does not allow growth. + rf.dataMu.Unlock() return false, syserror.EPERM } - rf.size = size + // We only need to update the file size. + atomic.StoreUint64(&rf.size, newSize) + rf.dataMu.Unlock() return true, nil } - // Shrinking the file + // We are shrinking the file. First check if this is allowed. if rf.seals&linux.F_SEAL_SHRINK != 0 { - // Seal does not allow shrink. + rf.dataMu.Unlock() return false, syserror.EPERM } - // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have - // mappings. + // Update the file size. + atomic.StoreUint64(&rf.size, newSize) + rf.dataMu.Unlock() + + // Invalidate past translations of truncated pages. + oldpgend := fs.OffsetPageEnd(int64(oldSize)) + newpgend := fs.OffsetPageEnd(int64(newSize)) + if newpgend < oldpgend { + rf.mapsMu.Lock() + rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ + // Compare Linux's mm/shmem.c:shmem_setattr() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + rf.mapsMu.Unlock() + } - rf.data.Truncate(size, rf.memFile) - rf.size = size + // We are now guaranteed that there are no translations of truncated pages, + // and can remove them. + rf.dataMu.Lock() + rf.data.Truncate(newSize, rf.memFile) + rf.dataMu.Unlock() return true, nil } +// AddMapping implements memmap.Mappable.AddMapping. +func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + rf.mapsMu.Lock() + defer rf.mapsMu.Unlock() + rf.dataMu.RLock() + defer rf.dataMu.RUnlock() + + // Reject writable mapping if F_SEAL_WRITE is set. + if rf.seals&linux.F_SEAL_WRITE != 0 && writable { + return syserror.EPERM + } + + rf.mappings.AddMapping(ms, ar, offset, writable) + if writable { + pagesBefore := rf.writableMappingPages + + // ar is guaranteed to be page aligned per memmap.Mappable. + rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize) + + if rf.writableMappingPages < pagesBefore { + panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) + } + } + + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + rf.mapsMu.Lock() + defer rf.mapsMu.Unlock() + + rf.mappings.RemoveMapping(ms, ar, offset, writable) + + if writable { + pagesBefore := rf.writableMappingPages + + // ar is guaranteed to be page aligned per memmap.Mappable. + rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize) + + if rf.writableMappingPages > pagesBefore { + panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) + } + } +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + return rf.AddMapping(ctx, ms, dstAR, offset, writable) +} + +// Translate implements memmap.Mappable.Translate. +func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + rf.dataMu.Lock() + defer rf.dataMu.Unlock() + + // Constrain translations to f.attr.Size (rounded up) to prevent + // translation to pages that may be concurrently truncated. + pgend := fs.OffsetPageEnd(int64(rf.size)) + var beyondEOF bool + if required.End > pgend { + if required.Start >= pgend { + return nil, &memmap.BusError{io.EOF} + } + beyondEOF = true + required.End = pgend + } + if optional.End > pgend { + optional.End = pgend + } + + cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { + // Newly-allocated pages are zeroed, so we don't need to do anything. + return dsts.NumBytes(), nil + }) + + var ts []memmap.Translation + var translatedEnd uint64 + for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() { + segMR := seg.Range().Intersect(optional) + ts = append(ts, memmap.Translation{ + Source: segMR, + File: rf.memFile, + Offset: seg.FileRangeOf(segMR).Start, + Perms: usermem.AnyAccess, + }) + translatedEnd = segMR.End + } + + // Don't return the error returned by f.data.Fill if it occurred outside of + // required. + if translatedEnd < required.End && cerr != nil { + return ts, &memmap.BusError{cerr} + } + if beyondEOF { + return ts, &memmap.BusError{io.EOF} + } + return ts, nil +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (*regularFile) InvalidateUnsavable(context.Context) error { + return nil +} + type regularFileFD struct { fileDescription @@ -152,8 +313,10 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // Overflow. return 0, syserror.EFBIG } + f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) + f.inode.mu.Unlock() putRegularFileReadWriter(rw) return n, err } @@ -215,6 +378,12 @@ func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng return nil } +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + file := fd.inode().impl.(*regularFile) + return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts) +} + // regularFileReadWriter implements safemem.Reader and Safemem.Writer. type regularFileReadWriter struct { file *regularFile @@ -244,14 +413,15 @@ func putRegularFileReadWriter(rw *regularFileReadWriter) { // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - rw.file.mu.RLock() + rw.file.dataMu.RLock() + defer rw.file.dataMu.RUnlock() + size := rw.file.size // Compute the range to read (limited by file size and overflow-checked). - if rw.off >= rw.file.size { - rw.file.mu.RUnlock() + if rw.off >= size { return 0, io.EOF } - end := rw.file.size + end := size if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end { end = rend } @@ -265,7 +435,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er // Get internal mappings. ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) if err != nil { - rw.file.mu.RUnlock() return done, err } @@ -275,7 +444,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er rw.off += uint64(n) dsts = dsts.DropFirst64(n) if err != nil { - rw.file.mu.RUnlock() return done, err } @@ -291,7 +459,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er rw.off += uint64(n) dsts = dsts.DropFirst64(n) if err != nil { - rw.file.mu.RUnlock() return done, err } @@ -299,13 +466,16 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} } } - rw.file.mu.RUnlock() return done, nil } // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +// +// Preconditions: inode.mu must be held. func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - rw.file.mu.Lock() + // Hold dataMu so we can modify size. + rw.file.dataMu.Lock() + defer rw.file.dataMu.Unlock() // Compute the range to write (overflow-checked). end := rw.off + srcs.NumBytes() @@ -316,7 +486,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, // Check if seals prevent either file growth or all writes. switch { case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed - rw.file.mu.Unlock() return 0, syserror.EPERM case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed // When growth is sealed, Linux effectively allows writes which would @@ -338,7 +507,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, } if end <= rw.off { // Truncation would result in no data being written. - rw.file.mu.Unlock() return 0, syserror.EPERM } } @@ -395,9 +563,8 @@ exitLoop: // If the write ends beyond the file's previous size, it causes the // file to grow. if rw.off > rw.file.size { - atomic.StoreUint64(&rw.file.size, rw.off) + rw.file.size = rw.off } - rw.file.mu.Unlock() return done, retErr } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index c5bb17562..521206305 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -18,9 +18,10 @@ // Lock order: // // filesystem.mu -// regularFileFD.offMu -// regularFile.mu // inode.mu +// regularFileFD.offMu +// regularFile.mapsMu +// regularFile.dataMu package tmpfs import ( @@ -226,12 +227,15 @@ func (i *inode) tryIncRef() bool { func (i *inode) decRef() { if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { - // This is unnecessary; it's mostly to simulate what tmpfs would do. if regFile, ok := i.impl.(*regularFile); ok { - regFile.mu.Lock() + // Hold inode.mu and regFile.dataMu while mutating + // size. + i.mu.Lock() + regFile.dataMu.Lock() regFile.data.DropAll(regFile.memFile) atomic.StoreUint64(®File.size, 0) - regFile.mu.Unlock() + regFile.dataMu.Unlock() + i.mu.Unlock() } } else if refs < 0 { panic("tmpfs.inode.decRef() called without holding a reference") @@ -320,7 +324,7 @@ func (i *inode) setStat(stat linux.Statx) error { if mask&linux.STATX_SIZE != 0 { switch impl := i.impl.(type) { case *regularFile: - updated, err := impl.truncate(stat.Size) + updated, err := impl.truncateLocked(stat.Size) if err != nil { return err } diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 334432abf..07bf39fed 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -10,6 +10,7 @@ go_library( srcs = [ "context.go", "inet.go", + "namespace.go", "test_stack.go", ], deps = [ diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go new file mode 100644 index 000000000..c16667e7f --- /dev/null +++ b/pkg/sentry/inet/namespace.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inet + +// Namespace represents a network namespace. See network_namespaces(7). +// +// +stateify savable +type Namespace struct { + // stack is the network stack implementation of this network namespace. + stack Stack `state:"nosave"` + + // creator allows kernel to create new network stack for network namespaces. + // If nil, no networking will function if network is namespaced. + creator NetworkStackCreator + + // isRoot indicates whether this is the root network namespace. + isRoot bool +} + +// NewRootNamespace creates the root network namespace, with creator +// allowing new network namespaces to be created. If creator is nil, no +// networking will function if the network is namespaced. +func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace { + return &Namespace{ + stack: stack, + creator: creator, + isRoot: true, + } +} + +// NewNamespace creates a new network namespace from the root. +func NewNamespace(root *Namespace) *Namespace { + n := &Namespace{ + creator: root.creator, + } + n.init() + return n +} + +// Stack returns the network stack of n. Stack may return nil if no network +// stack is configured. +func (n *Namespace) Stack() Stack { + return n.stack +} + +// IsRoot returns whether n is the root network namespace. +func (n *Namespace) IsRoot() bool { + return n.isRoot +} + +// RestoreRootStack restores the root network namespace with stack. This should +// only be called when restoring kernel. +func (n *Namespace) RestoreRootStack(stack Stack) { + if !n.isRoot { + panic("RestoreRootStack can only be called on root network namespace") + } + if n.stack != nil { + panic("RestoreRootStack called after a stack has already been set") + } + n.stack = stack +} + +func (n *Namespace) init() { + // Root network namespace will have stack assigned later. + if n.isRoot { + return + } + if n.creator != nil { + var err error + n.stack, err = n.creator.CreateStack() + if err != nil { + panic(err) + } + } +} + +// afterLoad is invoked by stateify. +func (n *Namespace) afterLoad() { + n.init() +} + +// NetworkStackCreator allows new instances of a network stack to be created. It +// is used by the kernel to create new network namespaces when requested. +type NetworkStackCreator interface { + // CreateStack creates a new network stack for a network namespace. + CreateStack() (Stack, error) +} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 23b88f7a6..58001d56c 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -296,6 +296,50 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return fds, nil } +// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for +// the given file description. If it succeeds, it takes a reference on file. +func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { + if minfd < 0 { + // Don't accept negative FDs. + return -1, syscall.EINVAL + } + + // Default limit. + end := int32(math.MaxInt32) + + // Ensure we don't get past the provided limit. + if limitSet := limits.FromContext(ctx); limitSet != nil { + lim := limitSet.Get(limits.NumberOfFiles) + if lim.Cur != limits.Infinity { + end = int32(lim.Cur) + } + if minfd >= end { + return -1, syscall.EMFILE + } + } + + f.mu.Lock() + defer f.mu.Unlock() + + // From f.next to find available fd. + fd := minfd + if fd < f.next { + fd = f.next + } + for fd < end { + if d, _, _ := f.get(fd); d == nil { + f.setVFS2(fd, file, flags) + if fd == f.next { + // Update next search start position. + f.next = fd + 1 + } + return fd, nil + } + fd++ + } + return -1, syscall.EMFILE +} + // NewFDAt sets the file reference for the given FD. If there is an active // reference for that FD, the ref count for that existing reference is // decremented. @@ -316,9 +360,6 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 return syscall.EBADF } - f.mu.Lock() - defer f.mu.Unlock() - // Check the limit for the provided file. if limitSet := limits.FromContext(ctx); limitSet != nil { if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur { @@ -327,6 +368,8 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 } // Install the entry. + f.mu.Lock() + defer f.mu.Unlock() f.setAll(fd, file, fileVFS2, flags) return nil } diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index 7218aa24e..47f78df9a 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -244,6 +244,28 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) { old.DecRef() } +// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd. +// +// This is not a valid call after free. +func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) { + if !vd.Ok() { + panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry") + } + + f.mu.Lock() + + if !f.rootVFS2.Ok() { + f.mu.Unlock() + panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd)) + } + + old := f.rootVFS2 + vd.IncRef() + f.rootVFS2 = vd + f.mu.Unlock() + old.DecRef() +} + // Umask returns the current umask. func (f *FSContext) Umask() uint { f.mu.Lock() diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 7da0368f1..8b76750e9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -111,7 +111,7 @@ type Kernel struct { timekeeper *Timekeeper tasks *TaskSet rootUserNamespace *auth.UserNamespace - networkStack inet.Stack `state:"nosave"` + rootNetworkNamespace *inet.Namespace applicationCores uint useHostCores bool extraAuxv []arch.AuxEntry @@ -247,6 +247,10 @@ type Kernel struct { // VFS keeps the filesystem state used across the kernel. vfs vfs.VirtualFilesystem + + // If set to true, report address space activation waits as if the task is in + // external wait so that the watchdog doesn't report the task stuck. + SleepForAddressSpaceActivation bool } // InitKernelArgs holds arguments to Init. @@ -260,8 +264,9 @@ type InitKernelArgs struct { // RootUserNamespace is the root user namespace. RootUserNamespace *auth.UserNamespace - // NetworkStack is the TCP/IP network stack. NetworkStack may be nil. - NetworkStack inet.Stack + // RootNetworkNamespace is the root network namespace. If nil, no networking + // will be available. + RootNetworkNamespace *inet.Namespace // ApplicationCores is the number of logical CPUs visible to sandboxed // applications. The set of logical CPU IDs is [0, ApplicationCores); thus @@ -320,7 +325,10 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.rootUTSNamespace = args.RootUTSNamespace k.rootIPCNamespace = args.RootIPCNamespace k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace - k.networkStack = args.NetworkStack + k.rootNetworkNamespace = args.RootNetworkNamespace + if k.rootNetworkNamespace == nil { + k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil) + } k.applicationCores = args.ApplicationCores if args.UseHostCores { k.useHostCores = true @@ -543,8 +551,6 @@ func (ts *TaskSet) unregisterEpollWaiters() { func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() - k.networkStack = net - initAppCores := k.applicationCores // Load the pre-saved CPUID FeatureSet. @@ -575,6 +581,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) log.Infof("Kernel load stats: %s", &stats) log.Infof("Kernel load took [%s].", time.Since(kernelStart)) + // rootNetworkNamespace should be populated after loading the state file. + // Restore the root network stack. + k.rootNetworkNamespace.RestoreRootStack(net) + // Load the memory file's state. memoryStart := time.Now() if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { @@ -905,6 +915,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, FSContext: fsContext, FDTable: args.FDTable, Credentials: args.Credentials, + NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores), UTSNamespace: args.UTSNamespace, IPCNamespace: args.IPCNamespace, @@ -1255,10 +1266,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace { return k.rootAbstractSocketNamespace } -// NetworkStack returns the network stack. NetworkStack may return nil if no -// network stack is available. -func (k *Kernel) NetworkStack() inet.Stack { - return k.networkStack +// RootNetworkNamespace returns the root network namespace, always non-nil. +func (k *Kernel) RootNetworkNamespace() *inet.Namespace { + return k.rootNetworkNamespace } // GlobalInit returns the thread group with ID 1 in the root PID namespace, or diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 18416643b..ded95f532 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() { } var cs linux.RSeqCriticalSection - if _, err := cs.CopyIn(t, critAddr); err != nil { + if err := cs.CopyIn(t, critAddr); err != nil { t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err) t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index a3443ff21..2cee2e6ed 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -486,13 +486,10 @@ type Task struct { numaPolicy int32 numaNodeMask uint64 - // If netns is true, the task is in a non-root network namespace. Network - // namespaces aren't currently implemented in full; being in a network - // namespace simply prevents the task from observing any network devices - // (including loopback) or using abstract socket addresses (see unix(7)). + // netns is the task's network namespace. netns is never nil. // - // netns is protected by mu. netns is owned by the task goroutine. - netns bool + // netns is protected by mu. + netns *inet.Namespace // If rseqPreempted is true, before the next call to p.Switch(), // interrupt rseq critical regions as defined by rseqAddr and @@ -792,6 +789,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error) return fds[0], nil } +// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2. +// +// This automatically passes the task as the context. +// +// Precondition: same as FDTable.Get. +func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { + return t.fdTable.NewFDVFS2(t, fd, file, flags) +} + // NewFDAt is a convenience wrapper for t.FDTable().NewFDAt. // // This automatically passes the task as the context. @@ -801,6 +807,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error { return t.fdTable.NewFDAt(t, fd, file, flags) } +// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2. +// +// This automatically passes the task as the context. +// +// Precondition: same as FDTable. +func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error { + return t.fdTable.NewFDAtVFS2(t, fd, file, flags) +} + // WithMuLocked executes f with t.mu locked. func (t *Task) WithMuLocked(f func(*Task)) { t.mu.Lock() diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index ba74b4c1c..78866f280 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -54,8 +55,7 @@ type SharingOptions struct { NewUserNamespace bool // If NewNetworkNamespace is true, the task should have an independent - // network namespace. (Note that network namespaces are not really - // implemented; see comment on Task.netns for details.) + // network namespace. NewNetworkNamespace bool // If NewFiles is true, the task should use an independent file descriptor @@ -199,6 +199,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { ipcns = NewIPCNamespace(userns) } + netns := t.NetworkNamespace() + if opts.NewNetworkNamespace { + netns = inet.NewNamespace(netns) + } + // TODO(b/63601033): Implement CLONE_NEWNS. mntnsVFS2 := t.mountNamespaceVFS2 if mntnsVFS2 != nil { @@ -268,7 +273,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { FDTable: fdTable, Credentials: creds, Niceness: t.Niceness(), - NetworkNamespaced: t.netns, + NetworkNamespace: netns, AllowedCPUMask: t.CPUMask(), UTSNamespace: utsns, IPCNamespace: ipcns, @@ -283,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } else { cfg.InheritParent = t } - if opts.NewNetworkNamespace { - cfg.NetworkNamespaced = true - } nt, err := t.tg.pidns.owner.NewTask(cfg) if err != nil { if opts.NewThreadGroup { @@ -482,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error { t.mu.Unlock() return syserror.EPERM } - t.netns = true + t.netns = inet.NewNamespace(t.netns) } if opts.NewUTSNamespace { if !haveCapSysAdmin { diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 2be982684..0158b1788 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -140,7 +140,7 @@ func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*Task } // Prepare a new user address space to load into. - m := mm.NewMemoryManager(k, k) + m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) defer m.DecUsers(ctx) args.MemoryManager = m diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 8f57a34a6..00c425cca 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -220,7 +220,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.mu.Unlock() t.unstopVforkParent() // NOTE(b/30316266): All locks must be dropped prior to calling Activate. - t.MemoryManager().Activate() + t.MemoryManager().Activate(t) t.ptraceExec(oldTID) return (*runSyscallExit)(nil) diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index 6d737d3e5..eeccaa197 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -32,21 +32,21 @@ const ( // Infof logs an formatted info message by calling log.Infof. func (t *Task) Infof(fmt string, v ...interface{}) { if log.IsLogging(log.Info) { - log.Infof(t.logPrefix.Load().(string)+fmt, v...) + log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...) } } // Warningf logs a warning string by calling log.Warningf. func (t *Task) Warningf(fmt string, v ...interface{}) { if log.IsLogging(log.Warning) { - log.Warningf(t.logPrefix.Load().(string)+fmt, v...) + log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...) } } // Debugf creates a debug string that includes the task ID. func (t *Task) Debugf(fmt string, v ...interface{}) { if log.IsLogging(log.Debug) { - log.Debugf(t.logPrefix.Load().(string)+fmt, v...) + log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...) } } diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go index 172a31e1d..f7711232c 100644 --- a/pkg/sentry/kernel/task_net.go +++ b/pkg/sentry/kernel/task_net.go @@ -22,14 +22,23 @@ import ( func (t *Task) IsNetworkNamespaced() bool { t.mu.Lock() defer t.mu.Unlock() - return t.netns + return !t.netns.IsRoot() } // NetworkContext returns the network stack used by the task. NetworkContext // may return nil if no network stack is available. +// +// TODO(gvisor.dev/issue/1833): Migrate callers of this method to +// NetworkNamespace(). func (t *Task) NetworkContext() inet.Stack { - if t.IsNetworkNamespaced() { - return nil - } - return t.k.networkStack + t.mu.Lock() + defer t.mu.Unlock() + return t.netns.Stack() +} + +// NetworkNamespace returns the network namespace observed by the task. +func (t *Task) NetworkNamespace() *inet.Namespace { + t.mu.Lock() + defer t.mu.Unlock() + return t.netns } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index f9236a842..a5035bb7f 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" @@ -65,9 +66,8 @@ type TaskConfig struct { // Niceness is the niceness of the new task. Niceness int - // If NetworkNamespaced is true, the new task should observe a non-root - // network namespace. - NetworkNamespaced bool + // NetworkNamespace is the network namespace to be used for the new task. + NetworkNamespace *inet.Namespace // AllowedCPUMask contains the cpus that this task can run on. AllowedCPUMask sched.CPUSet @@ -133,7 +133,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { allowedCPUMask: cfg.AllowedCPUMask.Copy(), ioUsage: &usage.IO{}, niceness: cfg.Niceness, - netns: cfg.NetworkNamespaced, + netns: cfg.NetworkNamespace, utsns: cfg.UTSNamespace, ipcns: cfg.IPCNamespace, abstractSockets: cfg.AbstractSocketNamespace, diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index 2bf3ce8a8..b02044ad2 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -30,7 +30,7 @@ var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown()) // Activate ensures that the task has an active address space. func (t *Task) Activate() { if mm := t.MemoryManager(); mm != nil { - if err := mm.Activate(); err != nil { + if err := mm.Activate(t); err != nil { panic("unable to activate mm: " + err.Error()) } } diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index e58a63deb..0332fc71c 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -18,7 +18,7 @@ import ( "fmt" "sync/atomic" - "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -39,11 +39,18 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace { // // When this MemoryManager is no longer needed by a task, it should call // Deactivate to release the reference. -func (mm *MemoryManager) Activate() error { +func (mm *MemoryManager) Activate(ctx context.Context) error { // Fast path: the MemoryManager already has an active // platform.AddressSpace, and we just need to indicate that we need it too. - if atomicbitops.IncUnlessZeroInt32(&mm.active) { - return nil + for { + active := atomic.LoadInt32(&mm.active) + if active == 0 { + // Fall back to the slow path. + break + } + if atomic.CompareAndSwapInt32(&mm.active, active, active+1) { + return nil + } } for { @@ -85,16 +92,20 @@ func (mm *MemoryManager) Activate() error { if as == nil { // AddressSpace is unavailable, we must wait. // - // activeMu must not be held while waiting, as the user - // of the address space we are waiting on may attempt - // to take activeMu. - // - // Don't call UninterruptibleSleepStart to register the - // wait to allow the watchdog stuck task to trigger in - // case a process is starved waiting for the address - // space. + // activeMu must not be held while waiting, as the user of the address + // space we are waiting on may attempt to take activeMu. mm.activeMu.Unlock() + + sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation + if sleep { + // Mark this task sleeping while waiting for the address space to + // prevent the watchdog from reporting it as a stuck task. + ctx.UninterruptibleSleepStart(false) + } <-c + if sleep { + ctx.UninterruptibleSleepFinish(false) + } continue } @@ -118,8 +129,15 @@ func (mm *MemoryManager) Activate() error { func (mm *MemoryManager) Deactivate() { // Fast path: this is not the last goroutine to deactivate the // MemoryManager. - if atomicbitops.DecUnlessOneInt32(&mm.active) { - return + for { + active := atomic.LoadInt32(&mm.active) + if active == 1 { + // Fall back to the slow path. + break + } + if atomic.CompareAndSwapInt32(&mm.active, active, active-1) { + return + } } mm.activeMu.Lock() diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 47b8fbf43..d8a5b9d29 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -18,7 +18,6 @@ import ( "fmt" "sync/atomic" - "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/limits" @@ -29,16 +28,17 @@ import ( ) // NewMemoryManager returns a new MemoryManager with no mappings and 1 user. -func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager { +func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager { return &MemoryManager{ - p: p, - mfp: mfp, - haveASIO: p.SupportsAddressSpaceIO(), - privateRefs: &privateRefs{}, - users: 1, - auxv: arch.Auxv{}, - dumpability: UserDumpable, - aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + p: p, + mfp: mfp, + haveASIO: p.SupportsAddressSpaceIO(), + privateRefs: &privateRefs{}, + users: 1, + auxv: arch.Auxv{}, + dumpability: UserDumpable, + aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + sleepForActivation: sleepForActivation, } } @@ -80,9 +80,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { envv: mm.envv, auxv: append(arch.Auxv(nil), mm.auxv...), // IncRef'd below, once we know that there isn't an error. - executable: mm.executable, - dumpability: mm.dumpability, - aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + executable: mm.executable, + dumpability: mm.dumpability, + aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + sleepForActivation: mm.sleepForActivation, } // Copy vmas. @@ -229,7 +230,15 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { // IncUsers increments mm's user count and returns true. If the user count is // already 0, IncUsers does nothing and returns false. func (mm *MemoryManager) IncUsers() bool { - return atomicbitops.IncUnlessZeroInt32(&mm.users) + for { + users := atomic.LoadInt32(&mm.users) + if users == 0 { + return false + } + if atomic.CompareAndSwapInt32(&mm.users, users, users+1) { + return true + } + } } // DecUsers decrements mm's user count. If the user count reaches 0, all diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 637383c7a..c2195ae11 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -226,6 +226,11 @@ type MemoryManager struct { // aioManager keeps track of AIOContexts used for async IOs. AIOManager // must be cloned when CLONE_VM is used. aioManager aioManager + + // sleepForActivation indicates whether the task should report to be sleeping + // before trying to activate the address space. When set to true, delays in + // activation are not reported as stuck tasks by the watchdog. + sleepForActivation bool } // vma represents a virtual memory area. diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index edacca741..fdc308542 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -31,7 +31,7 @@ import ( func testMemoryManager(ctx context.Context) *MemoryManager { p := platform.FromContext(ctx) mfp := pgalloc.MemoryFileProviderFromContext(ctx) - mm := NewMemoryManager(p, mfp) + mm := NewMemoryManager(p, mfp, false) mm.layout = arch.MmapLayout{ MinAddr: p.MinUserAddress(), MaxAddr: p.MaxUserAddress(), diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 8076c7529..f1afc74dc 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -329,10 +329,12 @@ func (m *machine) Destroy() { } // Get gets an available vCPU. +// +// This will return with the OS thread locked. func (m *machine) Get() *vCPU { + m.mu.RLock() runtime.LockOSThread() tid := procid.Current() - m.mu.RLock() // Check for an exact match. if c := m.vCPUs[tid]; c != nil { @@ -343,8 +345,22 @@ func (m *machine) Get() *vCPU { // The happy path failed. We now proceed to acquire an exclusive lock // (because the vCPU map may change), and scan all available vCPUs. + // In this case, we first unlock the OS thread. Otherwise, if mu is + // not available, the current system thread will be parked and a new + // system thread spawned. We avoid this situation by simply refreshing + // tid after relocking the system thread. m.mu.RUnlock() + runtime.UnlockOSThread() m.mu.Lock() + runtime.LockOSThread() + tid = procid.Current() + + // Recheck for an exact match. + if c := m.vCPUs[tid]; c != nil { + c.lock() + m.mu.Unlock() + return c + } for { // Scan for an available vCPU. diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4667373d2..8834a1e1a 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -329,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { } // PackTClass packs an IPV6_TCLASS socket control message. -func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { +func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IPV6, diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index c91ec7494..7cd2ce55b 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "extensions.go", "netfilter.go", + "targets.go", "tcp_matcher.go", "udp_matcher.go", ], diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 257cb485b..f68a2260d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -248,13 +248,15 @@ func marshalTarget(target iptables.Target) []byte { return marshalStandardTarget(iptables.RuleReturn) case iptables.RedirectTarget: return marshalRedirectTarget() + case JumpTarget: + return marshalJumpTarget(tg) default: panic(fmt.Errorf("unknown target of type %T", target)) } } func marshalStandardTarget(verdict iptables.RuleVerdict) []byte { - nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget) + nflog("convert to binary: marshalling standard target") // The target's name will be the empty string. target := linux.XTStandardTarget{ @@ -290,8 +292,25 @@ func marshalRedirectTarget() []byte { }, } copy(target.Target.Name[:], redirectTargetName) - + ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalJumpTarget(jt JumpTarget) []byte { + nflog("convert to binary: marshalling jump target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + // Verdict is overloaded by the ABI. When positive, it holds + // the jump offset from the start of the table. + Verdict: int32(jt.Offset), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) return binary.Marshal(ret, usermem.ByteOrder, target) } @@ -358,7 +377,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Convert input into a list of rules and their offsets. var offset uint32 - var offsets []uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { nflog("set entries: processing entry at offset %d", offset) @@ -419,11 +439,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { Target: target, Matchers: matchers, }) - offsets = append(offsets, offset) + offsets[offset] = int(entryIdx) offset += uint32(entry.NextOffset) if initialOptValLen-len(optVal) != int(entry.NextOffset) { nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return syserr.ErrInvalidArgument } } @@ -432,13 +453,13 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1<<hook) != 0 { hk := hookFromLinux(hook) - for ruleIdx, offset := range offsets { + for offset, ruleIdx := range offsets { if offset == replace.HookEntry[hook] { table.BuiltinChains[hk] = ruleIdx } if offset == replace.Underflow[hook] { if !validUnderflow(table.Rules[ruleIdx]) { - nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP.") + nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP") return syserr.ErrInvalidArgument } table.Underflows[hk] = ruleIdx @@ -467,16 +488,35 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // - There's some other rule after it. // - There are no matchers. if ruleIdx == len(table.Rules)-1 { - nflog("user chain must have a rule or default policy.") + nflog("user chain must have a rule or default policy") return syserr.ErrInvalidArgument } if len(table.Rules[ruleIdx].Matchers) != 0 { - nflog("user chain's first node must have no matcheres.") + nflog("user chain's first node must have no matchers") return syserr.ErrInvalidArgument } table.UserChains[target.Name] = ruleIdx + 1 } + // Set each jump to point to the appropriate rule. Right now they hold byte + // offsets. + for ruleIdx, rule := range table.Rules { + jump, ok := rule.Target.(JumpTarget) + if !ok { + continue + } + + // Find the rule corresponding to the jump rule offset. + jumpTo, ok := offsets[jump.Offset] + if !ok { + nflog("failed to find a rule to jump to") + return syserr.ErrInvalidArgument + } + jump.RuleNum = jumpTo + rule.Target = jump + table.Rules[ruleIdx] = rule + } + // TODO(gvisor.dev/issue/170): Support other chains. // Since we only support modifying the INPUT chain and redirect for // PREROUTING chain right now, make sure all other chains point to @@ -572,7 +612,12 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target buf = optVal[:linux.SizeOfXTStandardTarget] binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - return translateToStandardTarget(standardTarget.Verdict) + if standardTarget.Verdict < 0 { + // A Verdict < 0 indicates a non-jump verdict. + return translateToStandardTarget(standardTarget.Verdict) + } + // A verdict >= 0 indicates a jump. + return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil case errorTargetName: // Error target. diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go new file mode 100644 index 000000000..c421b87cf --- /dev/null +++ b/pkg/sentry/socket/netfilter/targets.go @@ -0,0 +1,35 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/iptables" +) + +// JumpTarget implements iptables.Target. +type JumpTarget struct { + // Offset is the byte offset of the rule to jump to. It is used for + // marshaling and unmarshaling. + Offset uint32 + + // RuleNum is the rule to jump to. + RuleNum int +} + +// Action implements iptables.Target.Action. +func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) { + return iptables.RuleJump, jt.RuleNum +} diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9757fbfba..e187276c5 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1318,6 +1318,22 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } return ib, nil + case linux.IPV6_RECVTCLASS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1803,6 +1819,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + case linux.IPV6_RECVTCLASS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + default: emitUnimplementedEventIPv6(t, name) } @@ -2086,7 +2110,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, - linux.IPV6_RECVTCLASS, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, linux.IPV6_TCLASS, @@ -2424,6 +2447,8 @@ func (s *SocketOperations) controlMessages() socket.ControlMessages { Timestamp: s.readCM.Timestamp, HasTOS: s.readCM.HasTOS, TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, HasIPPacketInfo: s.readCM.HasIPPacketInfo, PacketInfo: s.readCM.PacketInfo, }, diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 5afff2564..5f181f017 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -75,6 +75,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in switch protocol { case syscall.IPPROTO_ICMP: return header.ICMPv4ProtocolNumber, true, nil + case syscall.IPPROTO_ICMPV6: + return header.ICMPv6ProtocolNumber, true, nil case syscall.IPPROTO_UDP: return header.UDPProtocolNumber, true, nil case syscall.IPPROTO_TCP: diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 2f39a6f2b..88d5db9fc 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "capability.go", "clone.go", + "epoll.go", "futex.go", "linux64_amd64.go", "linux64_arm64.go", diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go new file mode 100644 index 000000000..a6e48b836 --- /dev/null +++ b/pkg/sentry/strace/epoll.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "fmt" + "strings" + + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/usermem" +) + +func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string { + var e linux.EpollEvent + if _, err := t.CopyIn(eventAddr, &e); err != nil { + return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err) + } + var sb strings.Builder + fmt.Fprintf(&sb, "%#x ", eventAddr) + writeEpollEvent(&sb, e) + return sb.String() +} + +func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string { + var sb strings.Builder + fmt.Fprintf(&sb, "%#x {", eventsAddr) + addr := eventsAddr + for i := uint64(0); i < numEvents; i++ { + var e linux.EpollEvent + if _, err := t.CopyIn(addr, &e); err != nil { + fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err) + continue + } + writeEpollEvent(&sb, e) + if uint64(sb.Len()) >= maxBytes { + sb.WriteString("...") + break + } + if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok { + fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr) + continue + } + } + sb.WriteString("}") + return sb.String() +} + +func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) { + events := epollEventEvents.Parse(uint64(e.Events)) + fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1]) +} + +var epollCtlOps = abi.ValueSet{ + linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD", + linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL", + linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD", +} + +var epollEventEvents = abi.FlagSet{ + {Flag: linux.EPOLLIN, Name: "EPOLLIN"}, + {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"}, + {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"}, + {Flag: linux.EPOLLERR, Name: "EPOLLERR"}, + {Flag: linux.EPOLLHUP, Name: "EPULLHUP"}, + {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"}, + {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"}, + {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"}, + {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"}, + {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"}, + {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"}, + {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"}, + {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"}, + {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"}, + {Flag: linux.EPOLLET, Name: "EPOLLET"}, +} diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index a4de545e9..71b92eaee 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -256,8 +256,8 @@ var linuxAMD64 = SyscallMap{ 229: makeSyscallInfo("clock_getres", Hex, PostTimespec), 230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec), 231: makeSyscallInfo("exit_group", Hex), - 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex), - 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex), + 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex), + 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent), 234: makeSyscallInfo("tgkill", Hex, Hex, Signal), 235: makeSyscallInfo("utimes", Path, Timeval), // 236: vserver (not implemented in the Linux kernel) @@ -305,7 +305,7 @@ var linuxAMD64 = SyscallMap{ 278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex), 279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex), 280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex), - 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex), + 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex), 282: makeSyscallInfo("signalfd", Hex, Hex, Hex), 283: makeSyscallInfo("timerfd_create", Hex, Hex), 284: makeSyscallInfo("eventfd", Hex), diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index 8bc38545f..bd7361a52 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -45,8 +45,8 @@ var linuxARM64 = SyscallMap{ 18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex), 19: makeSyscallInfo("eventfd2", Hex, Hex), 20: makeSyscallInfo("epoll_create1", Hex), - 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex), - 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex), + 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent), + 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex), 23: makeSyscallInfo("dup", FD), 24: makeSyscallInfo("dup3", FD, FD, Hex), 25: makeSyscallInfo("fcntl", FD, Hex, Hex), diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 46cb2a1cc..77655558e 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -481,6 +481,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer())) case PollFDs: output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false)) + case EpollCtlOp: + output = append(output, epollCtlOps.Parse(uint64(args[arg].Int()))) + case EpollEvent: + output = append(output, epollEvent(t, args[arg].Pointer())) + case EpollEvents: + output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize))) case SelectFDSet: output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) case Oct: @@ -549,6 +555,8 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer()) case PollFDs: output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true) + case EpollEvents: + output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize)) case GetSockOptVal: output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval) case SetSockOptVal: diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 446d1e0f6..7e69b9279 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -228,6 +228,16 @@ const ( // SockOptLevel is the optname argument in getsockopt(2) and // setsockopt(2). SockOptName + + // EpollCtlOp is the op argument to epoll_ctl(2). + EpollCtlOp + + // EpollEvent is the event argument in epoll_ctl(2). + EpollEvent + + // EpollEvents is an array of struct epoll_event. It is the events + // argument in epoll_wait(2)/epoll_pwait(2). + EpollEvents ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index fbef5b376..3ab93fbde 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -25,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + // EpollCreate1 implements the epoll_create1(2) linux syscall. func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { flags := args[0].Int() @@ -164,3 +166,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } + +// LINT.ThenChange(vfs2/epoll.go) diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 421845ebb..c21f14dc0 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string return path, dirPath, nil } +// LINT.IfChange + func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -575,6 +577,10 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode) } +// LINT.ThenChange(vfs2/filesystem.go) + +// LINT.IfChange + // Ioctl implements linux syscall ioctl(2). func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -650,6 +656,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } } +// LINT.ThenChange(vfs2/ioctl.go) + +// LINT.IfChange + // Getcwd implements the linux syscall getcwd(2). func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() @@ -760,6 +770,10 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, nil } +// LINT.ThenChange(vfs2/fscontext.go) + +// LINT.IfChange + // Close implements linux syscall close(2). func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -1094,6 +1108,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } } +// LINT.ThenChange(vfs2/fd.go) + const ( _FADV_NORMAL = 0 _FADV_RANDOM = 1 @@ -1141,6 +1157,8 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } +// LINT.IfChange + func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1421,6 +1439,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty) } +// LINT.ThenChange(vfs2/filesystem.go) + +// LINT.IfChange + func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1480,6 +1502,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return n, nil, err } +// LINT.ThenChange(vfs2/stat.go) + +// LINT.IfChange + func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1516,6 +1542,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, unlinkAt(t, dirFD, addr) } +// LINT.ThenChange(vfs2/filesystem.go) + +// LINT.IfChange + // Truncate implements linux syscall truncate(2). func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() @@ -1614,6 +1644,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } +// LINT.ThenChange(vfs2/setstat.go) + // Umask implements linux syscall umask(2). func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { mask := args[0].ModeT() @@ -1621,6 +1653,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(mask), nil, nil } +// LINT.IfChange + // Change ownership of a file. // // uid and gid may be -1, in which case they will not be changed. @@ -1987,6 +2021,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true) } +// LINT.ThenChange(vfs2/setstat.go) + +// LINT.IfChange + func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error { newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */) if err != nil { @@ -2042,6 +2080,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr) } +// LINT.ThenChange(vfs2/filesystem.go) + // Fallocate implements linux system call fallocate(2). func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index f66f4ffde..b126fecc0 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -27,6 +27,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + // Getdents implements linux syscall getdents(2) for 64bit systems. func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error { func (ds *direntSerializer) Written() int { return ds.written } + +// LINT.ThenChange(vfs2/getdents.go) diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go index 297e920c4..3f7691eae 100644 --- a/pkg/sentry/syscalls/linux/sys_lseek.go +++ b/pkg/sentry/syscalls/linux/sys_lseek.go @@ -21,6 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LINT.IfChange + // Lseek implements linux syscall lseek(2). func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } return uintptr(offset), nil, err } + +// LINT.ThenChange(vfs2/read_write.go) diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 9959f6e61..91694d374 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return uintptr(addr), nil, nil } +// LINT.IfChange + // Mmap implements linux syscall mmap(2). func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { prot := args[2].Int() @@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(rv), nil, err } +// LINT.ThenChange(vfs2/mmap.go) + // Munmap implements linux syscall munmap(2). func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64()) diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 227692f06..78a2cb750 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -28,6 +28,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + const ( // EventMaskRead contains events that can be triggered on reads. EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr @@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i return total, err } + +// LINT.ThenChange(vfs2/read_write.go) diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 8b66a9006..701b27b4a 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -23,6 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat { return linux.Stat{ Dev: sattr.DeviceID, @@ -131,8 +133,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err return err } s := statFromAttrs(t, d.Inode.StableAttr, uattr) - _, err = s.CopyOut(t, statAddr) - return err + return s.CopyOut(t, statAddr) } // fstat implements fstat for the given *fs.File. @@ -142,8 +143,7 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error { return err } s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr) - _, err = s.CopyOut(t, statAddr) - return err + return s.CopyOut(t, statAddr) } // Statx implements linux syscall statx(2). @@ -299,3 +299,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error { _, err = t.CopyOut(addr, &statfs) return err } + +// LINT.ThenChange(vfs2/stat.go) diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go index 3e55235bd..5ad465ae3 100644 --- a/pkg/sentry/syscalls/linux/sys_sync.go +++ b/pkg/sentry/syscalls/linux/sys_sync.go @@ -22,6 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LINT.IfChange + // Sync implements linux system call sync(2). func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { t.MountNamespace().SyncAll(t) @@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) } + +// LINT.ThenChange(vfs2/sync.go) diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index aba892939..506ee54ce 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -28,6 +28,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + const ( // EventMaskWrite contains events that can be triggered on writes. // @@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) ( return total, err } + +// LINT.ThenChange(vfs2/read_write.go) diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 9d8140b8a..2de5e3422 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -25,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + // GetXattr implements linux syscall getxattr(2). func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return getXattrFromPath(t, args, true) @@ -418,3 +420,5 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error { return d.Inode.RemoveXattr(t, d, name) } + +// LINT.ThenChange(vfs2/xattr.go) diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 6b8a00b6e..f51761e81 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -5,18 +5,44 @@ package(licenses = ["notice"]) go_library( name = "vfs2", srcs = [ + "epoll.go", + "epoll_unsafe.go", + "execve.go", + "fd.go", + "filesystem.go", + "fscontext.go", + "getdents.go", + "ioctl.go", "linux64.go", "linux64_override_amd64.go", "linux64_override_arm64.go", - "sys_read.go", + "mmap.go", + "path.go", + "poll.go", + "read_write.go", + "setstat.go", + "stat.go", + "sync.go", + "xattr.go", ], + marshal = True, visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", + "//pkg/fspath", + "//pkg/gohacks", "//pkg/sentry/arch", + "//pkg/sentry/fsbridge", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", + "//pkg/sentry/limits", + "//pkg/sentry/loader", + "//pkg/sentry/memmap", "//pkg/sentry/syscalls", "//pkg/sentry/syscalls/linux", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go new file mode 100644 index 000000000..d6cb0e79a --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -0,0 +1,225 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "math" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// EpollCreate1 implements Linux syscall epoll_create1(2). +func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + flags := args[0].Int() + if flags&^linux.EPOLL_CLOEXEC != 0 { + return 0, nil, syserror.EINVAL + } + + file, err := t.Kernel().VFS().NewEpollInstanceFD() + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil +} + +// EpollCreate implements Linux syscall epoll_create(2). +func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + size := args[0].Int() + + // "Since Linux 2.6.8, the size argument is ignored, but must be greater + // than zero" - epoll_create(2) + if size <= 0 { + return 0, nil, syserror.EINVAL + } + + file, err := t.Kernel().VFS().NewEpollInstanceFD() + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil +} + +// EpollCtl implements Linux syscall epoll_ctl(2). +func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + op := args[1].Int() + fd := args[2].Int() + eventAddr := args[3].Pointer() + + epfile := t.GetFileVFS2(epfd) + if epfile == nil { + return 0, nil, syserror.EBADF + } + defer epfile.DecRef() + ep, ok := epfile.Impl().(*vfs.EpollInstance) + if !ok { + return 0, nil, syserror.EINVAL + } + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + if epfile == file { + return 0, nil, syserror.EINVAL + } + + var event linux.EpollEvent + switch op { + case linux.EPOLL_CTL_ADD: + if err := event.CopyIn(t, eventAddr); err != nil { + return 0, nil, err + } + return 0, nil, ep.AddInterest(file, fd, event) + case linux.EPOLL_CTL_DEL: + return 0, nil, ep.DeleteInterest(file, fd) + case linux.EPOLL_CTL_MOD: + if err := event.CopyIn(t, eventAddr); err != nil { + return 0, nil, err + } + return 0, nil, ep.ModifyInterest(file, fd, event) + default: + return 0, nil, syserror.EINVAL + } +} + +// EpollWait implements Linux syscall epoll_wait(2). +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeout := int(args[3].Int()) + + const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS + if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS { + return 0, nil, syserror.EINVAL + } + + epfile := t.GetFileVFS2(epfd) + if epfile == nil { + return 0, nil, syserror.EBADF + } + defer epfile.DecRef() + ep, ok := epfile.Impl().(*vfs.EpollInstance) + if !ok { + return 0, nil, syserror.EINVAL + } + + // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent, + // maxEvents), so that the buffer can be allocated on the stack. + var ( + events [16]linux.EpollEvent + total int + ch chan struct{} + haveDeadline bool + deadline ktime.Time + ) + for { + batchEvents := len(events) + if batchEvents > maxEvents { + batchEvents = maxEvents + } + n := ep.ReadEvents(events[:batchEvents]) + maxEvents -= n + if n != 0 { + // Copy what we read out. + copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n]) + eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent) + total += copiedEvents + if err != nil { + if total != 0 { + return uintptr(total), nil, nil + } + return 0, nil, err + } + // If we've filled the application's event buffer, we're done. + if maxEvents == 0 { + return uintptr(total), nil, nil + } + // Loop if we read a full batch, under the expectation that there + // may be more events to read. + if n == batchEvents { + continue + } + } + // We get here if n != batchEvents. If we read any number of events + // (just now, or in a previous iteration of this loop), or if timeout + // is 0 (such that epoll_wait should be non-blocking), return the + // events we've read so far to the application. + if total != 0 || timeout == 0 { + return uintptr(total), nil, nil + } + // In the first iteration of this loop, register with the epoll + // instance for readability events, but then immediately continue the + // loop since we need to retry ReadEvents() before blocking. In all + // subsequent iterations, block until events are available, the timeout + // expires, or an interrupt arrives. + if ch == nil { + var w waiter.Entry + w, ch = waiter.NewChannelEntry(nil) + epfile.EventRegister(&w, waiter.EventIn) + defer epfile.EventUnregister(&w) + } else { + // Set up the timer if a timeout was specified. + if timeout > 0 && !haveDeadline { + timeoutDur := time.Duration(timeout) * time.Millisecond + deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) + haveDeadline = true + } + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + err = nil + } + // total must be 0 since otherwise we would have returned + // above. + return 0, nil, err + } + } + } +} + +// EpollPwait implements Linux syscall epoll_pwait(2). +func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + return EpollWait(t, args) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go new file mode 100644 index 000000000..825f325bf --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go @@ -0,0 +1,44 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "reflect" + "runtime" + "unsafe" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/usermem" +) + +const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{})) + +func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) { + if len(events) == 0 { + return 0, nil + } + // Cast events to a byte slice for copying. + var eventBytes []byte + eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes)) + eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0]))) + eventBytesHdr.Len = len(events) * sizeofEpollEvent + eventBytesHdr.Cap = len(events) * sizeofEpollEvent + copiedBytes, err := t.CopyOutBytes(addr, eventBytes) + runtime.KeepAlive(events) + copiedEvents := copiedBytes / sizeofEpollEvent // rounded down + return copiedEvents, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go new file mode 100644 index 000000000..aef0078a8 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -0,0 +1,137 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsbridge" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/loader" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Execve implements linux syscall execve(2). +func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathnameAddr := args[0].Pointer() + argvAddr := args[1].Pointer() + envvAddr := args[2].Pointer() + return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */) +} + +// Execveat implements linux syscall execveat(2). +func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathnameAddr := args[1].Pointer() + argvAddr := args[2].Pointer() + envvAddr := args[3].Pointer() + flags := args[4].Int() + return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags) +} + +func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + + pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) + if err != nil { + return 0, nil, err + } + var argv, envv []string + if argvAddr != 0 { + var err error + argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize) + if err != nil { + return 0, nil, err + } + } + if envvAddr != 0 { + var err error + envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize) + if err != nil { + return 0, nil, err + } + } + + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + var executable fsbridge.File + closeOnExec := false + if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute { + // We must open the executable ourselves since dirfd is used as the + // starting point while resolving path, but the task working directory + // is used as the starting point while resolving interpreters (Linux: + // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() => + // do_open_execat(fd=AT_FDCWD)), and the loader package is currently + // incapable of handling this correctly. + if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { + return 0, nil, syserror.ENOENT + } + dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd) + if dirfile == nil { + return 0, nil, syserror.EBADF + } + start := dirfile.VirtualDentry() + start.IncRef() + dirfile.DecRef() + closeOnExec = dirfileFlags.CloseOnExec + file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + FileExec: true, + }) + start.DecRef() + if err != nil { + return 0, nil, err + } + defer file.DecRef() + executable = fsbridge.NewVFSFile(file) + } + + // Load the new TaskContext. + mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change + defer mntns.DecRef() + wd := t.FSContext().WorkingDirectoryVFS2() + defer wd.DecRef() + remainingTraversals := uint(linux.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Opener: fsbridge.NewVFSLookup(mntns, root, wd), + RemainingTraversals: &remainingTraversals, + ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + Filename: pathname, + File: executable, + CloseOnExec: closeOnExec, + Argv: argv, + Envv: envv, + Features: t.Arch().FeatureSet(), + } + + tc, se := t.Kernel().LoadTaskImage(t, loadArgs) + if se != nil { + return 0, nil, se.ToError() + } + + ctrl, err := t.Execve(tc) + return 0, ctrl, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go new file mode 100644 index 000000000..3afcea665 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -0,0 +1,147 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Close implements Linux syscall close(2). +func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + // Note that Remove provides a reference on the file that we may use to + // flush. It is still active until we drop the final reference below + // (and other reference-holding operations complete). + _, file := t.FDTable().Remove(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + err := file.OnClose(t) + return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file) +} + +// Dup implements Linux syscall dup(2). +func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) + if err != nil { + return 0, nil, syserror.EMFILE + } + return uintptr(newFD), nil, nil +} + +// Dup2 implements Linux syscall dup2(2). +func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldfd := args[0].Int() + newfd := args[1].Int() + + if oldfd == newfd { + // As long as oldfd is valid, dup2() does nothing and returns newfd. + file := t.GetFileVFS2(oldfd) + if file == nil { + return 0, nil, syserror.EBADF + } + file.DecRef() + return uintptr(newfd), nil, nil + } + + return dup3(t, oldfd, newfd, 0) +} + +// Dup3 implements Linux syscall dup3(2). +func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldfd := args[0].Int() + newfd := args[1].Int() + flags := args[2].Uint() + + if oldfd == newfd { + return 0, nil, syserror.EINVAL + } + + return dup3(t, oldfd, newfd, flags) +} + +func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) { + if flags&^linux.O_CLOEXEC != 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(oldfd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{ + CloseOnExec: flags&linux.O_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + return uintptr(newfd), nil, nil +} + +// Fcntl implements linux syscall fcntl(2). +func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + cmd := args[1].Int() + + file, flags := t.FDTable().GetVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + switch cmd { + case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: + minfd := args[2].Int() + fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{ + CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC, + }) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil + case linux.F_GETFD: + return uintptr(flags.ToLinuxFDFlags()), nil, nil + case linux.F_SETFD: + flags := args[2].Uint() + t.FDTable().SetFlags(fd, kernel.FDFlags{ + CloseOnExec: flags&linux.FD_CLOEXEC != 0, + }) + return 0, nil, nil + case linux.F_GETFL: + return uintptr(file.StatusFlags()), nil, nil + case linux.F_SETFL: + return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint()) + default: + // TODO(gvisor.dev/issue/1623): Everything else is not yet supported. + return 0, nil, syserror.EINVAL + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go new file mode 100644 index 000000000..fc5ceea4c --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -0,0 +1,326 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Link implements Linux syscall link(2). +func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldpathAddr := args[0].Pointer() + newpathAddr := args[1].Pointer() + return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */) +} + +// Linkat implements Linux syscall linkat(2). +func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + olddirfd := args[0].Int() + oldpathAddr := args[1].Pointer() + newdirfd := args[2].Int() + newpathAddr := args[3].Pointer() + flags := args[4].Int() + return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags) +} + +func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 { + return syserror.EINVAL + } + if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) { + return syserror.ENOENT + } + + oldpath, err := copyInPath(t, oldpathAddr) + if err != nil { + return err + } + oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0)) + if err != nil { + return err + } + defer oldtpop.Release() + + newpath, err := copyInPath(t, newpathAddr) + if err != nil { + return err + } + newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer newtpop.Release() + + return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop) +} + +// Mkdir implements Linux syscall mkdir(2). +func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + mode := args[1].ModeT() + return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode) +} + +// Mkdirat implements Linux syscall mkdirat(2). +func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + addr := args[1].Pointer() + mode := args[2].ModeT() + return 0, nil, mkdirat(t, dirfd, addr, mode) +} + +func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error { + path, err := copyInPath(t, addr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{ + Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()), + }) +} + +// Mknod implements Linux syscall mknod(2). +func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + mode := args[1].ModeT() + dev := args[2].Uint() + return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev) +} + +// Mknodat implements Linux syscall mknodat(2). +func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + addr := args[1].Pointer() + mode := args[2].ModeT() + dev := args[3].Uint() + return 0, nil, mknodat(t, dirfd, addr, mode, dev) +} + +func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error { + path, err := copyInPath(t, addr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + major, minor := linux.DecodeDeviceID(dev) + return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{ + Mode: linux.FileMode(mode &^ t.FSContext().Umask()), + DevMajor: uint32(major), + DevMinor: minor, + }) +} + +// Open implements Linux syscall open(2). +func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + flags := args[1].Uint() + mode := args[2].ModeT() + return openat(t, linux.AT_FDCWD, addr, flags, mode) +} + +// Openat implements Linux syscall openat(2). +func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + addr := args[1].Pointer() + flags := args[2].Uint() + mode := args[3].ModeT() + return openat(t, dirfd, addr, flags, mode) +} + +// Creat implements Linux syscall creat(2). +func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + mode := args[1].ModeT() + return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode) +} + +func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) { + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0)) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{ + Flags: flags, + Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()), + }) + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.O_CLOEXEC != 0, + }) + return uintptr(fd), nil, err +} + +// Rename implements Linux syscall rename(2). +func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldpathAddr := args[0].Pointer() + newpathAddr := args[1].Pointer() + return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */) +} + +// Renameat implements Linux syscall renameat(2). +func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + olddirfd := args[0].Int() + oldpathAddr := args[1].Pointer() + newdirfd := args[2].Int() + newpathAddr := args[3].Pointer() + return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */) +} + +// Renameat2 implements Linux syscall renameat2(2). +func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + olddirfd := args[0].Int() + oldpathAddr := args[1].Pointer() + newdirfd := args[2].Int() + newpathAddr := args[3].Pointer() + flags := args[4].Uint() + return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags) +} + +func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error { + oldpath, err := copyInPath(t, oldpathAddr) + if err != nil { + return err + } + // "If oldpath refers to a symbolic link, the link is renamed" - rename(2) + oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer oldtpop.Release() + + newpath, err := copyInPath(t, newpathAddr) + if err != nil { + return err + } + newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer newtpop.Release() + + return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{ + Flags: flags, + }) +} + +// Rmdir implements Linux syscall rmdir(2). +func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr) +} + +func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop) +} + +// Unlink implements Linux syscall unlink(2). +func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr) +} + +func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop) +} + +// Unlinkat implements Linux syscall unlinkat(2). +func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + flags := args[2].Int() + + if flags&^linux.AT_REMOVEDIR != 0 { + return 0, nil, syserror.EINVAL + } + + if flags&linux.AT_REMOVEDIR != 0 { + return 0, nil, rmdirat(t, dirfd, pathAddr) + } + return 0, nil, unlinkat(t, dirfd, pathAddr) +} + +// Symlink implements Linux syscall symlink(2). +func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + targetAddr := args[0].Pointer() + linkpathAddr := args[1].Pointer() + return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr) +} + +// Symlinkat implements Linux syscall symlinkat(2). +func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + targetAddr := args[0].Pointer() + newdirfd := args[1].Int() + linkpathAddr := args[2].Pointer() + return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr) +} + +func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error { + target, err := t.CopyInString(targetAddr, linux.PATH_MAX) + if err != nil { + return err + } + linkpath, err := copyInPath(t, linkpathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go new file mode 100644 index 000000000..317409a18 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Getcwd implements Linux syscall getcwd(2). +func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + size := args[1].SizeT() + + root := t.FSContext().RootDirectoryVFS2() + wd := t.FSContext().WorkingDirectoryVFS2() + s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd) + root.DecRef() + wd.DecRef() + if err != nil { + return 0, nil, err + } + + // Note this is >= because we need a terminator. + if uint(len(s)) >= size { + return 0, nil, syserror.ERANGE + } + + // Construct a byte slice containing a NUL terminator. + buf := t.CopyScratchBuffer(len(s) + 1) + copy(buf, s) + buf[len(buf)-1] = 0 + + // Write the pathname slice. + n, err := t.CopyOutBytes(addr, buf) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Chdir implements Linux syscall chdir(2). +func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + return 0, nil, err + } + t.FSContext().SetWorkingDirectoryVFS2(vd) + vd.DecRef() + return 0, nil, nil +} + +// Fchdir implements Linux syscall fchdir(2). +func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + return 0, nil, err + } + t.FSContext().SetWorkingDirectoryVFS2(vd) + vd.DecRef() + return 0, nil, nil +} + +// Chroot implements Linux syscall chroot(2). +func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + + if !t.HasCapability(linux.CAP_SYS_CHROOT) { + return 0, nil, syserror.EPERM + } + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + return 0, nil, err + } + t.FSContext().SetRootDirectoryVFS2(vd) + vd.DecRef() + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go new file mode 100644 index 000000000..ddc140b65 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -0,0 +1,149 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Getdents implements Linux syscall getdents(2). +func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getdents(t, args, false /* isGetdents64 */) +} + +// Getdents64 implements Linux syscall getdents64(2). +func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getdents(t, args, true /* isGetdents64 */) +} + +func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := int(args[2].Uint()) + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + cb := getGetdentsCallback(t, addr, size, isGetdents64) + err := file.IterDirents(t, cb) + n := size - cb.remaining + putGetdentsCallback(cb) + if n == 0 { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +type getdentsCallback struct { + t *kernel.Task + addr usermem.Addr + remaining int + isGetdents64 bool +} + +var getdentsCallbackPool = sync.Pool{ + New: func() interface{} { + return &getdentsCallback{} + }, +} + +func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback { + cb := getdentsCallbackPool.Get().(*getdentsCallback) + *cb = getdentsCallback{ + t: t, + addr: addr, + remaining: size, + isGetdents64: isGetdents64, + } + return cb +} + +func putGetdentsCallback(cb *getdentsCallback) { + cb.t = nil + getdentsCallbackPool.Put(cb) +} + +// Handle implements vfs.IterDirentsCallback.Handle. +func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { + var buf []byte + if cb.isGetdents64 { + // struct linux_dirent64 { + // ino64_t d_ino; /* 64-bit inode number */ + // off64_t d_off; /* 64-bit offset to next structure */ + // unsigned short d_reclen; /* Size of this dirent */ + // unsigned char d_type; /* File type */ + // char d_name[]; /* Filename (null-terminated) */ + // }; + size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name) + if size < cb.remaining { + return syserror.EINVAL + } + buf = cb.t.CopyScratchBuffer(size) + usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino) + usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) + usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) + buf[18] = dirent.Type + copy(buf[19:], dirent.Name) + buf[size-1] = 0 // NUL terminator + } else { + // struct linux_dirent { + // unsigned long d_ino; /* Inode number */ + // unsigned long d_off; /* Offset to next linux_dirent */ + // unsigned short d_reclen; /* Length of this linux_dirent */ + // char d_name[]; /* Filename (null-terminated) */ + // /* length is actually (d_reclen - 2 - + // offsetof(struct linux_dirent, d_name)) */ + // /* + // char pad; // Zero padding byte + // char d_type; // File type (only since Linux + // // 2.6.4); offset is (d_reclen - 1) + // */ + // }; + if cb.t.Arch().Width() != 8 { + panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width())) + } + size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name) + if size < cb.remaining { + return syserror.EINVAL + } + buf = cb.t.CopyScratchBuffer(size) + usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino) + usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) + usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) + copy(buf[18:], dirent.Name) + buf[size-3] = 0 // NUL terminator + buf[size-2] = 0 // zero padding byte + buf[size-1] = dirent.Type + } + n, err := cb.t.CopyOutBytes(cb.addr, buf) + if err != nil { + // Don't report partially-written dirents by advancing cb.addr or + // cb.remaining. + return err + } + cb.addr += usermem.Addr(n) + cb.remaining -= n + return nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go new file mode 100644 index 000000000..5a2418da9 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -0,0 +1,35 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Ioctl implements Linux syscall ioctl(2). +func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + ret, err := file.Ioctl(t, t.MemoryManager(), args) + return ret, nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go index e0ac32b33..7d220bc20 100644 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build amd64 + package vfs2 import ( @@ -22,110 +24,142 @@ import ( // Override syscall table to add syscalls implementations from this package. func Override(table map[uintptr]kernel.Syscall) { table[0] = syscalls.Supported("read", Read) - - // Remove syscalls that haven't been converted yet. It's better to get ENOSYS - // rather than a SIGSEGV deep in the stack. - delete(table, 1) // write - delete(table, 2) // open - delete(table, 3) // close - delete(table, 4) // stat - delete(table, 5) // fstat - delete(table, 6) // lstat - delete(table, 7) // poll - delete(table, 8) // lseek - delete(table, 9) // mmap - delete(table, 16) // ioctl - delete(table, 17) // pread64 - delete(table, 18) // pwrite64 - delete(table, 19) // readv - delete(table, 20) // writev - delete(table, 21) // access - delete(table, 22) // pipe - delete(table, 32) // dup - delete(table, 33) // dup2 - delete(table, 40) // sendfile - delete(table, 59) // execve - delete(table, 72) // fcntl - delete(table, 73) // flock - delete(table, 74) // fsync - delete(table, 75) // fdatasync - delete(table, 76) // truncate - delete(table, 77) // ftruncate - delete(table, 78) // getdents - delete(table, 79) // getcwd - delete(table, 80) // chdir - delete(table, 81) // fchdir - delete(table, 82) // rename - delete(table, 83) // mkdir - delete(table, 84) // rmdir - delete(table, 85) // creat - delete(table, 86) // link - delete(table, 87) // unlink - delete(table, 88) // symlink - delete(table, 89) // readlink - delete(table, 90) // chmod - delete(table, 91) // fchmod - delete(table, 92) // chown - delete(table, 93) // fchown - delete(table, 94) // lchown - delete(table, 133) // mknod - delete(table, 137) // statfs - delete(table, 138) // fstatfs - delete(table, 161) // chroot - delete(table, 162) // sync + table[1] = syscalls.Supported("write", Write) + table[2] = syscalls.Supported("open", Open) + table[3] = syscalls.Supported("close", Close) + table[4] = syscalls.Supported("stat", Stat) + table[5] = syscalls.Supported("fstat", Fstat) + table[6] = syscalls.Supported("lstat", Lstat) + table[7] = syscalls.Supported("poll", Poll) + table[8] = syscalls.Supported("lseek", Lseek) + table[9] = syscalls.Supported("mmap", Mmap) + table[16] = syscalls.Supported("ioctl", Ioctl) + table[17] = syscalls.Supported("pread64", Pread64) + table[18] = syscalls.Supported("pwrite64", Pwrite64) + table[19] = syscalls.Supported("readv", Readv) + table[20] = syscalls.Supported("writev", Writev) + table[21] = syscalls.Supported("access", Access) + delete(table, 22) // pipe + table[23] = syscalls.Supported("select", Select) + table[32] = syscalls.Supported("dup", Dup) + table[33] = syscalls.Supported("dup2", Dup2) + delete(table, 40) // sendfile + delete(table, 41) // socket + delete(table, 42) // connect + delete(table, 43) // accept + delete(table, 44) // sendto + delete(table, 45) // recvfrom + delete(table, 46) // sendmsg + delete(table, 47) // recvmsg + delete(table, 48) // shutdown + delete(table, 49) // bind + delete(table, 50) // listen + delete(table, 51) // getsockname + delete(table, 52) // getpeername + delete(table, 53) // socketpair + delete(table, 54) // setsockopt + delete(table, 55) // getsockopt + table[59] = syscalls.Supported("execve", Execve) + table[72] = syscalls.Supported("fcntl", Fcntl) + delete(table, 73) // flock + table[74] = syscalls.Supported("fsync", Fsync) + table[75] = syscalls.Supported("fdatasync", Fdatasync) + table[76] = syscalls.Supported("truncate", Truncate) + table[77] = syscalls.Supported("ftruncate", Ftruncate) + table[78] = syscalls.Supported("getdents", Getdents) + table[79] = syscalls.Supported("getcwd", Getcwd) + table[80] = syscalls.Supported("chdir", Chdir) + table[81] = syscalls.Supported("fchdir", Fchdir) + table[82] = syscalls.Supported("rename", Rename) + table[83] = syscalls.Supported("mkdir", Mkdir) + table[84] = syscalls.Supported("rmdir", Rmdir) + table[85] = syscalls.Supported("creat", Creat) + table[86] = syscalls.Supported("link", Link) + table[87] = syscalls.Supported("unlink", Unlink) + table[88] = syscalls.Supported("symlink", Symlink) + table[89] = syscalls.Supported("readlink", Readlink) + table[90] = syscalls.Supported("chmod", Chmod) + table[91] = syscalls.Supported("fchmod", Fchmod) + table[92] = syscalls.Supported("chown", Chown) + table[93] = syscalls.Supported("fchown", Fchown) + table[94] = syscalls.Supported("lchown", Lchown) + table[132] = syscalls.Supported("utime", Utime) + table[133] = syscalls.Supported("mknod", Mknod) + table[137] = syscalls.Supported("statfs", Statfs) + table[138] = syscalls.Supported("fstatfs", Fstatfs) + table[161] = syscalls.Supported("chroot", Chroot) + table[162] = syscalls.Supported("sync", Sync) delete(table, 165) // mount delete(table, 166) // umount2 - delete(table, 172) // iopl - delete(table, 173) // ioperm delete(table, 187) // readahead - delete(table, 188) // setxattr - delete(table, 189) // lsetxattr - delete(table, 190) // fsetxattr - delete(table, 191) // getxattr - delete(table, 192) // lgetxattr - delete(table, 193) // fgetxattr + table[188] = syscalls.Supported("setxattr", Setxattr) + table[189] = syscalls.Supported("lsetxattr", Lsetxattr) + table[190] = syscalls.Supported("fsetxattr", Fsetxattr) + table[191] = syscalls.Supported("getxattr", Getxattr) + table[192] = syscalls.Supported("lgetxattr", Lgetxattr) + table[193] = syscalls.Supported("fgetxattr", Fgetxattr) + table[194] = syscalls.Supported("listxattr", Listxattr) + table[195] = syscalls.Supported("llistxattr", Llistxattr) + table[196] = syscalls.Supported("flistxattr", Flistxattr) + table[197] = syscalls.Supported("removexattr", Removexattr) + table[198] = syscalls.Supported("lremovexattr", Lremovexattr) + table[199] = syscalls.Supported("fremovexattr", Fremovexattr) delete(table, 206) // io_setup delete(table, 207) // io_destroy delete(table, 208) // io_getevents delete(table, 209) // io_submit delete(table, 210) // io_cancel - delete(table, 213) // epoll_create - delete(table, 214) // epoll_ctl_old - delete(table, 215) // epoll_wait_old - delete(table, 216) // remap_file_pages - delete(table, 217) // getdents64 - delete(table, 232) // epoll_wait - delete(table, 233) // epoll_ctl + table[213] = syscalls.Supported("epoll_create", EpollCreate) + table[217] = syscalls.Supported("getdents64", Getdents64) + delete(table, 221) // fdavise64 + table[232] = syscalls.Supported("epoll_wait", EpollWait) + table[233] = syscalls.Supported("epoll_ctl", EpollCtl) + table[235] = syscalls.Supported("utimes", Utimes) delete(table, 253) // inotify_init delete(table, 254) // inotify_add_watch delete(table, 255) // inotify_rm_watch - delete(table, 257) // openat - delete(table, 258) // mkdirat - delete(table, 259) // mknodat - delete(table, 260) // fchownat - delete(table, 261) // futimesat - delete(table, 262) // fstatat - delete(table, 263) // unlinkat - delete(table, 264) // renameat - delete(table, 265) // linkat - delete(table, 266) // symlinkat - delete(table, 267) // readlinkat - delete(table, 268) // fchmodat - delete(table, 269) // faccessat - delete(table, 270) // pselect - delete(table, 271) // ppoll + table[257] = syscalls.Supported("openat", Openat) + table[258] = syscalls.Supported("mkdirat", Mkdirat) + table[259] = syscalls.Supported("mknodat", Mknodat) + table[260] = syscalls.Supported("fchownat", Fchownat) + table[261] = syscalls.Supported("futimens", Futimens) + table[262] = syscalls.Supported("newfstatat", Newfstatat) + table[263] = syscalls.Supported("unlinkat", Unlinkat) + table[264] = syscalls.Supported("renameat", Renameat) + table[265] = syscalls.Supported("linkat", Linkat) + table[266] = syscalls.Supported("symlinkat", Symlinkat) + table[267] = syscalls.Supported("readlinkat", Readlinkat) + table[268] = syscalls.Supported("fchmodat", Fchmodat) + table[269] = syscalls.Supported("faccessat", Faccessat) + table[270] = syscalls.Supported("pselect", Pselect) + table[271] = syscalls.Supported("ppoll", Ppoll) + delete(table, 275) // splice + delete(table, 276) // tee + table[277] = syscalls.Supported("sync_file_range", SyncFileRange) + table[280] = syscalls.Supported("utimensat", Utimensat) + table[281] = syscalls.Supported("epoll_pwait", EpollPwait) + delete(table, 282) // signalfd + delete(table, 283) // timerfd_create + delete(table, 284) // eventfd delete(table, 285) // fallocate - delete(table, 291) // epoll_create1 - delete(table, 292) // dup3 + delete(table, 286) // timerfd_settime + delete(table, 287) // timerfd_gettime + delete(table, 288) // accept4 + delete(table, 289) // signalfd4 + delete(table, 290) // eventfd2 + table[291] = syscalls.Supported("epoll_create1", EpollCreate1) + table[292] = syscalls.Supported("dup3", Dup3) delete(table, 293) // pipe2 delete(table, 294) // inotify_init1 - delete(table, 295) // preadv - delete(table, 296) // pwritev - delete(table, 306) // syncfs - delete(table, 316) // renameat2 + table[295] = syscalls.Supported("preadv", Preadv) + table[296] = syscalls.Supported("pwritev", Pwritev) + delete(table, 299) // recvmmsg + table[306] = syscalls.Supported("syncfs", Syncfs) + delete(table, 307) // sendmmsg + table[316] = syscalls.Supported("renameat2", Renameat2) delete(table, 319) // memfd_create - delete(table, 322) // execveat - delete(table, 327) // preadv2 - delete(table, 328) // pwritev2 - delete(table, 332) // statx + table[322] = syscalls.Supported("execveat", Execveat) + table[327] = syscalls.Supported("preadv2", Preadv2) + table[328] = syscalls.Supported("pwritev2", Pwritev2) + table[332] = syscalls.Supported("statx", Statx) } diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go index 6af5c400f..a6b367468 100644 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build arm64 + package vfs2 import ( diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go new file mode 100644 index 000000000..60a43f0a0 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -0,0 +1,92 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Mmap implements Linux syscall mmap(2). +func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + prot := args[2].Int() + flags := args[3].Int() + fd := args[4].Int() + fixed := flags&linux.MAP_FIXED != 0 + private := flags&linux.MAP_PRIVATE != 0 + shared := flags&linux.MAP_SHARED != 0 + anon := flags&linux.MAP_ANONYMOUS != 0 + map32bit := flags&linux.MAP_32BIT != 0 + + // Require exactly one of MAP_PRIVATE and MAP_SHARED. + if private == shared { + return 0, nil, syserror.EINVAL + } + + opts := memmap.MMapOpts{ + Length: args[1].Uint64(), + Offset: args[5].Uint64(), + Addr: args[0].Pointer(), + Fixed: fixed, + Unmap: fixed, + Map32Bit: map32bit, + Private: private, + Perms: usermem.AccessType{ + Read: linux.PROT_READ&prot != 0, + Write: linux.PROT_WRITE&prot != 0, + Execute: linux.PROT_EXEC&prot != 0, + }, + MaxPerms: usermem.AnyAccess, + GrowsDown: linux.MAP_GROWSDOWN&flags != 0, + Precommit: linux.MAP_POPULATE&flags != 0, + } + if linux.MAP_LOCKED&flags != 0 { + opts.MLockMode = memmap.MLockEager + } + defer func() { + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef() + } + }() + + if !anon { + // Convert the passed FD to a file reference. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // mmap unconditionally requires that the FD is readable. + if !file.IsReadable() { + return 0, nil, syserror.EACCES + } + // MAP_SHARED requires that the FD be writable for PROT_WRITE. + if shared && !file.IsWritable() { + opts.MaxPerms.Write = false + } + + if err := file.ConfigureMMap(t, &opts); err != nil { + return 0, nil, err + } + } + + rv, err := t.MemoryManager().MMap(t, opts) + return uintptr(rv), nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go new file mode 100644 index 000000000..97da6c647 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/path.go @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) { + pathname, err := t.CopyInString(addr, linux.PATH_MAX) + if err != nil { + return fspath.Path{}, err + } + return fspath.Parse(pathname), nil +} + +type taskPathOperation struct { + pop vfs.PathOperation + haveStartRef bool +} + +func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) { + root := t.FSContext().RootDirectoryVFS2() + start := root + haveStartRef := false + if !path.Absolute { + if !path.HasComponents() && !bool(shouldAllowEmptyPath) { + root.DecRef() + return taskPathOperation{}, syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + haveStartRef = true + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + root.DecRef() + return taskPathOperation{}, syserror.EBADF + } + start = dirfile.VirtualDentry() + start.IncRef() + haveStartRef = true + dirfile.DecRef() + } + } + return taskPathOperation{ + pop: vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: bool(shouldFollowFinalSymlink), + }, + haveStartRef: haveStartRef, + }, nil +} + +func (tpop *taskPathOperation) Release() { + tpop.pop.Root.DecRef() + if tpop.haveStartRef { + tpop.pop.Start.DecRef() + tpop.haveStartRef = false + } +} + +type shouldAllowEmptyPath bool + +const ( + disallowEmptyPath shouldAllowEmptyPath = false + allowEmptyPath shouldAllowEmptyPath = true +) + +type shouldFollowFinalSymlink bool + +const ( + nofollowFinalSymlink shouldFollowFinalSymlink = false + followFinalSymlink shouldFollowFinalSymlink = true +) diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go new file mode 100644 index 000000000..dbf4882da --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -0,0 +1,584 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// fileCap is the maximum allowable files for poll & select. This has no +// equivalent in Linux; it exists in gVisor since allocation failure in Go is +// unrecoverable. +const fileCap = 1024 * 1024 + +// Masks for "readable", "writable", and "exceptional" events as defined by +// select(2). +const ( + // selectReadEvents is analogous to the Linux kernel's + // fs/select.c:POLLIN_SET. + selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR + + // selectWriteEvents is analogous to the Linux kernel's + // fs/select.c:POLLOUT_SET. + selectWriteEvents = linux.POLLOUT | linux.POLLERR + + // selectExceptEvents is analogous to the Linux kernel's + // fs/select.c:POLLEX_SET. + selectExceptEvents = linux.POLLPRI +) + +// pollState tracks the associated file description and waiter of a PollFD. +type pollState struct { + file *vfs.FileDescription + waiter waiter.Entry +} + +// initReadiness gets the current ready mask for the file represented by the FD +// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is +// used to register with the file for event notifications, and a reference to +// the file is stored in "state". +func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) { + if pfd.FD < 0 { + pfd.REvents = 0 + return + } + + file := t.GetFileVFS2(pfd.FD) + if file == nil { + pfd.REvents = linux.POLLNVAL + return + } + + if ch == nil { + defer file.DecRef() + } else { + state.file = file + state.waiter, _ = waiter.NewChannelEntry(ch) + file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events))) + } + + r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events))) + pfd.REvents = int16(r.ToLinux()) & pfd.Events +} + +// releaseState releases all the pollState in "state". +func releaseState(state []pollState) { + for i := range state { + if state[i].file != nil { + state[i].file.EventUnregister(&state[i].waiter) + state[i].file.DecRef() + } + } +} + +// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout" +// when "timeout" is greater than zero. +// +// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or +// positive if interrupted by a signal. +func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) { + var ch chan struct{} + if timeout != 0 { + ch = make(chan struct{}, 1) + } + + // Register for event notification in the files involved if we may + // block (timeout not zero). Once we find a file that has a non-zero + // result, we stop registering for events but still go through all files + // to get their ready masks. + state := make([]pollState, len(pfd)) + defer releaseState(state) + n := uintptr(0) + for i := range pfd { + initReadiness(t, &pfd[i], &state[i], ch) + if pfd[i].REvents != 0 { + n++ + ch = nil + } + } + + if timeout == 0 { + return timeout, n, nil + } + + haveTimeout := timeout >= 0 + + for n == 0 { + var err error + // Wait for a notification. + timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout) + if err != nil { + if err == syserror.ETIMEDOUT { + err = nil + } + return timeout, 0, err + } + + // We got notified, count how many files are ready. If none, + // then this was a spurious notification, and we just go back + // to sleep with the remaining timeout. + for i := range state { + if state[i].file == nil { + continue + } + + r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events))) + rl := int16(r.ToLinux()) & pfd[i].Events + if rl != 0 { + pfd[i].REvents = rl + n++ + } + } + } + + return timeout, n, nil +} + +// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max. +func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) { + if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) { + return nil, syserror.EINVAL + } + + pfd := make([]linux.PollFD, nfds) + if nfds > 0 { + if _, err := t.CopyIn(addr, &pfd); err != nil { + return nil, err + } + } + + return pfd, nil +} + +func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) { + pfd, err := copyInPollFDs(t, addr, nfds) + if err != nil { + return timeout, 0, err + } + + // Compatibility warning: Linux adds POLLHUP and POLLERR just before + // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after + // polling, changing event masks here is an application-visible difference. + // (Linux also doesn't copy out event masks at all, only revents.) + for i := range pfd { + pfd[i].Events |= linux.POLLHUP | linux.POLLERR + } + remainingTimeout, n, err := pollBlock(t, pfd, timeout) + err = syserror.ConvertIntr(err, syserror.EINTR) + + // The poll entries are copied out regardless of whether + // any are set or not. This aligns with the Linux behavior. + if nfds > 0 && err == nil { + if _, err := t.CopyOut(addr, pfd); err != nil { + return remainingTimeout, 0, err + } + } + + return remainingTimeout, n, err +} + +// CopyInFDSet copies an fd set from select(2)/pselect(2). +func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { + set := make([]byte, nBytes) + + if addr != 0 { + if _, err := t.CopyIn(addr, &set); err != nil { + return nil, err + } + // If we only use part of the last byte, mask out the extraneous bits. + // + // N.B. This only works on little-endian architectures. + if nBitsInLastPartialByte != 0 { + set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte + } + } + return set, nil +} + +func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { + if nfds < 0 || nfds > fileCap { + return 0, syserror.EINVAL + } + + // Calculate the size of the fd sets (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 + + // Capture all the provided input vectors. + r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err + } + w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err + } + e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err + } + + // Count how many FDs are actually being requested so that we can build + // a PollFD array. + fdCount := 0 + for i := 0; i < nBytes; i++ { + v := r[i] | w[i] | e[i] + for v != 0 { + v &= (v - 1) + fdCount++ + } + } + + // Build the PollFD array. + pfd := make([]linux.PollFD, 0, fdCount) + var fd int32 + for i := 0; i < nBytes; i++ { + rV, wV, eV := r[i], w[i], e[i] + v := rV | wV | eV + m := byte(1) + for j := 0; j < 8; j++ { + if (v & m) != 0 { + // Make sure the fd is valid and decrement the reference + // immediately to ensure we don't leak. Note, another thread + // might be about to close fd. This is racy, but that's + // OK. Linux is racy in the same way. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + file.DecRef() + + var mask int16 + if (rV & m) != 0 { + mask |= selectReadEvents + } + + if (wV & m) != 0 { + mask |= selectWriteEvents + } + + if (eV & m) != 0 { + mask |= selectExceptEvents + } + + pfd = append(pfd, linux.PollFD{ + FD: fd, + Events: mask, + }) + } + + fd++ + m <<= 1 + } + } + + // Do the syscall, then count the number of bits set. + if _, _, err = pollBlock(t, pfd, timeout); err != nil { + return 0, syserror.ConvertIntr(err, syserror.EINTR) + } + + // r, w, and e are currently event mask bitsets; unset bits corresponding + // to events that *didn't* occur. + bitSetCount := uintptr(0) + for idx := range pfd { + events := pfd[idx].REvents + i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8) + m := byte(1) << j + if r[i]&m != 0 { + if (events & selectReadEvents) != 0 { + bitSetCount++ + } else { + r[i] &^= m + } + } + if w[i]&m != 0 { + if (events & selectWriteEvents) != 0 { + bitSetCount++ + } else { + w[i] &^= m + } + } + if e[i]&m != 0 { + if (events & selectExceptEvents) != 0 { + bitSetCount++ + } else { + e[i] &^= m + } + } + } + + // Copy updated vectors back. + if readFDs != 0 { + if _, err := t.CopyOut(readFDs, r); err != nil { + return 0, err + } + } + + if writeFDs != 0 { + if _, err := t.CopyOut(writeFDs, w); err != nil { + return 0, err + } + } + + if exceptFDs != 0 { + if _, err := t.CopyOut(exceptFDs, e); err != nil { + return 0, err + } + } + + return bitSetCount, nil +} + +// timeoutRemaining returns the amount of time remaining for the specified +// timeout or 0 if it has elapsed. +// +// startNs must be from CLOCK_MONOTONIC. +func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration { + now := t.Kernel().MonotonicClock().Now() + remaining := timeout - now.Sub(startNs) + if remaining < 0 { + remaining = 0 + } + return remaining +} + +// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr. +// +// startNs must be from CLOCK_MONOTONIC. +func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error { + if timeout <= 0 { + return nil + } + remaining := timeoutRemaining(t, startNs, timeout) + tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds()) + return tsRemaining.CopyOut(t, timespecAddr) +} + +// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr. +// +// startNs must be from CLOCK_MONOTONIC. +func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error { + if timeout <= 0 { + return nil + } + remaining := timeoutRemaining(t, startNs, timeout) + tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds()) + return tvRemaining.CopyOut(t, timevalAddr) +} + +// pollRestartBlock encapsulates the state required to restart poll(2) via +// restart_syscall(2). +// +// +stateify savable +type pollRestartBlock struct { + pfdAddr usermem.Addr + nfds uint + timeout time.Duration +} + +// Restart implements kernel.SyscallRestartBlock.Restart. +func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { + return poll(t, p.pfdAddr, p.nfds, p.timeout) +} + +func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) { + remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) + // On an interrupt poll(2) is restarted with the remaining timeout. + if err == syserror.EINTR { + t.SetSyscallRestartBlock(&pollRestartBlock{ + pfdAddr: pfdAddr, + nfds: nfds, + timeout: remainingTimeout, + }) + return 0, kernel.ERESTART_RESTARTBLOCK + } + return n, err +} + +// Poll implements linux syscall poll(2). +func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pfdAddr := args[0].Pointer() + nfds := uint(args[1].Uint()) // poll(2) uses unsigned long. + timeout := time.Duration(args[2].Int()) * time.Millisecond + n, err := poll(t, pfdAddr, nfds, timeout) + return n, nil, err +} + +// Ppoll implements linux syscall ppoll(2). +func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pfdAddr := args[0].Pointer() + nfds := uint(args[1].Uint()) // poll(2) uses unsigned long. + timespecAddr := args[2].Pointer() + maskAddr := args[3].Pointer() + maskSize := uint(args[4].Uint()) + + timeout, err := copyTimespecInToDuration(t, timespecAddr) + if err != nil { + return 0, nil, err + } + + var startNs ktime.Time + if timeout > 0 { + startNs = t.Kernel().MonotonicClock().Now() + } + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + _, n, err := doPoll(t, pfdAddr, nfds, timeout) + copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) + // doPoll returns EINTR if interrupted, but ppoll is normally restartable + // if interrupted by something other than a signal handled by the + // application (i.e. returns ERESTARTNOHAND). However, if + // copyOutTimespecRemaining failed, then the restarted ppoll would use the + // wrong timeout, so the error should be left as EINTR. + // + // Note that this means that if err is nil but copyErr is not, copyErr is + // ignored. This is consistent with Linux. + if err == syserror.EINTR && copyErr == nil { + err = kernel.ERESTARTNOHAND + } + return n, nil, err +} + +// Select implements linux syscall select(2). +func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + nfds := int(args[0].Int()) // select(2) uses an int. + readFDs := args[1].Pointer() + writeFDs := args[2].Pointer() + exceptFDs := args[3].Pointer() + timevalAddr := args[4].Pointer() + + // Use a negative Duration to indicate "no timeout". + timeout := time.Duration(-1) + if timevalAddr != 0 { + var timeval linux.Timeval + if err := timeval.CopyIn(t, timevalAddr); err != nil { + return 0, nil, err + } + if timeval.Sec < 0 || timeval.Usec < 0 { + return 0, nil, syserror.EINVAL + } + timeout = time.Duration(timeval.ToNsecCapped()) + } + startNs := t.Kernel().MonotonicClock().Now() + n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) + copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) + // See comment in Ppoll. + if err == syserror.EINTR && copyErr == nil { + err = kernel.ERESTARTNOHAND + } + return n, nil, err +} + +// Pselect implements linux syscall pselect(2). +func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + nfds := int(args[0].Int()) // select(2) uses an int. + readFDs := args[1].Pointer() + writeFDs := args[2].Pointer() + exceptFDs := args[3].Pointer() + timespecAddr := args[4].Pointer() + maskWithSizeAddr := args[5].Pointer() + + timeout, err := copyTimespecInToDuration(t, timespecAddr) + if err != nil { + return 0, nil, err + } + + var startNs ktime.Time + if timeout > 0 { + startNs = t.Kernel().MonotonicClock().Now() + } + + if maskWithSizeAddr != 0 { + if t.Arch().Width() != 8 { + panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width())) + } + var maskStruct sigSetWithSize + if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil { + return 0, nil, err + } + if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil { + return 0, nil, err + } + } + + n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) + copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) + // See comment in Ppoll. + if err == syserror.EINTR && copyErr == nil { + err = kernel.ERESTARTNOHAND + } + return n, nil, err +} + +// +marshal +type sigSetWithSize struct { + sigsetAddr uint64 + sizeofSigset uint64 +} + +// copyTimespecInToDuration copies a Timespec from the untrusted app range, +// validates it and converts it to a Duration. +// +// If the Timespec is larger than what can be represented in a Duration, the +// returned value is the maximum that Duration will allow. +// +// If timespecAddr is NULL, the returned value is negative. +func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) { + // Use a negative Duration to indicate "no timeout". + timeout := time.Duration(-1) + if timespecAddr != 0 { + var timespec linux.Timespec + if err := timespec.CopyIn(t, timespecAddr); err != nil { + return 0, err + } + if !timespec.Valid() { + return 0, syserror.EINVAL + } + timeout = time.Duration(timespec.ToNsecCapped()) + } + return timeout, nil +} + +func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error { + if maskAddr == 0 { + return nil + } + if maskSize != linux.SignalSetSize { + return syserror.EINVAL + } + var mask linux.SignalSet + if err := mask.CopyIn(t, maskAddr); err != nil { + return err + } + mask &^= kernel.UnblockableSignals + oldmask := t.SignalMask() + t.SetSignalMask(mask) + t.SetSavedSignalMask(oldmask) + return nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go new file mode 100644 index 000000000..35f6308d6 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -0,0 +1,511 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr + eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr +) + +// Read implements Linux syscall read(2). +func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := read(t, file, dst, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file) +} + +// Readv implements Linux syscall readv(2). +func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Get the destination of the read. + dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := read(t, file, dst, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file) +} + +func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + n, err := file.Read(t, dst, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskRead) + + total := n + for { + // Shorten dst to reflect bytes previously read. + dst = dst.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.Read(t, dst, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Pread64 implements Linux syscall pread64(2). +func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file) +} + +// Preadv implements Linux syscall preadv(2). +func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file) +} + +// Preadv2 implements Linux syscall preadv2(2). +func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // While the glibc signature is + // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) + // the actual syscall + // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142) + // splits the offset argument into a high/low value for compatibility with + // 32-bit architectures. The flags argument is the 6th argument (index 5). + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + flags := args[5].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < -1 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + opts := vfs.ReadOptions{ + Flags: uint32(flags), + } + var n int64 + if offset == -1 { + n, err = read(t, file, dst, opts) + } else { + n, err = pread(t, file, dst, offset, opts) + } + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) +} + +func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + n, err := file.PRead(t, dst, offset, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskRead) + + total := n + for { + // Shorten dst to reflect bytes previously read. + dst = dst.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.PRead(t, dst, offset+total, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Write implements Linux syscall write(2). +func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := write(t, file, src, vfs.WriteOptions{}) + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file) +} + +// Writev implements Linux syscall writev(2). +func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Get the source of the write. + src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := write(t, file, src, vfs.WriteOptions{}) + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file) +} + +func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + n, err := file.Write(t, src, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskWrite) + + total := n + for { + // Shorten src to reflect bytes previously written. + src = src.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.Write(t, src, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Pwrite64 implements Linux syscall pwrite64(2). +func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file) +} + +// Pwritev implements Linux syscall pwritev(2). +func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file) +} + +// Pwritev2 implements Linux syscall pwritev2(2). +func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // While the glibc signature is + // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) + // the actual syscall + // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162) + // splits the offset argument into a high/low value for compatibility with + // 32-bit architectures. The flags argument is the 6th argument (index 5). + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + flags := args[5].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < -1 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + opts := vfs.WriteOptions{ + Flags: uint32(flags), + } + var n int64 + if offset == -1 { + n, err = write(t, file, src, opts) + } else { + n, err = pwrite(t, file, src, offset, opts) + } + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) +} + +func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, err := file.PWrite(t, src, offset, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskWrite) + + total := n + for { + // Shorten src to reflect bytes previously written. + src = src.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.PWrite(t, src, offset+total, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Lseek implements Linux syscall lseek(2). +func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + whence := args[2].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + newoff, err := file.Seek(t, offset, whence) + return uintptr(newoff), nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go new file mode 100644 index 000000000..9250659ff --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -0,0 +1,380 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX + +// Chmod implements Linux syscall chmod(2). +func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + mode := args[1].ModeT() + return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode) +} + +// Fchmodat implements Linux syscall fchmodat(2). +func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + mode := args[2].ModeT() + return 0, nil, fchmodat(t, dirfd, pathAddr, mode) +} + +func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error { + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + + return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_MODE, + Mode: uint16(mode & chmodMask), + }, + }) +} + +// Fchmod implements Linux syscall fchmod(2). +func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + mode := args[1].ModeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.SetStat(t, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_MODE, + Mode: uint16(mode & chmodMask), + }, + }) +} + +// Chown implements Linux syscall chown(2). +func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + owner := args[1].Int() + group := args[2].Int() + return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */) +} + +// Lchown implements Linux syscall lchown(2). +func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + owner := args[1].Int() + group := args[2].Int() + return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW) +} + +// Fchownat implements Linux syscall fchownat(2). +func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + owner := args[2].Int() + group := args[3].Int() + flags := args[4].Int() + return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags) +} + +func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil { + return err + } + + return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts) +} + +func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error { + userns := t.UserNamespace() + if owner != -1 { + kuid := userns.MapToKUID(auth.UID(owner)) + if !kuid.Ok() { + return syserror.EINVAL + } + opts.Stat.Mask |= linux.STATX_UID + opts.Stat.UID = uint32(kuid) + } + if group != -1 { + kgid := userns.MapToKGID(auth.GID(group)) + if !kgid.Ok() { + return syserror.EINVAL + } + opts.Stat.Mask |= linux.STATX_GID + opts.Stat.GID = uint32(kgid) + } + return nil +} + +// Fchown implements Linux syscall fchown(2). +func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + owner := args[1].Int() + group := args[2].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil { + return 0, nil, err + } + return 0, nil, file.SetStat(t, opts) +} + +// Truncate implements Linux syscall truncate(2). +func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + length := args[1].Int64() + + if length < 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + + return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_SIZE, + Size: uint64(length), + }, + }) +} + +// Ftruncate implements Linux syscall ftruncate(2). +func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + length := args[1].Int64() + + if length < 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.SetStat(t, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_SIZE, + Size: uint64(length), + }, + }) +} + +// Utime implements Linux syscall utime(2). +func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + timesAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + opts := vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_ATIME | linux.STATX_MTIME, + }, + } + if timesAddr == 0 { + opts.Stat.Atime.Nsec = linux.UTIME_NOW + opts.Stat.Mtime.Nsec = linux.UTIME_NOW + } else { + var times linux.Utime + if err := times.CopyIn(t, timesAddr); err != nil { + return 0, nil, err + } + opts.Stat.Atime.Sec = times.Actime + opts.Stat.Mtime.Sec = times.Modtime + } + + return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts) +} + +// Utimes implements Linux syscall utimes(2). +func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + timesAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + opts := vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_ATIME | linux.STATX_MTIME, + }, + } + if timesAddr == 0 { + opts.Stat.Atime.Nsec = linux.UTIME_NOW + opts.Stat.Mtime.Nsec = linux.UTIME_NOW + } else { + var times [2]linux.Timeval + if _, err := t.CopyIn(timesAddr, ×); err != nil { + return 0, nil, err + } + opts.Stat.Atime = linux.StatxTimestamp{ + Sec: times[0].Sec, + Nsec: uint32(times[0].Usec * 1000), + } + opts.Stat.Mtime = linux.StatxTimestamp{ + Sec: times[1].Sec, + Nsec: uint32(times[1].Usec * 1000), + } + } + + return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts) +} + +// Utimensat implements Linux syscall utimensat(2). +func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + timesAddr := args[2].Pointer() + flags := args[3].Int() + + if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil { + return 0, nil, err + } + + return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts) +} + +// Futimens implements Linux syscall futimens(2). +func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + timesAddr := args[1].Pointer() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil { + return 0, nil, err + } + + return 0, nil, file.SetStat(t, opts) +} + +func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error { + if timesAddr == 0 { + opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME + opts.Stat.Atime.Nsec = linux.UTIME_NOW + opts.Stat.Mtime.Nsec = linux.UTIME_NOW + return nil + } + var times [2]linux.Timespec + if _, err := t.CopyIn(timesAddr, ×); err != nil { + return err + } + if times[0].Nsec != linux.UTIME_OMIT { + opts.Stat.Mask |= linux.STATX_ATIME + opts.Stat.Atime = linux.StatxTimestamp{ + Sec: times[0].Sec, + Nsec: uint32(times[0].Nsec), + } + } + if times[1].Nsec != linux.UTIME_OMIT { + opts.Stat.Mask |= linux.STATX_MTIME + opts.Stat.Mtime = linux.StatxTimestamp{ + Sec: times[1].Sec, + Nsec: uint32(times[1].Nsec), + } + } + return nil +} + +func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error { + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + if !path.Absolute { + if !path.HasComponents() && !bool(shouldAllowEmptyPath) { + return syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + return syserror.EBADF + } + if !path.HasComponents() { + // Use FileDescription.SetStat() instead of + // VirtualFilesystem.SetStatAt(), since the former may be able + // to use opened file state to expedite the SetStat. + err := dirfile.SetStat(t, *opts) + dirfile.DecRef() + return err + } + start = dirfile.VirtualDentry() + start.IncRef() + defer start.DecRef() + dirfile.DecRef() + } + } + return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: bool(shouldFollowFinalSymlink), + }, opts) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go new file mode 100644 index 000000000..dca8d7011 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -0,0 +1,346 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Stat implements Linux syscall stat(2). +func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + statAddr := args[1].Pointer() + return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */) +} + +// Lstat implements Linux syscall lstat(2). +func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + statAddr := args[1].Pointer() + return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW) +} + +// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2). +func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + statAddr := args[2].Pointer() + flags := args[3].Int() + return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags) +} + +func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return syserror.EINVAL + } + + opts := vfs.StatOptions{ + Mask: linux.STATX_BASIC_STATS, + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + if !path.Absolute { + if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { + return syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + return syserror.EBADF + } + if !path.HasComponents() { + // Use FileDescription.Stat() instead of + // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the + // former may be able to use opened file state to expedite the + // Stat. + statx, err := dirfile.Stat(t, opts) + dirfile.DecRef() + if err != nil { + return err + } + var stat linux.Stat + convertStatxToUserStat(t, &statx, &stat) + return stat.CopyOut(t, statAddr) + } + start = dirfile.VirtualDentry() + start.IncRef() + defer start.DecRef() + dirfile.DecRef() + } + } + + statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + }, &opts) + if err != nil { + return err + } + var stat linux.Stat + convertStatxToUserStat(t, &statx, &stat) + return stat.CopyOut(t, statAddr) +} + +// This takes both input and output as pointer arguments to avoid copying large +// structs. +func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) { + // Linux just copies fields from struct kstat without regard to struct + // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too. + userns := t.UserNamespace() + *stat = linux.Stat{ + Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)), + Ino: statx.Ino, + Nlink: uint64(statx.Nlink), + Mode: uint32(statx.Mode), + UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()), + GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()), + Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)), + Size: int64(statx.Size), + Blksize: int64(statx.Blksize), + Blocks: int64(statx.Blocks), + ATime: timespecFromStatxTimestamp(statx.Atime), + MTime: timespecFromStatxTimestamp(statx.Mtime), + CTime: timespecFromStatxTimestamp(statx.Ctime), + } +} + +func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec { + return linux.Timespec{ + Sec: sxts.Sec, + Nsec: int64(sxts.Nsec), + } +} + +// Fstat implements Linux syscall fstat(2). +func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + statAddr := args[1].Pointer() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + statx, err := file.Stat(t, vfs.StatOptions{ + Mask: linux.STATX_BASIC_STATS, + }) + if err != nil { + return 0, nil, err + } + var stat linux.Stat + convertStatxToUserStat(t, &statx, &stat) + return 0, nil, stat.CopyOut(t, statAddr) +} + +// Statx implements Linux syscall statx(2). +func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + flags := args[2].Int() + mask := args[3].Uint() + statxAddr := args[4].Pointer() + + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + + opts := vfs.StatOptions{ + Mask: mask, + Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE), + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + if !path.Absolute { + if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { + return 0, nil, syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + return 0, nil, syserror.EBADF + } + if !path.HasComponents() { + // Use FileDescription.Stat() instead of + // VirtualFilesystem.StatAt() for statx(fd, ""), since the + // former may be able to use opened file state to expedite the + // Stat. + statx, err := dirfile.Stat(t, opts) + dirfile.DecRef() + if err != nil { + return 0, nil, err + } + userifyStatx(t, &statx) + return 0, nil, statx.CopyOut(t, statxAddr) + } + start = dirfile.VirtualDentry() + start.IncRef() + defer start.DecRef() + dirfile.DecRef() + } + } + + statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + }, &opts) + if err != nil { + return 0, nil, err + } + userifyStatx(t, &statx) + return 0, nil, statx.CopyOut(t, statxAddr) +} + +func userifyStatx(t *kernel.Task, statx *linux.Statx) { + userns := t.UserNamespace() + statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow()) + statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow()) +} + +// Readlink implements Linux syscall readlink(2). +func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + bufAddr := args[1].Pointer() + size := args[2].SizeT() + return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size) +} + +// Access implements Linux syscall access(2). +func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // FIXME(jamieliu): actually implement + return 0, nil, nil +} + +// Faccessat implements Linux syscall access(2). +func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // FIXME(jamieliu): actually implement + return 0, nil, nil +} + +// Readlinkat implements Linux syscall mknodat(2). +func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + bufAddr := args[2].Pointer() + size := args[3].SizeT() + return readlinkat(t, dirfd, pathAddr, bufAddr, size) +} + +func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) { + if int(size) <= 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + // "Since Linux 2.6.39, pathname can be an empty string, in which case the + // call operates on the symbolic link referred to by dirfd ..." - + // readlinkat(2) + tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + + if len(target) > int(size) { + target = target[:size] + } + n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target)) + if n == 0 { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Statfs implements Linux syscall statfs(2). +func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + bufAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + + return 0, nil, statfs.CopyOut(t, bufAddr) +} + +// Fstatfs implements Linux syscall fstatfs(2). +func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + bufAddr := args[1].Pointer() + + tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + + return 0, nil, statfs.CopyOut(t, bufAddr) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go new file mode 100644 index 000000000..365250b0b --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -0,0 +1,87 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Sync implements Linux syscall sync(2). +func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t) +} + +// Syncfs implements Linux syscall syncfs(2). +func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.SyncFS(t) +} + +// Fsync implements Linux syscall fsync(2). +func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.Sync(t) +} + +// Fdatasync implements Linux syscall fdatasync(2). +func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata. + return Fsync(t, args) +} + +// SyncFileRange implements Linux syscall sync_file_range(2). +func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + nbytes := args[2].Int64() + flags := args[3].Uint() + + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if nbytes < 0 { + return 0, nil, syserror.EINVAL + } + if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of + // [offset, offset+nbytes). + return 0, nil, file.Sync(t) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_read.go b/pkg/sentry/syscalls/linux/vfs2/sys_read.go deleted file mode 100644 index 7667524c7..000000000 --- a/pkg/sentry/syscalls/linux/vfs2/sys_read.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs2 - -import ( - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // EventMaskRead contains events that can be triggered on reads. - EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr -) - -// Read implements linux syscall read(2). Note that we try to get a buffer that -// is exactly the size requested because some applications like qemu expect -// they can do large reads all at once. Bug for bug. Same for other read -// calls below. -func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - fd := args[0].Int() - addr := args[1].Pointer() - size := args[2].SizeT() - - file := t.GetFileVFS2(fd) - if file == nil { - return 0, nil, syserror.EBADF - } - defer file.DecRef() - - // Check that the size is legitimate. - si := int(size) - if si < 0 { - return 0, nil, syserror.EINVAL - } - - // Get the destination of the read. - dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ - AddressSpaceActive: true, - }) - if err != nil { - return 0, nil, err - } - - n, err := read(t, file, dst, vfs.ReadOptions{}) - t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, linux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file) -} - -func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - n, err := file.Read(t, dst, opts) - if err != syserror.ErrWouldBlock { - return n, err - } - - // Register for notifications. - w, ch := waiter.NewChannelEntry(nil) - file.EventRegister(&w, EventMaskRead) - - total := n - for { - // Shorten dst to reflect bytes previously read. - dst = dst.DropFirst(int(n)) - - // Issue the request and break out if it completes with anything other than - // "would block". - n, err := file.Read(t, dst, opts) - total += n - if err != syserror.ErrWouldBlock { - break - } - if err := t.Block(ch); err != nil { - break - } - } - file.EventUnregister(&w) - - return total, err -} diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go new file mode 100644 index 000000000..89e9ff4d7 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -0,0 +1,353 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "bytes" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Listxattr implements Linux syscall listxattr(2). +func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return listxattr(t, args, followFinalSymlink) +} + +// Llistxattr implements Linux syscall llistxattr(2). +func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return listxattr(t, args, nofollowFinalSymlink) +} + +func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + listAddr := args[1].Pointer() + size := args[2].SizeT() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrNameList(t, listAddr, size, names) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Flistxattr implements Linux syscall flistxattr(2). +func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + listAddr := args[1].Pointer() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + names, err := file.Listxattr(t) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrNameList(t, listAddr, size, names) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Getxattr implements Linux syscall getxattr(2). +func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getxattr(t, args, followFinalSymlink) +} + +// Lgetxattr implements Linux syscall lgetxattr(2). +func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getxattr(t, args, nofollowFinalSymlink) +} + +func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + + value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrValue(t, valueAddr, size, value) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Fgetxattr implements Linux syscall fgetxattr(2). +func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + + value, err := file.Getxattr(t, name) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrValue(t, valueAddr, size, value) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Setxattr implements Linux syscall setxattr(2). +func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, setxattr(t, args, followFinalSymlink) +} + +// Lsetxattr implements Linux syscall lsetxattr(2). +func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, setxattr(t, args, nofollowFinalSymlink) +} + +func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + flags := args[4].Int() + + if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { + return syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return err + } + value, err := copyInXattrValue(t, valueAddr, size) + if err != nil { + return err + } + + return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{ + Name: name, + Value: value, + Flags: uint32(flags), + }) +} + +// Fsetxattr implements Linux syscall fsetxattr(2). +func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + flags := args[4].Int() + + if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + value, err := copyInXattrValue(t, valueAddr, size) + if err != nil { + return 0, nil, err + } + + return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{ + Name: name, + Value: value, + Flags: uint32(flags), + }) +} + +// Removexattr implements Linux syscall removexattr(2). +func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, removexattr(t, args, followFinalSymlink) +} + +// Lremovexattr implements Linux syscall lremovexattr(2). +func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, removexattr(t, args, nofollowFinalSymlink) +} + +func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return err + } + + return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name) +} + +// Fremovexattr implements Linux syscall fremovexattr(2). +func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + nameAddr := args[1].Pointer() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + + return 0, nil, file.Removexattr(t, name) +} + +func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { + name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) + if err != nil { + if err == syserror.ENAMETOOLONG { + return "", syserror.ERANGE + } + return "", err + } + if len(name) == 0 { + return "", syserror.ERANGE + } + return name, nil +} + +func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) { + if size > linux.XATTR_LIST_MAX { + size = linux.XATTR_LIST_MAX + } + var buf bytes.Buffer + for _, name := range names { + buf.WriteString(name) + buf.WriteByte(0) + } + if size == 0 { + // Return the size that would be required to accomodate the list. + return buf.Len(), nil + } + if buf.Len() > int(size) { + if size >= linux.XATTR_LIST_MAX { + return 0, syserror.E2BIG + } + return 0, syserror.ERANGE + } + return t.CopyOutBytes(listAddr, buf.Bytes()) +} + +func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) { + if size > linux.XATTR_SIZE_MAX { + return "", syserror.E2BIG + } + buf := make([]byte, size) + if _, err := t.CopyInBytes(valueAddr, buf); err != nil { + return "", err + } + return gohacks.StringFromImmutableBytes(buf), nil +} + +func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) { + if size > linux.XATTR_SIZE_MAX { + size = linux.XATTR_SIZE_MAX + } + if size == 0 { + // Return the size that would be required to accomodate the value. + return len(value), nil + } + if len(value) > int(size) { + if size >= linux.XATTR_SIZE_MAX { + return 0, syserror.E2BIG + } + return 0, syserror.ERANGE + } + return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value)) +} diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 0b4f18ab5..07c8383e6 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/gohacks", "//pkg/log", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index eed41139b..3da45d744 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -202,6 +202,9 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin // Add epi to file.epolls so that it is removed when the last // FileDescription reference is dropped. file.epollMu.Lock() + if file.epolls == nil { + file.epolls = make(map[*epollInterest]struct{}) + } file.epolls[epi] = struct{}{} file.epollMu.Unlock() diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index 1fe766a44..bc7581698 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -26,6 +26,7 @@ import ( "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/sync" ) @@ -160,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) loop: for { @@ -361,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr //go:linkname rand32 runtime.fastrand func rand32() uint32 - -// This is copy/pasted from runtime.noescape(), and is needed because arguments -// apparently escape from all functions defined by linkname. -// -//go:nosplit -func noescape(p unsafe.Pointer) unsafe.Pointer { - x := uintptr(p) - return unsafe.Pointer(x ^ 0) -} diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 8a0b382f6..eb4ebb511 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -228,7 +228,7 @@ func (rp *ResolvingPath) Advance() { rp.pit = next } else { // at end of path segment, continue with next one rp.curPart-- - rp.pit = rp.parts[rp.curPart-1] + rp.pit = rp.parts[rp.curPart] } } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 8f29031b2..73f8043be 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -385,15 +385,11 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential // Only a regular file can be executed. stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE}) if err != nil { + fd.DecRef() return nil, err } - if stat.Mask&linux.STATX_TYPE != 0 { - // This shouldn't happen, but if type can't be retrieved, file can't - // be executed. - return nil, syserror.EACCES - } - if t := linux.FileMode(stat.Mode).FileType(); t != linux.ModeRegular { - ctx.Infof("%q is not a regular file: %v", pop.Path, t) + if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG { + fd.DecRef() return nil, syserror.EACCES } } diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD new file mode 100644 index 000000000..0500a22cf --- /dev/null +++ b/pkg/syncevent/BUILD @@ -0,0 +1,39 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +licenses(["notice"]) + +go_library( + name = "syncevent", + srcs = [ + "broadcaster.go", + "receiver.go", + "source.go", + "syncevent.go", + "waiter_amd64.s", + "waiter_arm64.s", + "waiter_asm_unsafe.go", + "waiter_noasm_unsafe.go", + "waiter_unsafe.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/atomicbitops", + "//pkg/sync", + ], +) + +go_test( + name = "syncevent_test", + size = "small", + srcs = [ + "broadcaster_test.go", + "syncevent_example_test.go", + "waiter_test.go", + ], + library = ":syncevent", + deps = [ + "//pkg/sleep", + "//pkg/sync", + "//pkg/waiter", + ], +) diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go new file mode 100644 index 000000000..4bff59e7d --- /dev/null +++ b/pkg/syncevent/broadcaster.go @@ -0,0 +1,218 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncevent + +import ( + "gvisor.dev/gvisor/pkg/sync" +) + +// Broadcaster is an implementation of Source that supports any number of +// subscribed Receivers. +// +// The zero value of Broadcaster is valid and has no subscribed Receivers. +// Broadcaster is not copyable by value. +// +// All Broadcaster methods may be called concurrently from multiple goroutines. +type Broadcaster struct { + // Broadcaster is implemented as a hash table where keys are assigned by + // the Broadcaster and returned as SubscriptionIDs, making it safe to use + // the identity function for hashing. The hash table resolves collisions + // using linear probing and features Robin Hood insertion and backward + // shift deletion in order to support a relatively high load factor + // efficiently, which matters since the cost of Broadcast is linear in the + // size of the table. + + // mu protects the following fields. + mu sync.Mutex + + // Invariants: len(table) is 0 or a power of 2. + table []broadcasterSlot + + // load is the number of entries in table with receiver != nil. + load int + + lastID SubscriptionID +} + +type broadcasterSlot struct { + // Invariants: If receiver == nil, then filter == NoEvents and id == 0. + // Otherwise, id != 0. + receiver *Receiver + filter Set + id SubscriptionID +} + +const ( + broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1 + + broadcasterMaxLoadNum = 13 + broadcasterMaxLoadDen = 16 +) + +// SubscribeEvents implements Source.SubscribeEvents. +func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID { + b.mu.Lock() + + // Assign an ID for this subscription. + b.lastID++ + id := b.lastID + + // Expand the table if over the maximum load factor: + // + // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen + // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table) + b.load++ + if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) { + // Double the number of slots in the new table. + newlen := broadcasterMinNonZeroTableSize + if len(b.table) != 0 { + newlen = 2 * len(b.table) + } + if newlen <= cap(b.table) { + // Reuse excess capacity in the current table, moving entries not + // already in their first-probed positions to better ones. + newtable := b.table[:newlen] + newmask := uint64(newlen - 1) + for i := range b.table { + if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) { + entry := b.table[i] + b.table[i] = broadcasterSlot{} + broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter) + } + } + b.table = newtable + } else { + newtable := make([]broadcasterSlot, newlen) + // Copy existing entries to the new table. + for i := range b.table { + if b.table[i].receiver != nil { + broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter) + } + } + // Switch to the new table. + b.table = newtable + } + } + + broadcasterTableInsert(b.table, id, r, filter) + b.mu.Unlock() + return id +} + +// Preconditions: table must not be full. len(table) is a power of 2. +func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) { + entry := broadcasterSlot{ + receiver: r, + filter: filter, + id: id, + } + mask := uint64(len(table) - 1) + i := uint64(id) & mask + disp := uint64(0) + for { + if table[i].receiver == nil { + table[i] = entry + return + } + // If we've been displaced farther from our first-probed slot than the + // element stored in this one, swap elements and switch to inserting + // the replaced one. (This is Robin Hood insertion.) + slotDisp := (i - uint64(table[i].id)) & mask + if disp > slotDisp { + table[i], entry = entry, table[i] + disp = slotDisp + } + i = (i + 1) & mask + disp++ + } +} + +// UnsubscribeEvents implements Source.UnsubscribeEvents. +func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) { + b.mu.Lock() + + mask := uint64(len(b.table) - 1) + i := uint64(id) & mask + for { + if b.table[i].id == id { + // Found the element to remove. Move all subsequent elements + // backward until we either find an empty slot, or an element that + // is already in its first-probed slot. (This is backward shift + // deletion.) + for { + next := (i + 1) & mask + if b.table[next].receiver == nil { + break + } + if uint64(b.table[next].id)&mask == next { + break + } + b.table[i] = b.table[next] + i = next + } + b.table[i] = broadcasterSlot{} + break + } + i = (i + 1) & mask + } + + // If a table 1/4 of the current size would still be at or under the + // maximum load factor (i.e. the current table size is at least two + // expansions bigger than necessary), halve the size of the table to reduce + // the cost of Broadcast. Since we are concerned with iteration time and + // not memory usage, reuse the existing slice to reduce future allocations + // from table re-expansion. + b.load-- + if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) { + newlen := len(b.table) / 2 + newtable := b.table[:newlen] + for i := newlen; i < len(b.table); i++ { + if b.table[i].receiver != nil { + broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter) + b.table[i] = broadcasterSlot{} + } + } + b.table = newtable + } + + b.mu.Unlock() +} + +// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset +// of events to which they subscribed. The order in which Receivers are +// notified is unspecified. +func (b *Broadcaster) Broadcast(events Set) { + b.mu.Lock() + for i := range b.table { + if intersection := events & b.table[i].filter; intersection != 0 { + // We don't need to check if broadcasterSlot.receiver is nil, since + // if it is then broadcasterSlot.filter is 0. + b.table[i].receiver.Notify(intersection) + } + } + b.mu.Unlock() +} + +// FilteredEvents returns the set of events for which Broadcast will notify at +// least one Receiver, i.e. the union of filters for all subscribed Receivers. +func (b *Broadcaster) FilteredEvents() Set { + var es Set + b.mu.Lock() + for i := range b.table { + es |= b.table[i].filter + } + b.mu.Unlock() + return es +} diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go new file mode 100644 index 000000000..e88779e23 --- /dev/null +++ b/pkg/syncevent/broadcaster_test.go @@ -0,0 +1,376 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncevent + +import ( + "fmt" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/waiter" +) + +func TestBroadcasterFilter(t *testing.T) { + const numReceivers = 2 * MaxEvents + + var br Broadcaster + ws := make([]Waiter, numReceivers) + for i := range ws { + ws[i].Init() + br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents)) + } + for ev := 0; ev < MaxEvents; ev++ { + br.Broadcast(1 << ev) + for i := range ws { + want := NoEvents + if i%MaxEvents == ev { + want = 1 << ev + } + if got := ws[i].Receiver().PendingAndAckAll(); got != want { + t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want) + } + } + } +} + +// TestBroadcasterManySubscriptions tests that subscriptions are not lost by +// table expansion/compaction. +func TestBroadcasterManySubscriptions(t *testing.T) { + const numReceivers = 5000 // arbitrary + + var br Broadcaster + ws := make([]Waiter, numReceivers) + for i := range ws { + ws[i].Init() + } + + ids := make([]SubscriptionID, numReceivers) + for i := 0; i < numReceivers; i++ { + // Subscribe receiver i. + ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1) + // Check that receivers [0, i] are subscribed. + br.Broadcast(1) + for j := 0; j <= i; j++ { + if ws[j].Pending() != 1 { + t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i) + } + ws[j].Ack(1) + } + } + + // Generate a random order for unsubscriptions. + unsub := rand.Perm(numReceivers) + for i := 0; i < numReceivers; i++ { + // Unsubscribe receiver unsub[i]. + br.UnsubscribeEvents(ids[unsub[i]]) + // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that + // receivers (unsub[i], unsub[numReceivers]) are still subscribed. + br.Broadcast(1) + for j := 0; j <= i; j++ { + if ws[unsub[j]].Pending() != 0 { + t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i]) + } + } + for j := i + 1; j < numReceivers; j++ { + if ws[unsub[j]].Pending() != 1 { + t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i]) + } + ws[unsub[j]].Ack(1) + } + } +} + +var ( + receiverCountsNonZero = []int{1, 4, 16, 64} + receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...) +) + +// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage +// pattern X (described in terms of Broadcaster) with Broadcaster, a +// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively. + +// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe +// cycle. + +func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) { + var br Broadcaster + var w Waiter + w.Init() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := br.SubscribeEvents(w.Receiver(), 1) + br.UnsubscribeEvents(id) + } +} + +func BenchmarkMapSubscribeUnsubscribe(b *testing.B) { + var mu sync.Mutex + m := make(map[*Receiver]Set) + var w Waiter + w.Init() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + m[w.Receiver()] = Set(1) + mu.Unlock() + mu.Lock() + delete(m, w.Receiver()) + mu.Unlock() + } +} + +func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) { + var q waiter.Queue + e, _ := waiter.NewChannelEntry(nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + q.EventRegister(&e, 1) + q.EventUnregister(&e) + } +} + +// BenchmarkXxxSubscribeUnsubscribeBatch is similar to +// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large +// number of Receivers at a time in order to measure the amortized overhead of +// table expansion/compaction. (Since waiter.Queue is implemented using a +// linked list, BenchmarkQueueSubscribeUnsubscribe and +// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same +// result.) + +const numBatchReceivers = 1000 + +func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) { + var br Broadcaster + ws := make([]Waiter, numBatchReceivers) + for i := range ws { + ws[i].Init() + } + ids := make([]SubscriptionID, numBatchReceivers) + + // Generate a random order for unsubscriptions. + unsub := rand.Perm(numBatchReceivers) + + b.ResetTimer() + for i := 0; i < b.N/numBatchReceivers; i++ { + for j := 0; j < numBatchReceivers; j++ { + ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1) + } + for j := 0; j < numBatchReceivers; j++ { + br.UnsubscribeEvents(ids[unsub[j]]) + } + } +} + +func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) { + var mu sync.Mutex + m := make(map[*Receiver]Set) + ws := make([]Waiter, numBatchReceivers) + for i := range ws { + ws[i].Init() + } + + // Generate a random order for unsubscriptions. + unsub := rand.Perm(numBatchReceivers) + + b.ResetTimer() + for i := 0; i < b.N/numBatchReceivers; i++ { + for j := 0; j < numBatchReceivers; j++ { + mu.Lock() + m[ws[j].Receiver()] = Set(1) + mu.Unlock() + } + for j := 0; j < numBatchReceivers; j++ { + mu.Lock() + delete(m, ws[unsub[j]].Receiver()) + mu.Unlock() + } + } +} + +func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) { + var q waiter.Queue + es := make([]waiter.Entry, numBatchReceivers) + for i := range es { + es[i], _ = waiter.NewChannelEntry(nil) + } + + // Generate a random order for unsubscriptions. + unsub := rand.Perm(numBatchReceivers) + + b.ResetTimer() + for i := 0; i < b.N/numBatchReceivers; i++ { + for j := 0; j < numBatchReceivers; j++ { + q.EventRegister(&es[j], 1) + } + for j := 0; j < numBatchReceivers; j++ { + q.EventUnregister(&es[unsub[j]]) + } + } +} + +// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast +// already-pending events to multiple Receivers. + +func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) { + for _, n := range receiverCountsIncludingZero { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var br Broadcaster + ws := make([]Waiter, n) + for i := range ws { + ws[i].Init() + br.SubscribeEvents(ws[i].Receiver(), 1) + } + br.Broadcast(1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + br.Broadcast(1) + } + }) + } +} + +func BenchmarkMapBroadcastRedundant(b *testing.B) { + for _, n := range receiverCountsIncludingZero { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var mu sync.Mutex + m := make(map[*Receiver]Set) + ws := make([]Waiter, n) + for i := range ws { + ws[i].Init() + m[ws[i].Receiver()] = Set(1) + } + mu.Lock() + for r := range m { + r.Notify(1) + } + mu.Unlock() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + for r := range m { + r.Notify(1) + } + mu.Unlock() + } + }) + } +} + +func BenchmarkQueueBroadcastRedundant(b *testing.B) { + for _, n := range receiverCountsIncludingZero { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var q waiter.Queue + for i := 0; i < n; i++ { + e, _ := waiter.NewChannelEntry(nil) + q.EventRegister(&e, 1) + } + q.Notify(1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + q.Notify(1) + } + }) + } +} + +// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to +// multiple Receivers, check that all Receivers have received the event, and +// clear the event from all Receivers. + +func BenchmarkBroadcasterBroadcastAck(b *testing.B) { + for _, n := range receiverCountsNonZero { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var br Broadcaster + ws := make([]Waiter, n) + for i := range ws { + ws[i].Init() + br.SubscribeEvents(ws[i].Receiver(), 1) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + br.Broadcast(1) + for j := range ws { + if got, want := ws[j].Pending(), Set(1); got != want { + b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want) + } + ws[j].Ack(1) + } + } + }) + } +} + +func BenchmarkMapBroadcastAck(b *testing.B) { + for _, n := range receiverCountsNonZero { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var mu sync.Mutex + m := make(map[*Receiver]Set) + ws := make([]Waiter, n) + for i := range ws { + ws[i].Init() + m[ws[i].Receiver()] = Set(1) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + for r := range m { + r.Notify(1) + } + mu.Unlock() + for j := range ws { + if got, want := ws[j].Pending(), Set(1); got != want { + b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want) + } + ws[j].Ack(1) + } + } + }) + } +} + +func BenchmarkQueueBroadcastAck(b *testing.B) { + for _, n := range receiverCountsNonZero { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var q waiter.Queue + chs := make([]chan struct{}, n) + for i := range chs { + e, ch := waiter.NewChannelEntry(nil) + q.EventRegister(&e, 1) + chs[i] = ch + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + q.Notify(1) + for _, ch := range chs { + select { + case <-ch: + default: + b.Fatalf("channel did not receive event") + } + } + } + }) + } +} diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go new file mode 100644 index 000000000..5c86e5400 --- /dev/null +++ b/pkg/syncevent/receiver.go @@ -0,0 +1,103 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncevent + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/atomicbitops" +) + +// Receiver is an event sink that holds pending events and invokes a callback +// whenever new events become pending. Receiver's methods may be called +// concurrently from multiple goroutines. +// +// Receiver.Init() must be called before first use. +type Receiver struct { + // pending is the set of pending events. pending is accessed using atomic + // memory operations. + pending uint64 + + // cb is notified when new events become pending. cb is immutable after + // Init(). + cb ReceiverCallback +} + +// ReceiverCallback receives callbacks from a Receiver. +type ReceiverCallback interface { + // NotifyPending is called when the corresponding Receiver has new pending + // events. + // + // NotifyPending is called synchronously from Receiver.Notify(), so + // implementations must not take locks that may be held by callers of + // Receiver.Notify(). NotifyPending may be called concurrently from + // multiple goroutines. + NotifyPending() +} + +// Init must be called before first use of r. +func (r *Receiver) Init(cb ReceiverCallback) { + r.cb = cb +} + +// Pending returns the set of pending events. +func (r *Receiver) Pending() Set { + return Set(atomic.LoadUint64(&r.pending)) +} + +// Notify sets the given events as pending. +func (r *Receiver) Notify(es Set) { + p := Set(atomic.LoadUint64(&r.pending)) + // Optimization: Skip the atomic CAS on r.pending if all events are + // already pending. + if p&es == es { + return + } + // When this is uncontended (the common case), CAS is faster than + // atomic-OR because the former is inlined and the latter (which we + // implement in assembly ourselves) is not. + if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) { + // If the CAS fails, fall back to atomic-OR. + atomicbitops.OrUint64(&r.pending, uint64(es)) + } + r.cb.NotifyPending() +} + +// Ack unsets the given events as pending. +func (r *Receiver) Ack(es Set) { + p := Set(atomic.LoadUint64(&r.pending)) + // Optimization: Skip the atomic CAS on r.pending if all events are + // already not pending. + if p&es == 0 { + return + } + // When this is uncontended (the common case), CAS is faster than + // atomic-AND because the former is inlined and the latter (which we + // implement in assembly ourselves) is not. + if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) { + // If the CAS fails, fall back to atomic-AND. + atomicbitops.AndUint64(&r.pending, ^uint64(es)) + } +} + +// PendingAndAckAll unsets all events as pending and returns the set of +// previously-pending events. +// +// PendingAndAckAll should only be used in preference to a call to Pending +// followed by a conditional call to Ack when the caller expects events to be +// pending (e.g. after a call to ReceiverCallback.NotifyPending()). +func (r *Receiver) PendingAndAckAll() Set { + return Set(atomic.SwapUint64(&r.pending, 0)) +} diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go new file mode 100644 index 000000000..ddffb171a --- /dev/null +++ b/pkg/syncevent/source.go @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncevent + +// Source represents an event source. +type Source interface { + // SubscribeEvents causes the Source to notify the given Receiver of the + // given subset of events. + // + // Preconditions: r != nil. The ReceiverCallback for r must not take locks + // that are ordered prior to the Source; for example, it cannot call any + // Source methods. + SubscribeEvents(r *Receiver, filter Set) SubscriptionID + + // UnsubscribeEvents causes the Source to stop notifying the Receiver + // subscribed by a previous call to SubscribeEvents that returned the given + // SubscriptionID. + // + // Preconditions: UnsubscribeEvents may be called at most once for any + // given SubscriptionID. + UnsubscribeEvents(id SubscriptionID) +} + +// SubscriptionID identifies a call to Source.SubscribeEvents. +type SubscriptionID uint64 + +// UnsubscribeAndAck is a convenience function that unsubscribes r from the +// given events from src and also clears them from r. +func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) { + src.UnsubscribeEvents(id) + r.Ack(filter) +} + +// NoopSource implements Source by never sending events to subscribed +// Receivers. +type NoopSource struct{} + +// SubscribeEvents implements Source.SubscribeEvents. +func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID { + return 0 +} + +// UnsubscribeEvents implements Source.UnsubscribeEvents. +func (NoopSource) UnsubscribeEvents(SubscriptionID) { +} + +// See Broadcaster for a non-noop implementations of Source. diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go new file mode 100644 index 000000000..9fb6a06de --- /dev/null +++ b/pkg/syncevent/syncevent.go @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package syncevent provides efficient primitives for goroutine +// synchronization based on event bitmasks. +package syncevent + +// Set is a bitmask where each bit represents a distinct user-defined event. +// The event package does not treat any bits in Set specially. +type Set uint64 + +const ( + // NoEvents is a Set containing no events. + NoEvents = Set(0) + + // AllEvents is a Set containing all possible events. + AllEvents = ^Set(0) + + // MaxEvents is the number of distinct events that can be represented by a Set. + MaxEvents = 64 +) diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go new file mode 100644 index 000000000..bfb18e2ea --- /dev/null +++ b/pkg/syncevent/syncevent_example_test.go @@ -0,0 +1,108 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncevent + +import ( + "fmt" + "sync/atomic" + "time" +) + +func Example_ioReadinessInterrputible() { + const ( + evReady = Set(1 << iota) + evInterrupt + ) + errNotReady := fmt.Errorf("not ready for I/O") + + // State of some I/O object. + var ( + br Broadcaster + ready uint32 + ) + doIO := func() error { + if atomic.LoadUint32(&ready) == 0 { + return errNotReady + } + return nil + } + go func() { + // The I/O object eventually becomes ready for I/O. + time.Sleep(100 * time.Millisecond) + // When it does, it first ensures that future calls to isReady() return + // true, then broadcasts the readiness event to Receivers. + atomic.StoreUint32(&ready, 1) + br.Broadcast(evReady) + }() + + // Each user of the I/O object owns a Waiter. + var w Waiter + w.Init() + // The Waiter may be asynchronously interruptible, e.g. for signal + // handling in the sentry. + go func() { + time.Sleep(200 * time.Millisecond) + w.Receiver().Notify(evInterrupt) + }() + + // To use the I/O object: + // + // Optionally, if the I/O object is likely to be ready, attempt I/O first. + err := doIO() + if err == nil { + // Success, we're done. + return /* nil */ + } + if err != errNotReady { + // Failure, I/O failed for some reason other than readiness. + return /* err */ + } + // Subscribe for readiness events from the I/O object. + id := br.SubscribeEvents(w.Receiver(), evReady) + // When we are finished blocking, unsubscribe from readiness events and + // remove readiness events from the pending event set. + defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id) + for { + // Attempt I/O again. This must be done after the call to SubscribeEvents, + // since the I/O object might have become ready between the previous call + // to doIO and the call to SubscribeEvents. + err = doIO() + if err == nil { + return /* nil */ + } + if err != errNotReady { + return /* err */ + } + // Block until either the I/O object indicates it is ready, or we are + // interrupted. + events := w.Wait() + if events&evInterrupt != 0 { + // In the specific case of sentry signal handling, signal delivery + // is handled by another system, so we aren't responsible for + // acknowledging evInterrupt. + return /* errInterrupted */ + } + // Note that, in a concurrent context, the I/O object might become + // ready and then not ready again. To handle this: + // + // - evReady must be acknowledged before calling doIO() again (rather + // than after), so that if the I/O object becomes ready *again* after + // the call to doIO(), the readiness event is not lost. + // + // - We must loop instead of just calling doIO() once after receiving + // evReady. + w.Ack(evReady) + } +} diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s new file mode 100644 index 000000000..985b56ae5 --- /dev/null +++ b/pkg/syncevent/waiter_amd64.s @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// See waiter_noasm_unsafe.go for a description of waiterUnlock. +// +// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool +TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 + MOVQ g+0(FP), DI + MOVQ wg+8(FP), SI + + MOVQ $·preparingG(SB), AX + LOCK + CMPXCHGQ DI, 0(SI) + + SETEQ AX + MOVB AX, ret+16(FP) + + RET + diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s new file mode 100644 index 000000000..20d7ac23b --- /dev/null +++ b/pkg/syncevent/waiter_arm64.s @@ -0,0 +1,34 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// See waiter_noasm_unsafe.go for a description of waiterUnlock. +// +// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool +TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 + MOVD wg+8(FP), R0 + MOVD $·preparingG(SB), R1 + MOVD g+0(FP), R2 +again: + LDAXR (R0), R3 + CMP R1, R3 + BNE ok + STLXR R2, (R0), R3 + CBNZ R3, again +ok: + CSET EQ, R0 + MOVB R0, ret+16(FP) + RET + diff --git a/pkg/fspath/builder_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go index 75606808d..0995e9053 100644 --- a/pkg/fspath/builder_unsafe.go +++ b/pkg/syncevent/waiter_asm_unsafe.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package fspath +// +build amd64 arm64 + +package syncevent import ( "unsafe" ) -// String returns the accumulated string. No other methods should be called -// after String. -func (b *Builder) String() string { - bs := b.buf[b.start:] - // Compare strings.Builder.String(). - return *(*string)(unsafe.Pointer(&bs)) -} +// See waiter_noasm_unsafe.go for a description of waiterUnlock. +func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go new file mode 100644 index 000000000..1c4b0e39a --- /dev/null +++ b/pkg/syncevent/waiter_noasm_unsafe.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// waiterUnlock is called from g0, so when the race detector is enabled, +// waiterUnlock must be implemented in assembly since no race context is +// available. +// +// +build !race +// +build !amd64,!arm64 + +package syncevent + +import ( + "sync/atomic" + "unsafe" +) + +// waiterUnlock is the "unlock function" passed to runtime.gopark by +// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g. +// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping +// should be aborted. +// +//go:nosplit +func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool { + // The only way this CAS can fail is if a call to Waiter.NotifyPending() + // has replaced *wg with nil, in which case we should not sleep. + return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g) +} diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go new file mode 100644 index 000000000..3c8cbcdd8 --- /dev/null +++ b/pkg/syncevent/waiter_test.go @@ -0,0 +1,414 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncevent + +import ( + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" +) + +func TestWaiterAlreadyPending(t *testing.T) { + var w Waiter + w.Init() + want := Set(1) + w.Notify(want) + if got := w.Wait(); got != want { + t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want) + } +} + +func TestWaiterAsyncNotify(t *testing.T) { + var w Waiter + w.Init() + want := Set(1) + go func() { + time.Sleep(100 * time.Millisecond) + w.Notify(want) + }() + if got := w.Wait(); got != want { + t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want) + } +} + +func TestWaiterWaitFor(t *testing.T) { + var w Waiter + w.Init() + evWaited := Set(1) + evOther := Set(2) + w.Notify(evOther) + notifiedEvent := uint32(0) + go func() { + time.Sleep(100 * time.Millisecond) + atomic.StoreUint32(¬ifiedEvent, 1) + w.Notify(evWaited) + }() + if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want { + t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want) + } + if atomic.LoadUint32(¬ifiedEvent) == 0 { + t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event") + } +} + +func TestWaiterWaitAndAckAll(t *testing.T) { + var w Waiter + w.Init() + w.Notify(AllEvents) + if got := w.WaitAndAckAll(); got != AllEvents { + t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents) + } + if got := w.Pending(); got != NoEvents { + t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got) + } +} + +// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage +// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and +// buffered chan struct{} respectively. When the maximum number of event +// sources is relevant, we use 3 event sources because this is representative +// of the kernel.Task.block() use case: an interrupt source, a timeout source, +// and the actual event source being waited on. + +// Event set used by most benchmarks. +const evBench Set = 1 + +// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of +// an event that is already pending. + +func BenchmarkWaiterNotifyRedundant(b *testing.B) { + var w Waiter + w.Init() + w.Notify(evBench) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w.Notify(evBench) + } +} + +func BenchmarkSleeperNotifyRedundant(b *testing.B) { + var s sleep.Sleeper + var w sleep.Waker + s.AddWaker(&w, 0) + w.Assert() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w.Assert() + } +} + +func BenchmarkChannelNotifyRedundant(b *testing.B) { + ch := make(chan struct{}, 1) + ch <- struct{}{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + select { + case ch <- struct{}{}: + default: + } + } +} + +// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an +// event, return that event using a blocking check, and then unset the event as +// pending. + +func BenchmarkWaiterNotifyWaitAck(b *testing.B) { + var w Waiter + w.Init() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w.Notify(evBench) + w.Wait() + w.Ack(evBench) + } +} + +func BenchmarkSleeperNotifyWaitAck(b *testing.B) { + var s sleep.Sleeper + var w sleep.Waker + s.AddWaker(&w, 0) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w.Assert() + s.Fetch(true) + } +} + +func BenchmarkChannelNotifyWaitAck(b *testing.B) { + ch := make(chan struct{}, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // notify + select { + case ch <- struct{}{}: + default: + } + + // wait + ack + <-ch + } +} + +// BenchmarkSleeperMultiNotifyWaitAck is equivalent to +// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a +// temporary sleep.Waker. This is necessary when multiple goroutines may wait +// for the same event, since each sleep.Waker can wake only a single +// sleep.Sleeper. +// +// The syncevent package does not require a distinct object for each +// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and +// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state +// for channels, runtime.sudog, is inescapably runtime-allocated, so +// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would +// also be identical. + +func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) { + var s sleep.Sleeper + // The sleep package doesn't provide sync.Pool allocation of Wakers; + // we do for a fairer comparison. + wakerPool := sync.Pool{ + New: func() interface{} { + return &sleep.Waker{} + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w := wakerPool.Get().(*sleep.Waker) + s.AddWaker(w, 0) + w.Assert() + s.Fetch(true) + s.Done() + wakerPool.Put(w) + } +} + +// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also +// includes allocation of a temporary Waiter. This models the case where a +// goroutine not already associated with a Waiter needs one in order to block. +// +// The analogous state for channels is built into runtime.g, so +// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be +// identical. + +func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + w := GetWaiter() + w.Notify(evBench) + w.Wait() + w.Ack(evBench) + PutWaiter(w) + } +} + +func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) { + // The sleep package doesn't provide sync.Pool allocation of Sleepers; + // we do for a fairer comparison. + sleeperPool := sync.Pool{ + New: func() interface{} { + return &sleep.Sleeper{} + }, + } + var w sleep.Waker + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s := sleeperPool.Get().(*sleep.Sleeper) + s.AddWaker(&w, 0) + w.Assert() + s.Fetch(true) + s.Done() + sleeperPool.Put(s) + } +} + +// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows +// for multiple event sources. + +func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) { + var w Waiter + w.Init() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w.Notify(evBench) + if e := w.Wait(); e != evBench { + b.Fatalf("Wait: got %#x, wanted %#x", e, evBench) + } + w.Ack(evBench) + } +} + +func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) { + var s sleep.Sleeper + var ws [3]sleep.Waker + for i := range ws { + s.AddWaker(&ws[i], i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ws[0].Assert() + if id, _ := s.Fetch(true); id != 0 { + b.Fatalf("Fetch: got %d, wanted 0", id) + } + } +} + +func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) { + ch0 := make(chan struct{}, 1) + ch1 := make(chan struct{}, 1) + ch2 := make(chan struct{}, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // notify + select { + case ch0 <- struct{}{}: + default: + } + + // wait + clear + select { + case <-ch0: + // ok + case <-ch1: + b.Fatalf("received from ch1") + case <-ch2: + b.Fatalf("received from ch2") + } + } +} + +// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an +// event while another goroutine signals the event. This assumes that a new +// goroutine doesn't run immediately (i.e. the creator of a new goroutine is +// allowed to go to sleep before the new goroutine has a chance to run). + +func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) { + var w Waiter + w.Init() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + w.Notify(1) + }() + w.Wait() + w.Ack(evBench) + } +} + +func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) { + var s sleep.Sleeper + var w sleep.Waker + s.AddWaker(&w, 0) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + w.Assert() + }() + s.Fetch(true) + } +} + +func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) { + ch := make(chan struct{}, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + select { + case ch <- struct{}{}: + default: + } + }() + <-ch + } +} + +// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but +// allows for multiple event sources. + +func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) { + var w Waiter + w.Init() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + w.Notify(evBench) + }() + if e := w.Wait(); e != evBench { + b.Fatalf("Wait: got %#x, wanted %#x", e, evBench) + } + w.Ack(evBench) + } +} + +func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) { + var s sleep.Sleeper + var ws [3]sleep.Waker + for i := range ws { + s.AddWaker(&ws[i], i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + ws[0].Assert() + }() + if id, _ := s.Fetch(true); id != 0 { + b.Fatalf("Fetch: got %d, expected 0", id) + } + } +} + +func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) { + ch0 := make(chan struct{}, 1) + ch1 := make(chan struct{}, 1) + ch2 := make(chan struct{}, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + select { + case ch0 <- struct{}{}: + default: + } + }() + + select { + case <-ch0: + // ok + case <-ch1: + b.Fatalf("received from ch1") + case <-ch2: + b.Fatalf("received from ch2") + } + } +} diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go new file mode 100644 index 000000000..112e0e604 --- /dev/null +++ b/pkg/syncevent/waiter_unsafe.go @@ -0,0 +1,206 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.11 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +package syncevent + +import ( + "sync/atomic" + "unsafe" + + "gvisor.dev/gvisor/pkg/sync" +) + +//go:linkname gopark runtime.gopark +func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int) + +//go:linkname goready runtime.goready +func goready(g unsafe.Pointer, traceskip int) + +const ( + waitReasonSelect = 9 // Go: src/runtime/runtime2.go + traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go +) + +// Waiter allows a goroutine to block on pending events received by a Receiver. +// +// Waiter.Init() must be called before first use. +type Waiter struct { + r Receiver + + // g is one of: + // + // - nil: No goroutine is blocking in Wait. + // + // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet + // completed waiterUnlock(). Thus the wait can only be interrupted by + // replacing the value of g with nil (the G may not be in state Gwaiting + // yet, so we can't call goready.) + // + // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the + // goroutine blocked in Wait, which can only be woken by calling goready. + g unsafe.Pointer `state:"zerovalue"` +} + +// Sentinel object for Waiter.g. +var preparingG struct{} + +// Init must be called before first use of w. +func (w *Waiter) Init() { + w.r.Init(w) +} + +// Receiver returns the Receiver that receives events that unblock calls to +// w.Wait(). +func (w *Waiter) Receiver() *Receiver { + return &w.r +} + +// Pending returns the set of pending events. +func (w *Waiter) Pending() Set { + return w.r.Pending() +} + +// Wait blocks until at least one event is pending, then returns the set of +// pending events. It does not affect the set of pending events; callers must +// call w.Ack() to do so, or use w.WaitAndAck() instead. +// +// Precondition: Only one goroutine may call any Wait* method at a time. +func (w *Waiter) Wait() Set { + return w.WaitFor(AllEvents) +} + +// WaitFor blocks until at least one event in es is pending, then returns the +// set of pending events (including those not in es). It does not affect the +// set of pending events; callers must call w.Ack() to do so. +// +// Precondition: Only one goroutine may call any Wait* method at a time. +func (w *Waiter) WaitFor(es Set) Set { + for { + // Optimization: Skip the atomic store to w.g if an event is already + // pending. + if p := w.r.Pending(); p&es != NoEvents { + return p + } + + // Indicate that we're preparing to go to sleep. + atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG)) + + // If an event is pending, abort the sleep. + if p := w.r.Pending(); p&es != NoEvents { + atomic.StorePointer(&w.g, nil) + return p + } + + // If w.g is still preparingG (i.e. w.NotifyPending() has not been + // called or has not reached atomic.SwapPointer()), go to sleep until + // w.NotifyPending() => goready(). + gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0) + } +} + +// Ack marks the given events as not pending. +func (w *Waiter) Ack(es Set) { + w.r.Ack(es) +} + +// WaitAndAckAll blocks until at least one event is pending, then marks all +// events as not pending and returns the set of previously-pending events. +// +// Precondition: Only one goroutine may call any Wait* method at a time. +func (w *Waiter) WaitAndAckAll() Set { + // Optimization: Skip the atomic store to w.g if an event is already + // pending. Call Pending() first since, in the common case that events are + // not yet pending, this skips an atomic swap on w.r.pending. + if w.r.Pending() != NoEvents { + if p := w.r.PendingAndAckAll(); p != NoEvents { + return p + } + } + + for { + // Indicate that we're preparing to go to sleep. + atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG)) + + // If an event is pending, abort the sleep. + if w.r.Pending() != NoEvents { + if p := w.r.PendingAndAckAll(); p != NoEvents { + atomic.StorePointer(&w.g, nil) + return p + } + } + + // If w.g is still preparingG (i.e. w.NotifyPending() has not been + // called or has not reached atomic.SwapPointer()), go to sleep until + // w.NotifyPending() => goready(). + gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0) + + // Check for pending events. We call PendingAndAckAll() directly now since + // we only expect to be woken after events become pending. + if p := w.r.PendingAndAckAll(); p != NoEvents { + return p + } + } +} + +// Notify marks the given events as pending, possibly unblocking concurrent +// calls to w.Wait() or w.WaitFor(). +func (w *Waiter) Notify(es Set) { + w.r.Notify(es) +} + +// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter +// should not call NotifyPending. +func (w *Waiter) NotifyPending() { + // Optimization: Skip the atomic swap on w.g if there is no sleeping + // goroutine. NotifyPending is called after w.r.Pending() is updated, so + // concurrent and future calls to w.Wait() will observe pending events and + // abort sleeping. + if atomic.LoadPointer(&w.g) == nil { + return + } + // Wake a sleeping G, or prevent a G that is preparing to sleep from doing + // so. Swap is needed here to ensure that only one call to NotifyPending + // calls goready. + if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) { + goready(g, 0) + } +} + +var waiterPool = sync.Pool{ + New: func() interface{} { + w := &Waiter{} + w.Init() + return w + }, +} + +// GetWaiter returns an unused Waiter. PutWaiter should be called to release +// the Waiter once it is no longer needed. +// +// Where possible, users should prefer to associate each goroutine that calls +// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of +// Waiters in hot paths. +func GetWaiter() *Waiter { + return waiterPool.Get().(*Waiter) +} + +// PutWaiter releases an unused Waiter previously returned by GetWaiter. +func PutWaiter(w *Waiter) { + waiterPool.Put(w) +} diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 2269f6237..4b5a0fca6 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -29,6 +29,7 @@ var ( EACCES = error(syscall.EACCES) EAGAIN = error(syscall.EAGAIN) EBADF = error(syscall.EBADF) + EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) ECONNREFUSED = error(syscall.ECONNREFUSED) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ea0a0409a..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 150310c11..17e94c562 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { vv.views = append(vv.views, vv2.views...) vv.size += vv2.size } + +// AppendView appends the given view into this vectorised view. +func (vv *VectorisedView) AppendView(v View) { + vv.views = append(vv.views, v) + vv.size += len(v) +} diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 4d6ae0871..c6c160dfc 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -161,6 +161,20 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTClass creates a checker that checks the TCLASS field in +// ControlMessages. +func ReceiveTClass(want uint32) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTClass { + t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) + } + if got := cm.TClass; got != want { + t.Fatalf("got cm.TClass = %d, want %d", got, want) + } + } +} + // ReceiveTOS creates a checker that checks the TOS field in ControlMessages. func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index f7dc4f720..80ddbd442 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -156,25 +156,54 @@ func EmptyNatTable() Table { } } +// A chainVerdict is what a table decides should be done with a packet. +type chainVerdict int + +const ( + // chainAccept indicates the packet should continue through netstack. + chainAccept chainVerdict = iota + + // chainAccept indicates the packet should be dropped. + chainDrop + + // chainReturn indicates the packet should return to the calling chain + // or the underflow rule of a builtin chain. + chainReturn +) + + // Check runs pkt through the rules for hook. It returns true when the packet // should continue traversing the network stack and false when it should be // dropped. // // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { - // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because - // we're missing features. Jumps, the call stack, etc. aren't checked - // for yet because we're yet to support them. - // Go through each table containing the hook. for _, tablename := range it.Priorities[hook] { - switch verdict := it.checkTable(hook, pkt, tablename); verdict { + table := it.Tables[tablename] + ruleIdx := table.BuiltinChains[hook] + switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { // If the table returns Accept, move on to the next table. - case TableAccept: + case chainAccept: continue // The Drop verdict is final. - case TableDrop: + case chainDrop: return false + case chainReturn: + // Any Return from a built-in chain means we have to + // call the underflow. + underflow := table.Rules[table.Underflows[hook]] + switch v, _ := underflow.Target.Action(pkt); v { + case RuleAccept: + continue + case RuleDrop: + return false + case RuleJump, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + default: panic(fmt.Sprintf("Unknown verdict %v.", verdict)) } @@ -185,37 +214,37 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { } // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) TableVerdict { +func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. - table := it.Tables[tablename] - for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ { - switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict { + for ruleIdx < len(table.Rules) { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { case RuleAccept: - return TableAccept + return chainAccept case RuleDrop: - return TableDrop - - case RuleContinue: - continue + return chainDrop case RuleReturn: - // TODO(gvisor.dev/issue/170): We don't implement jump - // yet, so any Return is from a built-in chain. That - // means we have to to call the underflow. - underflow := table.Rules[table.Underflows[hook]] - // Underflow is guaranteed to be an unconditional - // ACCEPT or DROP. - switch v, _ := underflow.Target.Action(pkt, underflow.Filter); v { - case RuleAccept: - return TableAccept - case RuleDrop: - return TableDrop - case RuleContinue, RuleReturn: - panic("Underflows should only return RuleAccept or RuleDrop.") + return chainReturn + + case RuleJump: + // "Jumping" to the next rule just means we're + // continuing on down the list. + if jumpTo == ruleIdx+1 { + ruleIdx++ + continue + } + switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + case chainAccept: + return chainAccept + case chainDrop: + return chainDrop + case chainReturn: + ruleIdx++ + continue default: - panic(fmt.Sprintf("Unknown verdict: %d", v)) + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) } default: @@ -226,11 +255,11 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri // We got through the entire table without a decision. Default to DROP // for safety. - return TableDrop + return chainDrop } // Precondition: pk.NetworkHeader is set. -func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) RuleVerdict { +func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in @@ -242,7 +271,8 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru // First check whether the packet matches the IP header filter. // TODO(gvisor.dev/issue/170): Support other fields of the filter. if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() { - return RuleContinue + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 } // Go through each rule matcher. If they all match, run @@ -250,14 +280,14 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru for _, matcher := range rule.Matchers { matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { - return RuleDrop + return RuleDrop, 0 } if !matches { - return RuleContinue + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 } } // All the matchers matched, so run the target. - verdict, _ := rule.Target.Action(pkt, rule.Filter) - return verdict + return rule.Target.Action(pkt, rule.Filter) } diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index a75938da3..5dbb28145 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This file contains various Targets. - package iptables import ( @@ -26,16 +24,16 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { - return RuleAccept, "" +func (AcceptTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) { + return RuleAccept, 0 } // DropTarget drops packets. type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { - return RuleDrop, "" +func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) { + return RuleDrop, 0 } // ErrorTarget logs an error and drops the packet. It represents a target that @@ -43,9 +41,9 @@ func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (Rule type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { +func (ErrorTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") - return RuleDrop, "" + return RuleDrop, 0 } // UserChainTarget marks a rule as the beginning of a user chain. @@ -54,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) { +func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -63,8 +61,8 @@ func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) { - return RuleReturn, "" +func (ReturnTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, int) { + return RuleReturn, 0 } // RedirectTarget redirects the packet by modifying the destination port/IP. @@ -88,13 +86,13 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { +func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) { headerView := pkt.Data.First() // Network header should be set. netHeader := header.IPv4(headerView) if netHeader == nil { - return RuleDrop, "" + return RuleDrop, 0 } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if @@ -111,7 +109,7 @@ func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) ( tcp := header.TCP(headerView[hlen:]) tcp.SetDestinationPort(rt.MinPort) default: - return RuleDrop, "" + return RuleDrop, 0 } - return RuleAccept, "" + return RuleAccept, 0 } diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 0102831d0..8bd3a2c94 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -56,17 +56,6 @@ const ( NumHooks ) -// A TableVerdict is what a table decides should be done with a packet. -type TableVerdict int - -const ( - // TableAccept indicates the packet should continue through netstack. - TableAccept TableVerdict = iota - - // TableDrop indicates the packet should be dropped. - TableDrop -) - // A RuleVerdict is what a rule decides should be done with a packet. type RuleVerdict int @@ -74,12 +63,12 @@ const ( // RuleAccept indicates the packet should continue through netstack. RuleAccept RuleVerdict = iota - // RuleContinue indicates the packet should continue to the next rule. - RuleContinue - // RuleDrop indicates the packet should be dropped. RuleDrop + // RuleJump indicates the packet should jump to another chain. + RuleJump + // RuleReturn indicates the packet should return to the previous chain. RuleReturn ) @@ -175,5 +164,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the name of the chain to jump to. - Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) + Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) } diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 3974c464e..b8b93e78e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -7,6 +7,7 @@ go_library( srcs = ["channel.go"], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 78d447acd..5944ba190 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -20,6 +20,7 @@ package channel import ( "context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,6 +34,118 @@ type PacketInfo struct { Route stack.Route } +// Notification is the interface for receiving notification from the packet +// queue. +type Notification interface { + // WriteNotify will be called when a write happens to the queue. + WriteNotify() +} + +// NotificationHandle is an opaque handle to the registered notification target. +// It can be used to unregister the notification when no longer interested. +// +// +stateify savable +type NotificationHandle struct { + n Notification +} + +type queue struct { + // mu protects fields below. + mu sync.RWMutex + // c is the outbound packet channel. Sending to c should hold mu. + c chan PacketInfo + numWrite int + numRead int + notify []*NotificationHandle +} + +func (q *queue) Close() { + close(q.c) +} + +func (q *queue) Read() (PacketInfo, bool) { + q.mu.Lock() + defer q.mu.Unlock() + select { + case p := <-q.c: + q.numRead++ + return p, true + default: + return PacketInfo{}, false + } +} + +func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { + // We have to receive from channel without holding the lock, since it can + // block indefinitely. This will cause a window that numWrite - numRead + // produces a larger number, but won't go to negative. numWrite >= numRead + // still holds. + select { + case pkt := <-q.c: + q.mu.Lock() + defer q.mu.Unlock() + q.numRead++ + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + +func (q *queue) Write(p PacketInfo) bool { + wrote := false + + // It's important to make sure nobody can see numWrite until we increment it, + // so numWrite >= numRead holds. + q.mu.Lock() + select { + case q.c <- p: + wrote = true + q.numWrite++ + default: + } + notify := q.notify + q.mu.Unlock() + + if wrote { + // Send notification outside of lock. + for _, h := range notify { + h.n.WriteNotify() + } + } + return wrote +} + +func (q *queue) Num() int { + q.mu.RLock() + defer q.mu.RUnlock() + n := q.numWrite - q.numRead + if n < 0 { + panic("numWrite < numRead") + } + return n +} + +func (q *queue) AddNotify(notify Notification) *NotificationHandle { + q.mu.Lock() + defer q.mu.Unlock() + h := &NotificationHandle{n: notify} + q.notify = append(q.notify, h) + return h +} + +func (q *queue) RemoveNotify(handle *NotificationHandle) { + q.mu.Lock() + defer q.mu.Unlock() + // Make a copy, since we reads the array outside of lock when notifying. + notify := make([]*NotificationHandle, 0, len(q.notify)) + for _, h := range q.notify { + if h != handle { + notify = append(notify, h) + } + } + q.notify = notify +} + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -41,14 +154,16 @@ type Endpoint struct { linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities - // c is where outbound packets are queued. - c chan PacketInfo + // Outbound packet queue. + q *queue } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - c: make(chan PacketInfo, size), + q: &queue{ + c: make(chan PacketInfo, size), + }, mtu: mtu, linkAddr: linkAddr, } @@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { // Close closes e. Further packet injections will panic. Reads continue to // succeed until all packets are read. func (e *Endpoint) Close() { - close(e.c) + e.q.Close() } -// Read does non-blocking read for one packet from the outbound packet queue. +// Read does non-blocking read one packet from the outbound packet queue. func (e *Endpoint) Read() (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - default: - return PacketInfo{}, false - } + return e.q.Read() } // ReadContext does blocking read for one packet from the outbound packet queue. // It can be cancelled by ctx, and in this case, it returns false. func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - case <-ctx.Done(): - return PacketInfo{}, false - } + return e.q.ReadContext(ctx) } // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { - select { - case <-e.c: - c++ - default: + if _, ok := e.Read(); !ok { return c } + c++ } } +// NumQueued returns the number of packet queued for outbound. +func (e *Endpoint) NumQueued() int { + return e.q.Num() +} + // InjectInbound injects an inbound packet. func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) @@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne Route: route, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } @@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac route.Release() payloadView := pkts[0].Data.ToView() n := 0 -packetLoop: for _, pkt := range pkts { off := pkt.DataOffset size := pkt.DataSize @@ -185,12 +289,10 @@ packetLoop: Route: route, } - select { - case e.c <- p: - n++ - default: - break packetLoop + if !e.q.Write(p) { + break } + n++ } return n, nil @@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { GSO: nil, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} + +// AddNotify adds a notification target for receiving event about outgoing +// packets. +func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { + return e.q.AddNotify(notify) +} + +// RemoveNotify removes handle from the list of notification targets. +func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { + e.q.RemoveNotify(handle) +} diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index e5096ea38..e0db6cf54 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -4,6 +4,22 @@ package(licenses = ["notice"]) go_library( name = "tun", - srcs = ["tun_unsafe.go"], + srcs = [ + "device.go", + "protocol.go", + "tun_unsafe.go", + ], visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/refs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], ) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go new file mode 100644 index 000000000..6ff47a742 --- /dev/null +++ b/pkg/tcpip/link/tun/device.go @@ -0,0 +1,352 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // drivers/net/tun.c:tun_net_init() + defaultDevMtu = 1500 + + // Queue length for outbound packet, arriving at fd side for read. Overflow + // causes packet drops. gVisor implementation-specific. + defaultDevOutQueueLen = 1024 +) + +var zeroMAC [6]byte + +// Device is an opened /dev/net/tun device. +// +// +stateify savable +type Device struct { + waiter.Queue + + mu sync.RWMutex `state:"nosave"` + endpoint *tunEndpoint + notifyHandle *channel.NotificationHandle + flags uint16 +} + +// beforeSave is invoked by stateify. +func (d *Device) beforeSave() { + d.mu.Lock() + defer d.mu.Unlock() + // TODO(b/110961832): Restore the device to stack. At this moment, the stack + // is not savable. + if d.endpoint != nil { + panic("/dev/net/tun does not support save/restore when a device is associated with it.") + } +} + +// Release implements fs.FileOperations.Release. +func (d *Device) Release() { + d.mu.Lock() + defer d.mu.Unlock() + + // Decrease refcount if there is an endpoint associated with this file. + if d.endpoint != nil { + d.endpoint.RemoveNotify(d.notifyHandle) + d.endpoint.DecRef() + d.endpoint = nil + } +} + +// SetIff services TUNSETIFF ioctl(2) request. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.endpoint != nil { + return syserror.EINVAL + } + + // Input validations. + isTun := flags&linux.IFF_TUN != 0 + isTap := flags&linux.IFF_TAP != 0 + supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) + if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { + return syserror.EINVAL + } + + prefix := "tun" + if isTap { + prefix = "tap" + } + + endpoint, err := attachOrCreateNIC(s, name, prefix) + if err != nil { + return syserror.EINVAL + } + + d.endpoint = endpoint + d.notifyHandle = d.endpoint.AddNotify(d) + d.flags = flags + return nil +} + +func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) { + for { + // 1. Try to attach to an existing NIC. + if name != "" { + if nic, found := s.GetNICByName(name); found { + endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if !ok { + // Not a NIC created by tun device. + return nil, syserror.EOPNOTSUPP + } + if !endpoint.TryIncRef() { + // Race detected: NIC got deleted in between. + continue + } + return endpoint, nil + } + } + + // 2. Creating a new NIC. + id := tcpip.NICID(s.UniqueID()) + endpoint := &tunEndpoint{ + Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), + stack: s, + nicID: id, + name: name, + } + if endpoint.name == "" { + endpoint.name = fmt.Sprintf("%s%d", prefix, id) + } + err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ + Name: endpoint.name, + }) + switch err { + case nil: + return endpoint, nil + case tcpip.ErrDuplicateNICID: + // Race detected: A NIC has been created in between. + continue + default: + return nil, syserror.EINVAL + } + } +} + +// Write inject one inbound packet to the network interface. +func (d *Device) Write(data []byte) (int64, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return 0, syserror.EBADFD + } + if !endpoint.IsAttached() { + return 0, syserror.EIO + } + + dataLen := int64(len(data)) + + // Packet information. + var pktInfoHdr PacketInfoHeader + if !d.hasFlags(linux.IFF_NO_PI) { + if len(data) < PacketInfoHeaderSize { + // Ignore bad packet. + return dataLen, nil + } + pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize]) + data = data[PacketInfoHeaderSize:] + } + + // Ethernet header (TAP only). + var ethHdr header.Ethernet + if d.hasFlags(linux.IFF_TAP) { + if len(data) < header.EthernetMinimumSize { + // Ignore bad packet. + return dataLen, nil + } + ethHdr = header.Ethernet(data[:header.EthernetMinimumSize]) + data = data[header.EthernetMinimumSize:] + } + + // Try to determine network protocol number, default zero. + var protocol tcpip.NetworkProtocolNumber + switch { + case pktInfoHdr != nil: + protocol = pktInfoHdr.Protocol() + case ethHdr != nil: + protocol = ethHdr.Type() + } + + // Try to determine remote link address, default zero. + var remote tcpip.LinkAddress + switch { + case ethHdr != nil: + remote = ethHdr.SourceAddress() + default: + remote = tcpip.LinkAddress(zeroMAC[:]) + } + + pkt := tcpip.PacketBuffer{ + Data: buffer.View(data).ToVectorisedView(), + } + if ethHdr != nil { + pkt.LinkHeader = buffer.View(ethHdr) + } + endpoint.InjectLinkAddr(protocol, remote, pkt) + return dataLen, nil +} + +// Read reads one outgoing packet from the network interface. +func (d *Device) Read() ([]byte, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return nil, syserror.EBADFD + } + + for { + info, ok := endpoint.Read() + if !ok { + return nil, syserror.ErrWouldBlock + } + + v, ok := d.encodePkt(&info) + if !ok { + // Ignore unsupported packet. + continue + } + return v, nil + } +} + +// encodePkt encodes packet for fd side. +func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { + var vv buffer.VectorisedView + + // Packet information. + if !d.hasFlags(linux.IFF_NO_PI) { + hdr := make(PacketInfoHeader, PacketInfoHeaderSize) + hdr.Encode(&PacketInfoFields{ + Protocol: info.Proto, + }) + vv.AppendView(buffer.View(hdr)) + } + + // If the packet does not already have link layer header, and the route + // does not exist, we can't compute it. This is possibly a raw packet, tun + // device doesn't support this at the moment. + if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + return nil, false + } + + // Ethernet header (TAP only). + if d.hasFlags(linux.IFF_TAP) { + // Add ethernet header if not provided. + if info.Pkt.LinkHeader == nil { + hdr := &header.EthernetFields{ + SrcAddr: info.Route.LocalLinkAddress, + DstAddr: info.Route.RemoteLinkAddress, + Type: info.Proto, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = d.endpoint.LinkAddress() + } + + eth := make(header.Ethernet, header.EthernetMinimumSize) + eth.Encode(hdr) + vv.AppendView(buffer.View(eth)) + } else { + vv.AppendView(info.Pkt.LinkHeader) + } + } + + // Append upper headers. + vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + // Append data payload. + vv.Append(info.Pkt.Data) + + return vv.ToView(), true +} + +// Name returns the name of the attached network interface. Empty string if +// unattached. +func (d *Device) Name() string { + d.mu.RLock() + defer d.mu.RUnlock() + if d.endpoint != nil { + return d.endpoint.name + } + return "" +} + +// Flags returns the flags set for d. Zero value if unset. +func (d *Device) Flags() uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + return d.flags +} + +func (d *Device) hasFlags(flags uint16) bool { + return d.flags&flags == flags +} + +// Readiness implements watier.Waitable.Readiness. +func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint != nil && endpoint.NumQueued() == 0 { + mask &= ^waiter.EventIn + } + } + return mask & (waiter.EventIn | waiter.EventOut) +} + +// WriteNotify implements channel.Notification.WriteNotify. +func (d *Device) WriteNotify() { + d.Notify(waiter.EventIn) +} + +// tunEndpoint is the link endpoint for the NIC created by the tun device. +// +// It is ref-counted as multiple opening files can attach to the same NIC. +// The last owner is responsible for deleting the NIC. +type tunEndpoint struct { + *channel.Endpoint + + refs.AtomicRefCount + + stack *stack.Stack + nicID tcpip.NICID + name string +} + +// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +func (e *tunEndpoint) DecRef() { + e.DecRefWithDestructor(func() { + e.stack.RemoveNIC(e.nicID) + }) +} diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go new file mode 100644 index 000000000..89d9d91a9 --- /dev/null +++ b/pkg/tcpip/link/tun/protocol.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // PacketInfoHeaderSize is the size of the packet information header. + PacketInfoHeaderSize = 4 + + offsetFlags = 0 + offsetProtocol = 2 +) + +// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is +// not set. +type PacketInfoFields struct { + Flags uint16 + Protocol tcpip.NetworkProtocolNumber +} + +// PacketInfoHeader is the wire representation of the packet information sent if +// IFF_NO_PI flag is not set. +type PacketInfoHeader []byte + +// Encode encodes f into h. +func (h PacketInfoHeader) Encode(f *PacketInfoFields) { + binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags) + binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol)) +} + +// Flags returns the flag field in h. +func (h PacketInfoHeader) Flags() uint16 { + return binary.BigEndian.Uint16(h[offsetFlags:]) +} + +// Protocol returns the protocol field in h. +func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:])) +} diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4da13c5df..e9fcc89a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi }, nil } -// LinkAddressProtocol implements stack.LinkAddressResolver. +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -// LinkAddressRequest implements stack.LinkAddressResolver. +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ RemoteLinkAddress: broadcastMAC, @@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) } -// ResolveStaticAddress implements stack.LinkAddressResolver. +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return broadcastMAC, true @@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo return tcpip.LinkAddress([]byte(nil)), false } -// SetOption implements NetworkProtocol. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements NetworkProtocol. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6597e6781..4f1742938 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 180a480fd..9aef5234b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 045409bda..f651871ce 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } -// cleanupHostOnlyState cleans up any state that is only useful for hosts. +// cleanupState cleans up ndp's state. // -// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a -// host to a router. This function will invalidate all discovered on-link -// prefixes, discovered routers, and auto-generated addresses as routers do not -// normally process Router Advertisements to discover default routers and -// on-link prefixes, and auto-generate addresses via SLAAC. +// If hostOnly is true, then only host-specific state will be cleaned up. +// +// cleanupState MUST be called with hostOnly set to true when ndp's NIC is +// transitioning from a host to a router. This function will invalidate all +// discovered on-link prefixes, discovered routers, and auto-generated +// addresses. +// +// If hostOnly is true, then the link-local auto-generated address will not be +// invalidated as routers are also expected to generate a link-local address. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupHostOnlyState() { +func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() linkLocalAddrs := 0 for addr := range ndp.autoGenAddresses { // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them. - if linkLocalSubnet.Contains(addr) { + // link-local address so we do not invalidate them if we are cleaning up + // host-only state. + if hostOnly && linkLocalSubnet.Contains(addr) { linkLocalAddrs++ continue } @@ -1230,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() { } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1f6f77439..6e9306d09 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s } } +// channelLinkWithHeaderLength is a channel.Endpoint with a configurable +// header length. +type channelLinkWithHeaderLength struct { + *channel.Endpoint + headerLength uint16 +} + +func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { + return l.headerLength +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { @@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) { // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. +// This tests also validates the NDP NS packet that is transmitted. func TestDADResolve(t *testing.T) { const nicID = 1 tests := []struct { name string + linkHeaderLen uint16 dupAddrDetectTransmits uint8 retransTimer time.Duration expectedRetransmitTimer time.Duration }{ - {"1:1s:1s", 1, time.Second, time.Second}, - {"2:1s:1s", 2, time.Second, time.Second}, - {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, + { + name: "1:1s:1s", + dupAddrDetectTransmits: 1, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "2:1s:1s", + linkHeaderLen: 1, + dupAddrDetectTransmits: 2, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "1:2s:2s", + linkHeaderLen: 2, + dupAddrDetectTransmits: 1, + retransTimer: 2 * time.Second, + expectedRetransmitTimer: 2 * time.Second, + }, // 0s is an invalid RetransmitTimer timer and will be fixed to // the default RetransmitTimer value of 1s. - {"1:0s:1s", 1, 0, time.Second}, + { + name: "1:0s:1s", + linkHeaderLen: 3, + dupAddrDetectTransmits: 1, + retransTimer: 0, + expectedRetransmitTimer: time.Second, + }, } for _, test := range tests { @@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) { checker.NDPNSTargetAddress(addr1), checker.NDPNSOptions(nil), )) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } }) } @@ -592,70 +635,94 @@ func TestDADFail(t *testing.T) { } } -// TestDADStop tests to make sure that the DAD process stops when an address is -// removed. func TestDADStop(t *testing.T) { const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - ndpConfigs := stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } + tests := []struct { + name string + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests to make sure that DAD stops when an address is removed. + { + name: "Remove address", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) + } + }, + }, - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + // Tests to make sure that DAD stops when the NIC is disabled. + { + name: "Disable NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + }, + }, } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + } - // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } - // Remove the address. This should stop DAD. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + // Address should not be considered bound to the NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + + test.stopFn(t, s) + + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Fatalf("got NeighborSolicit = %d, want <= 1", got) + // Should not have sent more than 1 NS message. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + t.Errorf("got NeighborSolicit = %d, want <= 1", got) + } + }) } } @@ -2886,17 +2953,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } } -// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers -// and prefixes, and non-linklocal auto-generated addresses are invalidated when -// a NIC becomes a router. -func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { +// TestCleanupNDPState tests that all discovered routers and prefixes, and +// auto-generated addresses are invalidated when a NIC becomes a router. +func TestCleanupNDPState(t *testing.T) { t.Parallel() const ( - lifetimeSeconds = 5 - maxEvents = 4 - nicID1 = 1 - nicID2 = 2 + lifetimeSeconds = 5 + maxRouterAndPrefixEvents = 4 + nicID1 = 1 + nicID2 = 2 ) prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) @@ -2912,254 +2978,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { PrefixLen: 64, } - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, + tests := []struct { + name string + cleanupFn func(t *testing.T, s *stack.Stack) + keepAutoGenLinkLocal bool + maxAutoGenAddrEvents int + }{ + // A NIC should still keep its auto-generated link-local address when + // becoming a router. + { + name: "Forwarding Enable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, + keepAutoGenLinkLocal: true, + maxAutoGenAddrEvents: 4, }, - NDPDisp: &ndpDisp, - }) - expectRouterEvent := func() (bool, ndpRouterEvent) { - select { - case e := <-ndpDisp.routerC: - return true, e - default: - } + // A NIC should cleanup all NDP state when it is disabled. + { + name: "NIC Disable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - return false, ndpRouterEvent{} + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + }, } - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) - return false, ndpPrefixEvent{} - } + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } + return false, ndpRouterEvent{} + } - return false, ndpAutoGenAddrEvent{} - } + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() + return false, ndpPrefixEvent{} + } - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. + return false, ndpAutoGenAddrEvent{} + } - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + expectAutoGenAddrEvent() - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + expectAutoGenAddrEvent() - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and + // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from + // llAddr4) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } - s.SetForwarding(true) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } - // Collect invalidation events after becoming a router - gotRouterEvents := make(map[ndpRouterEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectRouterEvent() - if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i) - break - } - gotRouterEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, - } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } + test.cleanupFn(t, s) - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + // Collect invalidation events after having NDP state cleaned up. + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < test.maxAutoGenAddrEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } - // Should not get any more events (invalidation timers should have been - // cancelled when we transitioned into a router). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) - select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr3, discovered: false}: 1, + {nicID: nicID1, addr: llAddr4, discovered: false}: 1, + {nicID: nicID2, addr: llAddr3, discovered: false}: 1, + {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + + if !test.keepAutoGenLinkLocal { + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 + } + + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + } + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // Should not get any more events (invalidation timers should have been + // cancelled when the NDP state was cleaned up). + time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: + } + }) } } @@ -3259,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { func TestRouterSolicitation(t *testing.T) { t.Parallel() + const nicID = 1 + tests := []struct { name string + linkHeaderLen uint16 maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3277,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with delay", + linkHeaderLen: 1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3285,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Single RS without delay", + linkHeaderLen: 2, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3293,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS without delay and invalid zero interval", + linkHeaderLen: 3, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3330,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() ctx, _ := context.WithTimeout(context.Background(), timeout) @@ -3357,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS(), ) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } waitForNothing := func(timeout time.Duration) { t.Helper() @@ -3373,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Make sure each RS got sent at the right @@ -3406,77 +3539,130 @@ func TestRouterSolicitation(t *testing.T) { }) } -// TestStopStartSolicitingRouters tests that when forwarding is enabled or -// disabled, router solicitations are stopped or started, respecitively. func TestStopStartSolicitingRouters(t *testing.T) { t.Parallel() + const nicID = 1 const interval = 500 * time.Millisecond const delay = time.Second const maxRtrSolicitations = 3 - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, + tests := []struct { + name string + startFn func(t *testing.T, s *stack.Stack) + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests that when forwarding is enabled or disabled, router solicitations + // are stopped or started, respectively. + { + name: "Forwarding enabled and disabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(false) + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, }, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Enable forwarding which should stop router solicitations. - s.SetForwarding(true) - ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before forwarding was enabled. - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok = e.ReadContext(ctx); ok { - t.Fatal("Should not have sent more than one RS message") - } - } + // Tests that when a NIC is enabled or disabled, router solicitations + // are started or stopped, respectively. + { + name: "NIC disabled and enabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Enabling forwarding again should do nothing. - s.SetForwarding(true) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") - } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Disable forwarding which should start router solicitations. - s.SetForwarding(false) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + }, + }, } - // Disabling forwarding again should do nothing. - s.SetForwarding(false) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(maxRtrSolicitations, 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + p, ok := e.ReadContext(ctx) + if !ok { + t.Fatal("timed out waiting for packet") + return + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Stop soliciting routers. + test.stopFn(t, s) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + // A single RS may have been sent before forwarding was enabled. + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok = e.ReadContext(ctx); ok { + t.Fatal("should not have sent more than one RS message") + } + } + + // Stopping router solicitations after it has already been stopped should + // do nothing. + test.stopFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") + } + + // Start soliciting routers. + test.startFn(t, s) + waitForPkt(delay + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + } + + // Starting router solicitations after it has already completed should do + // nothing. + test.startFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after finishing router solicitations") + } + }) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a75dc0322..271e55be0 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -28,6 +28,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/iptables" ) +var ipv4BroadcastAddr = tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 8 * header.IPv4AddressSize, + }, +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -133,11 +141,77 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } + nic.linkEP.Attach(nic) + return nic } -// enable enables the NIC. enable will attach the link to its LinkEndpoint and -// join the IPv6 All-Nodes Multicast address (ff02::1). +// enabled returns true if n is enabled. +func (n *NIC) enabled() bool { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + return enabled +} + +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() *tcpip.Error { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if !enabled { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if !n.mu.enabled { + return nil + } + + // TODO(b/147015577): Should Routes that are currently bound to n be + // invalidated? Currently, Routes will continue to work when a NIC is enabled + // again, and applications may not know that the underlying NIC was ever + // disabled. + + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + n.mu.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupState(false /* hostOnly */) + + // Stop DAD for all the unicast IPv6 endpoints that are in the + // permanentTentative state. + for _, r := range n.mu.endpoints { + if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { + n.mu.ndp.stopDuplicateAddressDetection(addr) + } + } + + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The address may have already been removed. + if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + n.mu.enabled = false + return nil +} + +// enable enables n. +// +// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast +// address (ff02::1), start DAD for permanent addresses, and start soliciting +// routers if the stack is not operating as a router. If the stack is also +// configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { n.mu.RLock() enabled := n.mu.enabled @@ -155,14 +229,9 @@ func (n *NIC) enable() *tcpip.Error { n.mu.enabled = true - n.attachLinkEndpoint() - // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { + if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } @@ -184,6 +253,14 @@ func (n *NIC) enable() *tcpip.Error { return nil } + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // @@ -201,10 +278,6 @@ func (n *NIC) enable() *tcpip.Error { } } - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - // Do not auto-generate an IPv6 link-local address for loopback devices. if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { // The valid and preferred lifetime is infinite for the auto-generated @@ -226,6 +299,33 @@ func (n *NIC) enable() *tcpip.Error { return nil } +// remove detaches NIC from the link endpoint, and marks existing referenced +// network endpoints expired. This guarantees no packets between this NIC and +// the network stack. +func (n *NIC) remove() *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + // Remove permanent and permanentTentative addresses, so no packet goes out. + var errs []*tcpip.Error + for nid, ref := range n.mu.endpoints { + switch ref.getKind() { + case permanentTentative, permanent: + if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { + errs = append(errs, err) + } + } + } + if len(errs) > 0 { + return errs[0] + } + + return nil +} + // becomeIPv6Router transitions n into an IPv6 router. // // When transitioning into an IPv6 router, host-only state (NDP discovered @@ -235,7 +335,7 @@ func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() - n.mu.ndp.cleanupHostOnlyState() + n.mu.ndp.cleanupState(true /* hostOnly */) n.mu.ndp.stopSolicitingRouters() } @@ -250,12 +350,6 @@ func (n *NIC) becomeIPv6Host() { n.mu.ndp.startSolicitingRouters() } -// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it -// to start delivering packets. -func (n *NIC) attachLinkEndpoint() { - n.linkEP.Attach(n) -} - // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() @@ -713,6 +807,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { case permanentExpired, temporary: continue } + addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -1010,6 +1105,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } +// isInGroup returns true if n has joined the multicast group addr. +func (n *NIC) isInGroup(addr tcpip.Address) bool { + n.mu.RLock() + joins := n.mu.mcastJoins[NetworkEndpointID{addr}] + n.mu.RUnlock() + + return joins != 0 +} + func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1237,6 +1341,11 @@ func (n *NIC) Stack() *Stack { return n.stack } +// LinkEndpoint returns the link endpoint of n. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // isAddrTentative returns true if addr is tentative on n. // // Note that if addr is not associated with n, then this function will return @@ -1423,7 +1532,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { // // r's NIC must be read locked. func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - return r.getKind() != permanentExpired || r.nic.mu.spoofing + return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ec91f60dd..f9fd8f18f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -74,10 +74,11 @@ type TransportEndpoint interface { // HandleControlPacket takes ownership of pkt. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) - // Close puts the endpoint in a closed state and frees all resources - // associated with it. This cleanup may happen asynchronously. Wait can - // be used to block on this asynchronous cleanup. - Close() + // Abort initiates an expedited endpoint teardown. It puts the endpoint + // in a closed state and frees all resources associated with it. This + // cleanup may happen asynchronously. Wait can be used to block on this + // asynchronous cleanup. + Abort() // Wait waits for any worker goroutines owned by the endpoint to stop. // @@ -160,6 +161,13 @@ type TransportProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // TransportDispatcher contains the methods used by the network stack to deliver @@ -277,7 +285,7 @@ type NetworkProtocol interface { // DefaultPrefixLen returns the protocol's default prefix length. DefaultPrefixLen() int - // ParsePorts returns the source and destination addresses stored in a + // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) @@ -293,6 +301,13 @@ type NetworkProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6eac16e16..ebb6c5e3b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -881,6 +881,8 @@ type NICOptions struct { // CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and // NICOptions. See the documentation on type NICOptions for details on how // NICs can be configured. +// +// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -900,7 +902,6 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) - s.nics[id] = n if !opts.Disabled { return n.enable() @@ -910,34 +911,88 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls -// `LinkEndpoint.Attach` to start delivering packets to it. +// LinkEndpoint.Attach to bind ep with a NetworkDispatcher. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, nic := range s.nics { + if nic.Name() == name { + return nic, true + } + } + return nil, false +} + // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } return nic.enable() } +// DisableNIC disables the given NIC. +func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.disable() +} + // CheckNIC checks if a NIC is usable. func (s *Stack) CheckNIC(id tcpip.NICID) bool { s.mu.RLock() + defer s.mu.RUnlock() + nic, ok := s.nics[id] - s.mu.RUnlock() - if ok { - return nic.linkEP.IsAttached() + if !ok { + return false } - return false + + return nic.enabled() +} + +// RemoveNIC removes NIC and all related routes from the network stack. +func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + delete(s.nics, id) + + // Remove routes in-place. n tracks the number of routes written. + n := 0 + for i, r := range s.routeTable { + if r.NIC != id { + // Keep this route. + if i > n { + s.routeTable[n] = r + } + n++ + } + } + s.routeTable = s.routeTable[:n] + + return nic.remove() } // NICAddressRanges returns a map of NICIDs to their associated subnets. @@ -989,7 +1044,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.linkEP.IsAttached(), + Running: nic.enabled(), Promiscuous: nic.isPromiscuousMode(), Loopback: nic.isLoopback(), } @@ -1151,7 +1206,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok { + if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } @@ -1161,7 +1216,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok { + if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route @@ -1391,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { // Endpoints created or modified during this call may not get closed. func (s *Stack) Close() { for _, e := range s.RegisteredEndpoints() { - e.Close() + e.Abort() + } + for _, p := range s.transportProtocols { + p.proto.Close() + } + for _, p := range s.networkProtocols { + p.Close() } } @@ -1409,6 +1470,12 @@ func (s *Stack) Wait() { for _, e := range s.CleanupEndpoints() { e.Wait() } + for _, p := range s.transportProtocols { + p.proto.Wait() + } + for _, p := range s.networkProtocols { + p.Wait() + } s.mu.RLock() defer s.mu.RUnlock() @@ -1614,6 +1681,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC return tcpip.ErrUnknownNICID } +// IsInGroup returns true if the NIC with ID nicID has joined the multicast +// group multicastAddr. +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.isInGroup(multicastAddr), nil + } + return false, tcpip.ErrUnknownNICID +} + // IPTables returns the stack's iptables. func (s *Stack) IPTables() iptables.IPTables { s.tablesMu.RLock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7ba604442..edf6bec52 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -234,10 +235,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } +// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify +// that LinkEndpoint.Attach was called. +type linkEPWithMockedAttach struct { + stack.LinkEndpoint + attached bool +} + +// Attach implements stack.LinkEndpoint.Attach. +func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { + l.LinkEndpoint.Attach(d) + l.attached = true +} + +func (l *linkEPWithMockedAttach) isAttached() bool { + return l.attached +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -509,6 +533,296 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to +// a NetworkDispatcher when the NIC is created. +func TestAttachToLinkEndpointImmediately(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + nicOpts stack.NICOptions + }{ + { + name: "Create enabled NIC", + nicOpts: stack.NICOptions{Disabled: false}, + }, + { + name: "Create disabled NIC", + nicOpts: stack.NICOptions{Disabled: true}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + + if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) + } + if !e.isAttached() { + t.Fatalf("link endpoint not attached to a network disatcher") + } + }) + } +} + +func TestDisableUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := loopback.New() + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + checkNIC := func(enabled bool) { + t.Helper() + + allNICInfo := s.NICInfo() + nicInfo, ok := allNICInfo[nicID] + if !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } else if nicInfo.Flags.Running != enabled { + t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) + } + + if got := s.CheckNIC(nicID); got != enabled { + t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) + } + } + + // NIC should initially report itself as disabled. + checkNIC(false) + + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + checkNIC(true) + + // If the NIC is not reporting a correct enabled status, we cannot trust the + // next check so end the test here. + if t.Failed() { + t.FailNow() + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + checkNIC(false) +} + +func TestRoutesWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Disabling NIC1 should result in no routes to odd addresses. Routes to even + // addresses should continue to be available as NIC2 is still enabled. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + nic1Dst := tcpip.Address("\x05") + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + nic2Dst := tcpip.Address("\x06") + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Disabling NIC2 should result in no routes to even addresses. No route + // should be available to any address as routes to odd addresses were made + // unavailable by disabling NIC1 above. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + // Enabling NIC1 should make routes to odd addresses available again. Routes + // to even addresses should continue to be unavailable as NIC2 is still + // disabled. + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) +} + +func TestRouteWritePacketWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + nic1Dst := tcpip.Address("\x05") + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() + + nic2Dst := tcpip.Address("\x06") + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() + + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } + + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC1 should fail. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC2 should fail. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + // Writes with Routes that use the re-enabled NIC1 should succeed. + // TODO(b/147015577): Should we instead completely invalidate all Routes that + // were bound to a disabled NIC at some point? + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) +} + func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has @@ -2173,13 +2487,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { e := channel.New(0, 1280, test.linkAddr) s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName} + nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - var expectedMainAddr tcpip.AddressWithPrefix + // A new disabled NIC should not have any address, even if auto generation + // was enabled. + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + // Enabling the NIC should attempt auto-generation of a link-local + // address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + var expectedMainAddr tcpip.AddressWithPrefix if test.shouldGen { expectedMainAddr = tcpip.AddressWithPrefix{ Address: test.expectedAddr, @@ -2609,6 +2939,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } } +func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + + // Enabling the NIC should add the IPv4 broadcast address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 1 { + t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) + } + want := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } + if allNICAddrs[0] != want { + t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + } + + // Disabling the NIC should remove the IPv4 broadcast address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } +} + +func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + // Should not be in the IPv6 all-nodes multicast group yet because the NIC has + // not been enabled yet. + isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } +} + // TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC // was disabled have DAD performed on them when the NIC is enabled. func TestDoDADWhenNICEnabled(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d686e6eb8..778c0a4d6 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p ep.mu.RUnlock() // Don't use defer for performance reasons. } -// Close implements stack.TransportEndpoint.Close. -func (ep *multiPortEndpoint) Close() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Close() - } -} - -// Wait implements stack.TransportEndpoint.Wait. -func (ep *multiPortEndpoint) Wait() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Wait() - } -} - // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 869c69a6d..5d1da2f8b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } +func (f *fakeTransportEndpoint) Abort() { + f.Close() +} + func (f *fakeTransportEndpoint) Close() { f.route.Release() } @@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +// Abort implements TransportProtocol.Abort. +func (*fakeTransportProtocol) Abort() {} + +// Close implements tcpip.Endpoint.Close. +func (*fakeTransportProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeTransportProtocol) Wait() {} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 9ca39ce40..3dc5d87d6 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -323,11 +323,11 @@ type ControlMessages struct { // TOS is the IPv4 type of service of the associated packet. TOS uint8 - // HasTClass indicates whether Tclass is valid/set. + // HasTClass indicates whether TClass is valid/set. HasTClass bool - // Tclass is the IPv6 traffic class of the associated packet. - TClass int32 + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 // HasIPPacketInfo indicates whether PacketInfo is set. HasIPPacketInfo bool @@ -341,9 +341,15 @@ type ControlMessages struct { // networking stack. type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources - // associated with it. + // associated with it. Close initiates the teardown process, the + // Endpoint may not be fully closed when Close returns. Close() + // Abort initiates an expedited endpoint teardown. As compared to + // Close, Abort prioritizes closing the Endpoint quickly over cleanly. + // Abort is best effort; implementing Abort with Close is acceptable. + Abort() + // Read reads data from the endpoint and optionally returns the sender. // // This method does not block if there is no data pending. It will also @@ -502,9 +508,13 @@ type WriteOptions struct { type SockOptBool int const ( + // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the + // IPV6_TCLASS ancillary message is passed with incoming packets. + ReceiveTClassOption SockOptBool = iota + // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS // ancillary message is passed with incoming packets. - ReceiveTOSOption SockOptBool = iota + ReceiveTOSOption // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. @@ -514,6 +524,9 @@ const ( // if more inforamtion is provided with incoming packets such // as interface index and address. ReceiveIPPacketInfoOption + + // TODO(b/146901447): convert existing bool socket options to be handled via + // Get/SetSockOptBool ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 48764b978..2f98a996f 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -25,6 +25,8 @@ import ( ) // StdClock implements Clock with the time package. +// +// +stateify savable type StdClock struct{} var _ Clock = (*StdClock)(nil) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 42afb3f5b..426da1ee6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9ce500e80..113d92901 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index fc5bc69fa..5722815e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb return ep, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (ep *endpoint) Close() { ep.mu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ee9c4c58b..2ef5fac76 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (e *endpoint) Close() { e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 08afb7c17..13e383ffc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + if l.listenEP != nil { l.removePendingEndpoint(ep) } @@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Unlock() // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5c5397823..7730e6445 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { + if n¬ifyReset != 0 || n¬ifyAbort != 0 { return tcpip.ErrConnectionAborted } @@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 { + if n¬ifyClose != 0 || n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index e18012ac0..d792b07d6 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -68,17 +68,28 @@ func (q *epQueue) empty() bool { type processor struct { epQ epQueue newEndpointWaker sleep.Waker + closeWaker sleep.Waker id int + wg sync.WaitGroup } func newProcessor(id int) *processor { p := &processor{ id: id, } + p.wg.Add(1) go p.handleSegments() return p } +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) wait() { + p.wg.Wait() +} + func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) @@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) handleSegments() { const newEndpointWaker = 1 + const closeWaker = 2 s := sleep.Sleeper{} s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + s.AddWaker(&p.closeWaker, closeWaker) defer s.Done() for { - s.Fetch(true) + id, ok := s.Fetch(true) + if ok && id == closeWaker { + p.wg.Done() + return + } for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { if ep.segmentQueue.empty() { continue @@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher { } } +func (d *dispatcher) close() { + for _, p := range d.processors { + p.close() + } +} + +func (d *dispatcher) wait() { + for _, p := range d.processors { + p.wait() + } +} + func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f2be0e651..f1ad19dac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,8 @@ const ( notifyDrain notifyReset notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + if e.EndpointState().connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. @@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() { // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { + switch e.EndpointState() { + // Sockets in StateSynRecv state(passive connections) are closed when + // the handshake fails or if the listening socket is closed while + // handshake was in progress. In such cases the handshake goroutine + // is already gone by the time Close is called and we need to cleanup + // here. + case StateInitial, StateBound, StateSynRecv: e.cleanupLocked() - } else { + e.setEndpointState(StateClose) + case StateError, StateClose: + // do nothing. + default: e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 958c06fa7..73098d904 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -194,7 +194,7 @@ func replyWithReset(s *segment) { sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: @@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index cc118c993..5b2b16afa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -543,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { ), ) - // Wait for the TIME-WAIT state to transition to CLOSED. - time.Sleep(1 * time.Second) + // Wait for a little more than the TIME-WAIT duration for the socket to + // transition to CLOSED state. + time.Sleep(1200 * time.Millisecond) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3fe91cac2..1c6a600b8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -32,7 +32,8 @@ type udpPacket struct { packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - tos uint8 + // tos stores either the receiveTOS or receiveTClass value. + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -119,6 +120,10 @@ type endpoint struct { // as ancillary data to ControlMessages on Read. receiveTOS bool + // receiveTClass determines if the incoming IPv6 TClass header field is + // passed as ancillary data to ControlMessages on Read. + receiveTClass bool + // receiveIPPacketInfo determines if the packet info is returned by Read. receiveIPPacketInfo bool @@ -181,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -258,13 +268,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess } e.mu.RLock() receiveTOS := e.receiveTOS + receiveTClass := e.receiveTClass receiveIPPacketInfo := e.receiveIPPacketInfo e.mu.RUnlock() if receiveTOS { cm.HasTOS = true cm.TOS = p.tos } - + if receiveTClass { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } if receiveIPPacketInfo { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo @@ -490,6 +505,17 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrNotSupported + } + + e.mu.Lock() + e.receiveTClass = v + e.mu.Unlock() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -709,6 +735,17 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrNotSupported + } + + e.mu.RLock() + v := e.receiveTClass + e.mu.RUnlock() + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1273,6 +1310,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.packetInfo.LocalAddr = r.LocalAddress packet.packetInfo.DestinationAddr = r.RemoteAddress packet.packetInfo.NIC = r.NICID() + case header.IPv6ProtocolNumber: + packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } packet.timestamp = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 259c3072a..8df089d22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index f0ff3fe71..34b7c2360 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -409,6 +409,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, @@ -1336,7 +1337,7 @@ func TestSetTTL(t *testing.T) { } } -func TestTOSV4(t *testing.T) { +func TestSetTOS(t *testing.T) { for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1347,23 +1348,23 @@ func TestTOSV4(t *testing.T) { const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) } if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) } if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1371,7 +1372,7 @@ func TestTOSV4(t *testing.T) { } } -func TestTOSV6(t *testing.T) { +func TestSetTClass(t *testing.T) { for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1379,71 +1380,92 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = testTOS + const tClass = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - c.t.Errorf("SetSockOpt failed: %s", err) + if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { + c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) } if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if want := tcpip.IPv6TrafficClassOption(tClass); v != want { + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) } - testWrite(c, flow, checker.TOS(tos, 0)) + // The header getter for TClass is called TOS, so use that checker. + testWrite(c, flow, checker.TOS(tClass, 0)) }) } } -func TestReceiveTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() +func TestReceiveTosTClass(t *testing.T) { + testCases := []struct { + name string + getReceiveOption tcpip.SockOptBool + tests []testFlow + }{ + {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, + {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + } + for _, testCase := range testCases { + for _, flow := range testCase.tests { + t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpointForFlow(flow) + c.createEndpointForFlow(flow) + option := testCase.getReceiveOption + name := testCase.name - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false) - } + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) + } - want := true - if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil { - c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err) - } + want := true + if err := c.ep.SetSockOptBool(option, want); err != nil { + c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) + } - got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - if got != want { - c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want) - } + got, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + } - // Verify that the correct received TOS is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - testRead(c, flow, checker.ReceiveTOS(testTOS)) - }) + if got != want { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) + } + + // Verify that the correct received TOS or TClass is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + switch option { + case tcpip.ReceiveTClassOption: + testRead(c, flow, checker.ReceiveTClass(testTOS)) + case tcpip.ReceiveTOSOption: + testRead(c, flow, checker.ReceiveTOS(testTOS)) + default: + t.Fatalf("unknown test variant: %s", name) + } + }) + } } } diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD index ff8b9e91a..6c9ada9c7 100644 --- a/pkg/usermem/BUILD +++ b/pkg/usermem/BUILD @@ -25,7 +25,6 @@ go_library( "bytes_io_unsafe.go", "usermem.go", "usermem_arm64.go", - "usermem_unsafe.go", "usermem_x86.go", ], visibility = ["//:sandbox"], @@ -33,6 +32,7 @@ go_library( "//pkg/atomicbitops", "//pkg/binary", "//pkg/context", + "//pkg/gohacks", "//pkg/log", "//pkg/safemem", "//pkg/syserror", diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 71fd4e155..d2f4403b0 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -251,7 +252,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt } end, ok := addr.AddLength(uint64(readlen)) if !ok { - return stringFromImmutableBytes(buf[:done]), syserror.EFAULT + return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT } // Shorten the read to avoid crossing page boundaries, since faulting // in a page unnecessarily is expensive. This also ensures that partial @@ -272,16 +273,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt // Look for the terminating zero byte, which may have occurred before // hitting err. if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 { - return stringFromImmutableBytes(buf[:done+i]), nil + return gohacks.StringFromImmutableBytes(buf[:done+i]), nil } done += n if err != nil { - return stringFromImmutableBytes(buf[:done]), err + return gohacks.StringFromImmutableBytes(buf[:done]), err } addr = end } - return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG + return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG } // CopyOutVec copies bytes from src to the memory mapped at ars in uio. The diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index ae4dd102a..26f68fe3d 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -19,7 +19,6 @@ go_library( "loader_amd64.go", "loader_arm64.go", "network.go", - "pprof.go", "strace.go", "user.go", ], @@ -91,6 +90,7 @@ go_library( "//pkg/usermem", "//runsc/boot/filter", "//runsc/boot/platforms", + "//runsc/boot/pprof", "//runsc/specutils", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 9c9e94864..17e774e0c 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" + "gvisor.dev/gvisor/runsc/boot/pprof" "gvisor.dev/gvisor/runsc/specutils" ) @@ -142,7 +143,7 @@ func newController(fd int, l *Loader) (*controller, error) { } srv.Register(manager) - if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok { + if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok { net := &Network{ Stack: eps.Stack, } @@ -341,7 +342,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { return fmt.Errorf("creating memory file: %v", err) } k.SetMemoryFile(mf) - networkStack := cm.l.k.NetworkStack() + networkStack := cm.l.k.RootNetworkNamespace().Stack() cm.l.k = k // Set up the restore environment. @@ -365,9 +366,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { } if cm.l.conf.ProfileEnable { - // initializePProf opens /proc/self/maps, so has to be - // called before installing seccomp filters. - initializePProf() + // pprof.Initialize opens /proc/self/maps, so has to be called before + // installing seccomp filters. + pprof.Initialize() } // Seccomp filters have to be applied before parsing the state file. diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index c69f4c602..a4627905e 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -229,7 +229,9 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_NANOSLEEP: {}, syscall.SYS_PPOLL: {}, syscall.SYS_PREAD64: {}, + syscall.SYS_PREADV: {}, syscall.SYS_PWRITE64: {}, + syscall.SYS_PWRITEV: {}, syscall.SYS_READ: {}, syscall.SYS_RECVMSG: []seccomp.Rule{ { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index eef43b9df..e7ca98134 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -49,6 +49,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -60,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/runsc/boot/filter" _ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms. + "gvisor.dev/gvisor/runsc/boot/pprof" "gvisor.dev/gvisor/runsc/specutils" // Include supported socket providers. @@ -230,11 +232,8 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("enabling strace: %v", err) } - // Create an empty network stack because the network namespace may be empty at - // this point. Netns is configured before Run() is called. Netstack is - // configured using a control uRPC message. Host network is configured inside - // Run(). - networkStack, err := newEmptyNetworkStack(args.Conf, k, k) + // Create root network namespace/stack. + netns, err := newRootNetworkNamespace(args.Conf, k, k) if err != nil { return nil, fmt.Errorf("creating network: %v", err) } @@ -277,7 +276,7 @@ func New(args Args) (*Loader, error) { FeatureSet: cpuid.HostFeatureSet(), Timekeeper: tk, RootUserNamespace: creds.UserNamespace, - NetworkStack: networkStack, + RootNetworkNamespace: netns, ApplicationCores: uint(args.NumCPU), Vdso: vdso, RootUTSNamespace: kernel.NewUTSNamespace(args.Spec.Hostname, args.Spec.Hostname, creds.UserNamespace), @@ -466,7 +465,7 @@ func (l *Loader) run() error { // Delay host network configuration to this point because network namespace // is configured after the loader is created and before Run() is called. log.Debugf("Configuring host network") - stack := l.k.NetworkStack().(*hostinet.Stack) + stack := l.k.RootNetworkNamespace().Stack().(*hostinet.Stack) if err := stack.Configure(); err != nil { return err } @@ -485,7 +484,7 @@ func (l *Loader) run() error { // l.restore is set by the container manager when a restore call is made. if !l.restore { if l.conf.ProfileEnable { - initializePProf() + pprof.Initialize() } // Finally done with all configuration. Setup filters before user code @@ -908,48 +907,92 @@ func (l *Loader) WaitExit() kernel.ExitStatus { return l.k.GlobalInit().ExitStatus() } -func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { +func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) { + // Create an empty network stack because the network namespace may be empty at + // this point. Netns is configured before Run() is called. Netstack is + // configured using a control uRPC message. Host network is configured inside + // Run(). switch conf.Network { case NetworkHost: - return hostinet.NewStack(), nil + // No network namespacing support for hostinet yet, hence creator is nil. + return inet.NewRootNamespace(hostinet.NewStack(), nil), nil case NetworkNone, NetworkSandbox: - // NetworkNone sets up loopback using netstack. - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} - s := netstack.Stack{stack.New(stack.Options{ - NetworkProtocols: netProtos, - TransportProtocols: transProtos, - Clock: clock, - Stats: netstack.Metrics, - HandleLocal: true, - // Enable raw sockets for users with sufficient - // privileges. - RawFactory: raw.EndpointFactory{}, - UniqueID: uniqueID, - })} - - // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { - return nil, fmt.Errorf("failed to enable SACK: %v", err) + s, err := newEmptySandboxNetworkStack(clock, uniqueID) + if err != nil { + return nil, err } + creator := &sandboxNetstackCreator{ + clock: clock, + uniqueID: uniqueID, + } + return inet.NewRootNamespace(s, creator), nil - // Set default TTLs as required by socket/netstack. - s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) - s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + default: + panic(fmt.Sprintf("invalid network configuration: %v", conf.Network)) + } - // Enable Receive Buffer Auto-Tuning. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) - } +} - s.FillDefaultIPTables() +func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { + netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} + transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} + s := netstack.Stack{stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + Clock: clock, + Stats: netstack.Metrics, + HandleLocal: true, + // Enable raw sockets for users with sufficient + // privileges. + RawFactory: raw.EndpointFactory{}, + UniqueID: uniqueID, + })} - return &s, nil + // Enable SACK Recovery. + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { + return nil, fmt.Errorf("failed to enable SACK: %v", err) + } - default: - panic(fmt.Sprintf("invalid network configuration: %v", conf.Network)) + // Set default TTLs as required by socket/netstack. + s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + + // Enable Receive Buffer Auto-Tuning. + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) + } + + s.FillDefaultIPTables() + + return &s, nil +} + +// sandboxNetstackCreator implements kernel.NetworkStackCreator. +// +// +stateify savable +type sandboxNetstackCreator struct { + clock tcpip.Clock + uniqueID stack.UniqueID +} + +// CreateStack implements kernel.NetworkStackCreator.CreateStack. +func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) { + s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID) + if err != nil { + return nil, err } + + // Setup loopback. + n := &Network{Stack: s.(*netstack.Stack).Stack} + nicID := tcpip.NICID(f.uniqueID.UniqueID()) + link := DefaultLoopbackLink + linkEP := loopback.New() + if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { + return nil, err + } + + return s, nil } // signal sends a signal to one or more processes in a container. If PID is 0, diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 6a8765ec8..bee6ee336 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -17,6 +17,7 @@ package boot import ( "fmt" "net" + "strings" "syscall" "gvisor.dev/gvisor/pkg/log" @@ -31,6 +32,32 @@ import ( "gvisor.dev/gvisor/pkg/urpc" ) +var ( + // DefaultLoopbackLink contains IP addresses and routes of "127.0.0.1/8" and + // "::1/8" on "lo" interface. + DefaultLoopbackLink = LoopbackLink{ + Name: "lo", + Addresses: []net.IP{ + net.IP("\x7f\x00\x00\x01"), + net.IPv6loopback, + }, + Routes: []Route{ + { + Destination: net.IPNet{ + IP: net.IPv4(0x7f, 0, 0, 0), + Mask: net.IPv4Mask(0xff, 0, 0, 0), + }, + }, + { + Destination: net.IPNet{ + IP: net.IPv6loopback, + Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)), + }, + }, + }, + } +) + // Network exposes methods that can be used to configure a network stack. type Network struct { Stack *stack.Stack diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD new file mode 100644 index 000000000..29cb42b2f --- /dev/null +++ b/runsc/boot/pprof/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pprof", + srcs = ["pprof.go"], + visibility = [ + "//runsc:__subpackages__", + ], +) diff --git a/runsc/boot/pprof.go b/runsc/boot/pprof/pprof.go index 463362f02..1ded20dee 100644 --- a/runsc/boot/pprof.go +++ b/runsc/boot/pprof/pprof.go @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package boot +// Package pprof provides a stub to initialize custom profilers. +package pprof -func initializePProf() { +// Initialize will be called at boot for initializing custom profilers. +func Initialize() { } diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 2a88b85a9..d0bb4613a 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -31,6 +31,7 @@ go_library( "spec.go", "start.go", "state.go", + "statefile.go", "syscalls.go", "wait.go", ], @@ -43,6 +44,8 @@ go_library( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/state", + "//pkg/state/statefile", "//pkg/sync", "//pkg/unet", "//pkg/urpc", diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go new file mode 100644 index 000000000..e6f1907da --- /dev/null +++ b/runsc/cmd/statefile.go @@ -0,0 +1,143 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "context" + "fmt" + "os" + + "github.com/google/subcommands" + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/statefile" + "gvisor.dev/gvisor/runsc/flag" +) + +// Statefile implements subcommands.Command for the "statefile" command. +type Statefile struct { + list bool + get string + key string + output string + html bool +} + +// Name implements subcommands.Command. +func (*Statefile) Name() string { + return "state" +} + +// Synopsis implements subcommands.Command. +func (*Statefile) Synopsis() string { + return "shows information about a statefile" +} + +// Usage implements subcommands.Command. +func (*Statefile) Usage() string { + return `statefile [flags] <statefile>` +} + +// SetFlags implements subcommands.Command. +func (s *Statefile) SetFlags(f *flag.FlagSet) { + f.BoolVar(&s.list, "list", false, "lists the metdata in the statefile.") + f.StringVar(&s.get, "get", "", "extracts the given metadata key.") + f.StringVar(&s.key, "key", "", "the integrity key for the file.") + f.StringVar(&s.output, "output", "", "target to write the result.") + f.BoolVar(&s.html, "html", false, "outputs in HTML format.") +} + +// Execute implements subcommands.Command.Execute. +func (s *Statefile) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + // Check arguments. + if s.list && s.get != "" { + Fatalf("error: can't specify -list and -get simultaneously.") + } + + // Setup output. + var output = os.Stdout // Default. + if s.output != "" { + f, err := os.OpenFile(s.output, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644) + if err != nil { + Fatalf("error opening output: %v", err) + } + defer func() { + if err := f.Close(); err != nil { + Fatalf("error flushing output: %v", err) + } + }() + output = f + } + + // Open the file. + if f.NArg() != 1 { + f.Usage() + return subcommands.ExitUsageError + } + input, err := os.Open(f.Arg(0)) + if err != nil { + Fatalf("error opening input: %v\n", err) + } + + if s.html { + fmt.Fprintf(output, "<html><body>\n") + defer fmt.Fprintf(output, "</body></html>\n") + } + + // Dump the full file? + if !s.list && s.get == "" { + var key []byte + if s.key != "" { + key = []byte(s.key) + } + rc, _, err := statefile.NewReader(input, key) + if err != nil { + Fatalf("error parsing statefile: %v", err) + } + if err := state.PrettyPrint(output, rc, s.html); err != nil { + Fatalf("error printing state: %v", err) + } + return subcommands.ExitSuccess + } + + // Load just the metadata. + metadata, err := statefile.MetadataUnsafe(input) + if err != nil { + Fatalf("error reading metadata: %v", err) + } + + // Is it a single key? + if s.get != "" { + val, ok := metadata[s.get] + if !ok { + Fatalf("metadata key %s: not found", s.get) + } + fmt.Fprintf(output, "%s\n", val) + return subcommands.ExitSuccess + } + + // List all keys. + if s.html { + fmt.Fprintf(output, " <ul>\n") + defer fmt.Fprintf(output, " </ul>\n") + } + for key := range metadata { + if s.html { + fmt.Fprintf(output, " <li>%s</li>\n", key) + } else { + fmt.Fprintf(output, "%s\n", key) + } + } + return subcommands.ExitSuccess +} diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 04a7dc237..bdd65b498 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -71,6 +71,7 @@ func waitForProcessCount(cont *Container, want int) error { return &backoff.PermanentError{Err: err} } if got := len(pss); got != want { + log.Infof("Waiting for process count to reach %d. Current: %d", want, got) return fmt.Errorf("wrong process count, got: %d, want: %d", got, want) } return nil diff --git a/runsc/main.go b/runsc/main.go index 762b0f801..af73bed97 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -116,8 +116,8 @@ func main() { subcommands.Register(new(cmd.Resume), "") subcommands.Register(new(cmd.Run), "") subcommands.Register(new(cmd.Spec), "") - subcommands.Register(new(cmd.Start), "") subcommands.Register(new(cmd.State), "") + subcommands.Register(new(cmd.Start), "") subcommands.Register(new(cmd.Wait), "") // Register internal commands with the internal group name. This causes @@ -127,6 +127,7 @@ func main() { subcommands.Register(new(cmd.Boot), internalGroup) subcommands.Register(new(cmd.Debug), internalGroup) subcommands.Register(new(cmd.Gofer), internalGroup) + subcommands.Register(new(cmd.Statefile), internalGroup) // All subcommands must be registered before flag parsing. flag.Parse() diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 99e143696..bc093fba5 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -21,7 +21,6 @@ import ( "path/filepath" "runtime" "strconv" - "strings" "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -75,30 +74,8 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi } func createDefaultLoopbackInterface(conn *urpc.Client) error { - link := boot.LoopbackLink{ - Name: "lo", - Addresses: []net.IP{ - net.IP("\x7f\x00\x00\x01"), - net.IPv6loopback, - }, - Routes: []boot.Route{ - { - Destination: net.IPNet{ - - IP: net.IPv4(0x7f, 0, 0, 0), - Mask: net.IPv4Mask(0xff, 0, 0, 0), - }, - }, - { - Destination: net.IPNet{ - IP: net.IPv6loopback, - Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)), - }, - }, - }, - } if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &boot.CreateLinksAndRoutesArgs{ - LoopbackLinks: []boot.LoopbackLink{link}, + LoopbackLinks: []boot.LoopbackLink{boot.DefaultLoopbackLink}, }, nil); err != nil { return fmt.Errorf("creating loopback link and routes: %v", err) } diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh new file mode 100644 index 000000000..a0317db02 --- /dev/null +++ b/scripts/benchmark.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Run in the root of the repo. +cd "$(dirname "$0")" + +KEY_PATH=${KEY_PATH:-"${KOKORO_KEYSTORE_DIR}/${KOKORO_SERVICE_ACCOUNT}"} + +gcloud auth activate-service-account --key-file "${KEY_PATH}" + +gcloud compute instances list + diff --git a/scripts/common_build.sh b/scripts/common_build.sh index ae8b67383..3be0bb21c 100755 --- a/scripts/common_build.sh +++ b/scripts/common_build.sh @@ -70,7 +70,9 @@ function collect_logs() { for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do junitparser merge `find $d -name test.xml` $d/test.xml cat $d/shard_*_of_*/test.log > $d/test.log - ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip + if ls -l $d/shard_*_of_*/test.outputs/outputs.zip 2>/dev/null; then + zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs/outputs.zip + fi done find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf # Move test logs to Kokoro directory. tar is used to conveniently perform @@ -90,7 +92,13 @@ function collect_logs() { echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" fi - tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . + time tar \ + --verbose \ + --create \ + --gzip \ + --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \ + --directory "${RUNSC_LOGS_DIR}" \ + . fi fi fi diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index e26d6a7d2..b2fb6401a 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -26,6 +26,7 @@ const ( acceptPort = 2402 sendloopDuration = 2 * time.Second network = "udp4" + chainName = "foochain" ) func init() { @@ -40,6 +41,12 @@ func init() { RegisterTestCase(FilterInputDefaultPolicyAccept{}) RegisterTestCase(FilterInputDefaultPolicyDrop{}) RegisterTestCase(FilterInputReturnUnderflow{}) + RegisterTestCase(FilterInputSerializeJump{}) + RegisterTestCase(FilterInputJumpBasic{}) + RegisterTestCase(FilterInputJumpReturn{}) + RegisterTestCase(FilterInputJumpReturnDrop{}) + RegisterTestCase(FilterInputJumpBuiltin{}) + RegisterTestCase(FilterInputJumpTwice{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -267,13 +274,12 @@ func (FilterInputMultiUDPRules) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"); err != nil { - return err + rules := [][]string{ + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, + {"-L"}, } - return filterTable("-L") + return filterTableRules(rules) } // LocalAction implements TestCase.LocalAction. @@ -314,14 +320,13 @@ func (FilterInputCreateUserChain) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error { - // Create a chain. - const chainName = "foochain" - if err := filterTable("-N", chainName); err != nil { - return err + rules := [][]string{ + // Create a chain. + {"-N", chainName}, + // Add a simple rule to the chain. + {"-A", chainName, "-j", "DROP"}, } - - // Add a simple rule to the chain. - return filterTable("-A", chainName, "-j", "DROP") + return filterTableRules(rules) } // LocalAction implements TestCase.LocalAction. @@ -396,13 +401,12 @@ func (FilterInputReturnUnderflow) Name() string { func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { // Add a RETURN rule followed by an unconditional accept, and set the // default policy to DROP. - if err := filterTable("-A", "INPUT", "-j", "RETURN"); err != nil { - return err + rules := [][]string{ + {"-A", "INPUT", "-j", "RETURN"}, + {"-A", "INPUT", "-j", "DROP"}, + {"-P", "INPUT", "ACCEPT"}, } - if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil { - return err - } - if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { + if err := filterTableRules(rules); err != nil { return err } @@ -415,3 +419,179 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } + +// FilterInputSerializeJump verifies that we can serialize jumps. +type FilterInputSerializeJump struct{} + +// Name implements TestCase.Name. +func (FilterInputSerializeJump) Name() string { + return "FilterInputSerializeJump" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputSerializeJump) ContainerAction(ip net.IP) error { + // Write a JUMP rule, the serialize it with `-L`. + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-L"}, + } + return filterTableRules(rules) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputSerializeJump) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// FilterInputJumpBasic jumps to a chain and executes a rule there. +type FilterInputJumpBasic struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpBasic) Name() string { + return "FilterInputJumpBasic" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBasic) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBasic) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputJumpReturn jumps, returns, and executes a rule. +type FilterInputJumpReturn struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpReturn) Name() string { + return "FilterInputJumpReturn" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-N", chainName}, + {"-P", "INPUT", "ACCEPT"}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "RETURN"}, + {"-A", chainName, "-j", "DROP"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturn) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. +type FilterInputJumpReturnDrop struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpReturnDrop) Name() string { + return "FilterInputJumpReturnDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", "INPUT", "-j", "DROP"}, + {"-A", chainName, "-j", "RETURN"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. +type FilterInputJumpBuiltin struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpBuiltin) Name() string { + return "FilterInputJumpBuiltin" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil { + return fmt.Errorf("iptables should be unable to jump to a built-in chain") + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// FilterInputJumpTwice jumps twice, then returns twice and executes a rule. +type FilterInputJumpTwice struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpTwice) Name() string { + return "FilterInputJumpTwice" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { + const chainName2 = chainName + "2" + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-N", chainName2}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", chainName2}, + {"-A", "INPUT", "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // UDP packets should jump and return twice, eventually hitting the + // ACCEPT rule. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpTwice) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 7d061acba..29ad5932d 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -261,3 +261,39 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) { t.Fatal(err) } } + +func TestJumpSerialize(t *testing.T) { + if err := singleTest(FilterInputSerializeJump{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpBasic(t *testing.T) { + if err := singleTest(FilterInputJumpBasic{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpReturn(t *testing.T) { + if err := singleTest(FilterInputJumpReturn{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpReturnDrop(t *testing.T) { + if err := singleTest(FilterInputJumpReturnDrop{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpBuiltin(t *testing.T) { + if err := singleTest(FilterInputJumpBuiltin{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpTwice(t *testing.T) { + if err := singleTest(FilterInputJumpTwice{}); err != nil { + t.Fatal(err) + } +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 5c9199abf..32cf5a417 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -27,17 +27,16 @@ const iptablesBinary = "iptables" // filterTable calls `iptables -t filter` with the given args. func filterTable(args ...string) error { - args = append([]string{"-t", "filter"}, args...) - cmd := exec.Command(iptablesBinary, args...) - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) - } - return nil + return tableCmd("filter", args) } // natTable calls `iptables -t nat` with the given args. func natTable(args ...string) error { - args = append([]string{"-t", "nat"}, args...) + return tableCmd("nat", args) +} + +func tableCmd(table string, args []string) error { + args = append([]string{"-t", table}, args...) cmd := exec.Command(iptablesBinary, args...) if out, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) @@ -45,6 +44,16 @@ func natTable(args ...string) error { return nil } +// filterTableRules is like filterTable, but runs multiple iptables commands. +func filterTableRules(argsList [][]string) error { + for _, args := range argsList { + if err := filterTable(args...); err != nil { + return err + } + } + return nil +} + // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. func listenUDP(port int, timeout time.Duration) error { diff --git a/test/perf/BUILD b/test/perf/BUILD new file mode 100644 index 000000000..346a28e16 --- /dev/null +++ b/test/perf/BUILD @@ -0,0 +1,114 @@ +load("//test/runner:defs.bzl", "syscall_test") + +package(licenses = ["notice"]) + +syscall_test( + test = "//test/perf/linux:clock_getres_benchmark", +) + +syscall_test( + test = "//test/perf/linux:clock_gettime_benchmark", +) + +syscall_test( + test = "//test/perf/linux:death_benchmark", +) + +syscall_test( + test = "//test/perf/linux:epoll_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:fork_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:futex_benchmark", +) + +syscall_test( + size = "enormous", + test = "//test/perf/linux:getdents_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:getpid_benchmark", +) + +syscall_test( + size = "enormous", + test = "//test/perf/linux:gettid_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:mapping_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:open_benchmark", +) + +syscall_test( + test = "//test/perf/linux:pipe_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:randread_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:read_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:sched_yield_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:send_recv_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:seqwrite_benchmark", +) + +syscall_test( + size = "enormous", + test = "//test/perf/linux:signal_benchmark", +) + +syscall_test( + test = "//test/perf/linux:sleep_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:stat_benchmark", +) + +syscall_test( + size = "enormous", + add_overlay = True, + test = "//test/perf/linux:unlink_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:write_benchmark", +) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD new file mode 100644 index 000000000..b4e907826 --- /dev/null +++ b/test/perf/linux/BUILD @@ -0,0 +1,356 @@ +load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +cc_binary( + name = "getpid_benchmark", + testonly = 1, + srcs = [ + "getpid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "send_recv_benchmark", + testonly = 1, + srcs = [ + "send_recv_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "gettid_benchmark", + testonly = 1, + srcs = [ + "gettid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "sched_yield_benchmark", + testonly = 1, + srcs = [ + "sched_yield_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "clock_getres_benchmark", + testonly = 1, + srcs = [ + "clock_getres_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "clock_gettime_benchmark", + testonly = 1, + srcs = [ + "clock_gettime_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "open_benchmark", + testonly = 1, + srcs = [ + "open_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) + +cc_binary( + name = "read_benchmark", + testonly = 1, + srcs = [ + "read_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "randread_benchmark", + testonly = 1, + srcs = [ + "randread_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "write_benchmark", + testonly = 1, + srcs = [ + "write_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "seqwrite_benchmark", + testonly = 1, + srcs = [ + "seqwrite_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "pipe_benchmark", + testonly = 1, + srcs = [ + "pipe_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( + name = "fork_benchmark", + testonly = 1, + srcs = [ + "fork_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:cleanup", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "futex_benchmark", + testonly = 1, + srcs = [ + "futex_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "epoll_benchmark", + testonly = 1, + srcs = [ + "epoll_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:epoll_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "death_benchmark", + testonly = 1, + srcs = [ + "death_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "mapping_benchmark", + testonly = 1, + srcs = [ + "mapping_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:memory_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "signal_benchmark", + testonly = 1, + srcs = [ + "signal_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "getdents_benchmark", + testonly = 1, + srcs = [ + "getdents_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "sleep_benchmark", + testonly = 1, + srcs = [ + "sleep_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "stat_benchmark", + testonly = 1, + srcs = [ + "stat_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "unlink_benchmark", + testonly = 1, + srcs = [ + "unlink_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc new file mode 100644 index 000000000..b051293ad --- /dev/null +++ b/test/perf/linux/clock_getres_benchmark.cc @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <time.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +// clock_getres(1) is very nearly a no-op syscall, but it does require copying +// out to a userspace struct. It thus provides a nice small copy-out benchmark. +void BM_ClockGetRes(benchmark::State& state) { + struct timespec ts; + for (auto _ : state) { + clock_getres(CLOCK_MONOTONIC, &ts); + } +} + +BENCHMARK(BM_ClockGetRes); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc new file mode 100644 index 000000000..6691bebd9 --- /dev/null +++ b/test/perf/linux/clock_gettime_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <pthread.h> +#include <time.h> + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_ClockGettimeThreadCPUTime(benchmark::State& state) { + clockid_t clockid; + ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid)); + struct timespec tp; + + for (auto _ : state) { + clock_gettime(clockid, &tp); + } +} + +BENCHMARK(BM_ClockGettimeThreadCPUTime); + +void BM_VDSOClockGettime(benchmark::State& state) { + const clockid_t clock = state.range(0); + struct timespec tp; + absl::Time start = absl::Now(); + + // Don't benchmark the calibration phase. + while (absl::Now() < start + absl::Milliseconds(2100)) { + clock_gettime(clock, &tp); + } + + for (auto _ : state) { + clock_gettime(clock, &tp); + } +} + +BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc new file mode 100644 index 000000000..cb2b6fd07 --- /dev/null +++ b/test/perf/linux/death_benchmark.cc @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing +// the ability of gVisor (on whatever platform) to execute all the related +// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH. +TEST(DeathTest, ZeroEqualsOne) { + EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), ""); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc new file mode 100644 index 000000000..0b121338a --- /dev/null +++ b/test/perf/linux/epoll_benchmark.cc @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/epoll.h> +#include <sys/eventfd.h> + +#include <atomic> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <ctime> +#include <memory> + +#include "gtest/gtest.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/epoll_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Returns a new eventfd. +PosixErrorOr<FileDescriptor> NewEventFD() { + int fd = eventfd(0, /* flags = */ 0); + MaybeSave(); + if (fd < 0) { + return PosixError(errno, "eventfd"); + } + return FileDescriptor(fd); +} + +// Also stolen from epoll.cc unit tests. +void BM_EpollTimeout(benchmark::State& state) { + constexpr int kFDsPerEpoll = 3; + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < kFDsPerEpoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + } + + struct epoll_event result[kFDsPerEpoll]; + int timeout_ms = state.range(0); + + for (auto _ : state) { + EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms)); + } +} + +BENCHMARK(BM_EpollTimeout)->Range(0, 8); + +// Also stolen from epoll.cc unit tests. +void BM_EpollAllEvents(benchmark::State& state) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + const int fds_per_epoll = state.range(0); + constexpr uint64_t kEventVal = 5; + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < fds_per_epoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + + ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)), + SyscallSucceedsWithValue(sizeof(kEventVal))); + } + + std::vector<struct epoll_event> result(fds_per_epoll); + + for (auto _ : state) { + EXPECT_EQ(fds_per_epoll, + epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0)); + } +} + +BENCHMARK(BM_EpollAllEvents)->Range(2, 1024); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc new file mode 100644 index 000000000..84fdbc8a0 --- /dev/null +++ b/test/perf/linux/fork_benchmark.cc @@ -0,0 +1,350 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/synchronization/barrier.h" +#include "benchmark/benchmark.h" +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBusyMax = 250; + +// Do some CPU-bound busy-work. +int busy(int max) { + // Prevent the compiler from optimizing this work away, + volatile int count = 0; + + for (int i = 1; i < max; i++) { + for (int j = 2; j < i / 2; j++) { + if (i % j == 0) { + count++; + } + } + } + + return count; +} + +void BM_CPUBoundUniprocess(benchmark::State& state) { + for (auto _ : state) { + busy(kBusyMax); + } +} + +BENCHMARK(BM_CPUBoundUniprocess); + +void BM_CPUBoundAsymmetric(benchmark::State& state) { + const size_t max = state.max_iterations; + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < max; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + ASSERT_TRUE(state.KeepRunningBatch(max)); + + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + ASSERT_FALSE(state.KeepRunning()); +} + +BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime(); + +void BM_CPUBoundSymmetric(benchmark::State& state) { + std::vector<pid_t> children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + ASSERT_FALSE(state.KeepRunning()); + }); + + const int processes = state.range(0); + for (int i = 0; i < processes; i++) { + size_t cur = (state.max_iterations + (processes - 1)) / processes; + if ((state.iterations() + cur) >= state.max_iterations) { + cur = state.max_iterations - state.iterations(); + } + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < cur; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + if (cur > 0) { + // We can have a zero cur here, depending. + ASSERT_TRUE(state.KeepRunningBatch(cur)); + } + children.push_back(child); + } +} + +BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime(); + +// Child routine for ProcessSwitch/ThreadSwitch. +// Reads from readfd and writes the result to writefd. +void SwitchChild(int readfd, int writefd) { + while (1) { + char buf; + int ret = ReadFd(readfd, &buf, 1); + if (ret == 0) { + break; + } + TEST_CHECK_MSG(ret == 1, "read failed"); + + ret = WriteFd(writefd, &buf, 1); + if (ret == -1) { + TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure"); + break; + } + TEST_CHECK_MSG(ret == 1, "write failed"); + } +} + +// Send bytes in a loop through a series of pipes, each passing through a +// different process. +// +// Proc 0 Proc 1 +// * ----------> * +// ^ Pipe 1 | +// | | +// | Pipe 0 | Pipe 2 +// | | +// | | +// | Pipe 3 v +// * <---------- * +// Proc 3 Proc 2 +// +// This exercises context switching through multiple processes. +void BM_ProcessSwitch(benchmark::State& state) { + // Code below assumes there are at least two processes. + const int num_processes = state.range(0); + ASSERT_GE(num_processes, 2); + + std::vector<pid_t> children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + }); + + // Must come after children, as the FDs must be closed before the children + // will exit. + std::vector<FileDescriptor> read_fds; + std::vector<FileDescriptor> write_fds; + + for (int i = 0; i < num_processes; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This process is one of the processes in the loop. It will be considered + // index 0. + for (int i = 1; i < num_processes; i++) { + // Read from current pipe index, write to next. + const int read_index = i; + const int read_fd = read_fds[read_index].get(); + + const int write_index = (i + 1) % num_processes; + const int write_fd = write_fds[write_index].get(); + + // std::vector isn't safe to use from the fork child. + FileDescriptor* read_array = read_fds.data(); + FileDescriptor* write_array = write_fds.data(); + + pid_t child = fork(); + if (!child) { + // Close all other FDs. + for (int j = 0; j < num_processes; j++) { + if (j != read_index) { + read_array[j].reset(); + } + if (j != write_index) { + write_array[j].reset(); + } + } + + SwitchChild(read_fd, write_fd); + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + children.push_back(child); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } +} + +BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime(); + +// Equivalent to BM_ThreadSwitch using threads instead of processes. +void BM_ThreadSwitch(benchmark::State& state) { + // Code below assumes there are at least two threads. + const int num_threads = state.range(0); + ASSERT_GE(num_threads, 2); + + // Must come after threads, as the FDs must be closed before the children + // will exit. + std::vector<std::unique_ptr<ScopedThread>> threads; + std::vector<FileDescriptor> read_fds; + std::vector<FileDescriptor> write_fds; + + for (int i = 0; i < num_threads; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This thread is one of the threads in the loop. It will be considered + // index 0. + for (int i = 1; i < num_threads; i++) { + // Read from current pipe index, write to next. + // + // Transfer ownership of the FDs to the thread. + const int read_index = i; + const int read_fd = read_fds[read_index].release(); + + const int write_index = (i + 1) % num_threads; + const int write_fd = write_fds[write_index].release(); + + threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] { + FileDescriptor read(read_fd); + FileDescriptor write(write_fd); + SwitchChild(read.get(), write.get()); + })); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } + + // The two FDs still owned by this thread are closed, causing the next thread + // to exit its loop and close its FDs, and so on until all threads exit. +} + +BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime(); + +void BM_ThreadStart(benchmark::State& state) { + const int num_threads = state.range(0); + + for (auto _ : state) { + state.PauseTiming(); + + auto barrier = new absl::Barrier(num_threads + 1); + std::vector<std::unique_ptr<ScopedThread>> threads; + + state.ResumeTiming(); + + for (size_t i = 0; i < num_threads; ++i) { + threads.emplace_back(std::make_unique<ScopedThread>([barrier] { + if (barrier->Block()) { + delete barrier; + } + })); + } + + if (barrier->Block()) { + delete barrier; + } + + state.PauseTiming(); + + for (const auto& thread : threads) { + thread->Join(); + } + + state.ResumeTiming(); + } +} + +BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime(); + +// Benchmark the complete fork + exit + wait. +void BM_ProcessLifecycle(benchmark::State& state) { + const int num_procs = state.range(0); + + std::vector<pid_t> pids(num_procs); + for (auto _ : state) { + for (size_t i = 0; i < num_procs; ++i) { + int pid = fork(); + if (pid == 0) { + _exit(0); + } + ASSERT_THAT(pid, SyscallSucceeds()); + pids[i] = pid; + } + + for (const int pid : pids) { + ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0), + SyscallSucceedsWithValue(pid)); + } + } +} + +BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc new file mode 100644 index 000000000..b349d50bf --- /dev/null +++ b/test/perf/linux/futex_benchmark.cc @@ -0,0 +1,248 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/futex.h> + +#include <atomic> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <ctime> + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +inline int FutexWait(std::atomic<int32_t>* v, int32_t val) { + return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, nullptr); +} + +inline int FutexWaitRelativeTimeout(std::atomic<int32_t>* v, int32_t val, + const struct timespec* reltime) { + return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, reltime); +} + +inline int FutexWaitAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val, + const struct timespec* abstime) { + return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, abstime); +} + +inline int FutexWaitBitsetAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val, + int32_t bits, + const struct timespec* abstime) { + return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME, + val, abstime, nullptr, bits); +} + +inline int FutexWake(std::atomic<int32_t>* v, int32_t count) { + return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count); +} + +// This just uses FUTEX_WAKE on an address with nothing waiting, very simple. +void BM_FutexWakeNop(benchmark::State& state) { + std::atomic<int32_t> v(0); + + for (auto _ : state) { + EXPECT_EQ(0, FutexWake(&v, 1)); + } +} + +BENCHMARK(BM_FutexWakeNop); + +// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the +// syscall won't wait. +void BM_FutexWaitNop(benchmark::State& state) { + std::atomic<int32_t> v(0); + + for (auto _ : state) { + EXPECT_EQ(-EAGAIN, FutexWait(&v, 1)); + } +} + +BENCHMARK(BM_FutexWaitNop); + +// This uses FUTEX_WAIT with a timeout on an address whose value never +// changes, such that it always times out. Timeout overhead can be estimated by +// timer overruns for short timeouts. +void BM_FutexWaitTimeout(benchmark::State& state) { + const int timeout_ns = state.range(0); + std::atomic<int32_t> v(0); + auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); + + for (auto _ : state) { + EXPECT_EQ(-ETIMEDOUT, FutexWaitRelativeTimeout(&v, 0, &ts)); + } +} + +BENCHMARK(BM_FutexWaitTimeout) + ->Arg(1) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(10000); + +// This calls FUTEX_WAIT_BITSET with CLOCK_REALTIME. +void BM_FutexWaitBitset(benchmark::State& state) { + std::atomic<int32_t> v(0); + int timeout_ns = state.range(0); + auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); + for (auto _ : state) { + EXPECT_EQ(-ETIMEDOUT, FutexWaitBitsetAbsoluteTimeout(&v, 0, 1, &ts)); + } +} + +BENCHMARK(BM_FutexWaitBitset)->Range(0, 100000); + +int64_t GetCurrentMonotonicTimeNanos() { + struct timespec ts; + TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1); + return ts.tv_sec * 1000000000ULL + ts.tv_nsec; +} + +void SpinNanos(int64_t delay_ns) { + if (delay_ns <= 0) { + return; + } + const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns; + while (GetCurrentMonotonicTimeNanos() < end) { + // spin + } +} + +// Each iteration of FutexRoundtripDelayed involves a thread sending a futex +// wakeup to another thread, which spins for delay_us and then sends a futex +// wakeup back. The time per iteration is 2* (delay_us + kBeforeWakeDelayNs + +// futex/scheduling overhead). +void BM_FutexRoundtripDelayed(benchmark::State& state) { + const int delay_us = state.range(0); + + const int64_t delay_ns = delay_us * 1000; + // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce + // the probability that the wakeup comes before the wait, preventing the wait + // from ever taking effect and causing the benchmark to underestimate the + // actual wakeup time. + constexpr int64_t kBeforeWakeDelayNs = 500; + std::atomic<int32_t> v(0); + ScopedThread t([&] { + for (int i = 0; i < state.max_iterations; i++) { + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 0) { + FutexWait(&v, 0); + } + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(0, std::memory_order_release); + FutexWake(&v, 1); + } + }); + for (auto _ : state) { + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(1, std::memory_order_release); + FutexWake(&v, 1); + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 1) { + FutexWait(&v, 1); + } + } +} + +BENCHMARK(BM_FutexRoundtripDelayed) + ->Arg(0) + ->Arg(10) + ->Arg(20) + ->Arg(50) + ->Arg(100); + +// FutexLock is a simple, dumb futex based lock implementation. +// It will try to acquire the lock by atomically incrementing the +// lock word. If it did not increment the lock from 0 to 1, someone +// else has the lock, so it will FUTEX_WAIT until it is woken in +// the unlock path. +class FutexLock { + public: + FutexLock() : lock_word_(0) {} + + void lock(struct timespec* deadline) { + int32_t val; + while ((val = lock_word_.fetch_add(1, std::memory_order_acquire) + 1) != + 1) { + // If we didn't get the lock by incrementing from 0 to 1, + // do a FUTEX_WAIT with the desired current value set to + // val. If val is no longer what the atomic increment returned, + // someone might have set it to 0 so we can try to acquire + // again. + int ret = FutexWaitAbsoluteTimeout(&lock_word_, val, deadline); + if (ret == 0 || ret == -EWOULDBLOCK || ret == -EINTR) { + continue; + } else { + FAIL() << "unexpected FUTEX_WAIT return: " << ret; + } + } + } + + void unlock() { + // Store 0 into the lock word and wake one waiter. We intentionally + // ignore the return value of the FUTEX_WAKE here, since there may be + // no waiters to wake anyway. + lock_word_.store(0, std::memory_order_release); + (void)FutexWake(&lock_word_, 1); + } + + private: + std::atomic<int32_t> lock_word_; +}; + +FutexLock* test_lock; // Used below. + +void FutexContend(benchmark::State& state, int thread_index, + struct timespec* deadline) { + int counter = 0; + if (thread_index == 0) { + test_lock = new FutexLock(); + } + for (auto _ : state) { + test_lock->lock(deadline); + counter++; + test_lock->unlock(); + } + if (thread_index == 0) { + delete test_lock; + } + state.SetItemsProcessed(state.iterations()); +} + +void BM_FutexContend(benchmark::State& state) { + FutexContend(state, state.thread_index, nullptr); +} + +BENCHMARK(BM_FutexContend)->ThreadRange(1, 1024)->UseRealTime(); + +void BM_FutexDeadlineContend(benchmark::State& state) { + auto deadline = absl::ToTimespec(absl::Now() + absl::Minutes(10)); + FutexContend(state, state.thread_index, &deadline); +} + +BENCHMARK(BM_FutexDeadlineContend)->ThreadRange(1, 1024)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc new file mode 100644 index 000000000..afc599ad2 --- /dev/null +++ b/test/perf/linux/getdents_benchmark.cc @@ -0,0 +1,149 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +#ifndef SYS_getdents64 +#if defined(__x86_64__) +#define SYS_getdents64 217 +#elif defined(__aarch64__) +#define SYS_getdents64 217 +#else +#error "Unknown architecture" +#endif +#endif // SYS_getdents64 + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBufferSize = 16384; + +PosixErrorOr<TempPath> CreateDirectory(int count, + std::vector<std::string>* files) { + ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir()); + + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (int i = 0; i < count; i++) { + auto file = NewTempRelPath(); + auto res = MknodAt(dfd, file, S_IFREG | 0644, 0); + RETURN_IF_ERRNO(res); + files->push_back(file); + } + + return std::move(dir); +} + +PosixError CleanupDirectory(const TempPath& dir, + std::vector<std::string>* files) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (auto it = files->begin(); it != files->end(); ++it) { + auto res = UnlinkAt(dfd, *it, 0); + RETURN_IF_ERRNO(res); + } + return NoError(); +} + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a single FD. +void BM_GetdentsSameFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector<std::string> files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds()); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a new FD each time. +void BM_GetdentsNewFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector<std::string> files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc new file mode 100644 index 000000000..db74cb264 --- /dev/null +++ b/test/perf/linux/getpid_benchmark.cc @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/syscall.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Getpid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_getpid); + } +} + +BENCHMARK(BM_Getpid); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc new file mode 100644 index 000000000..8f4961f5e --- /dev/null +++ b/test/perf/linux/gettid_benchmark.cc @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/syscall.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Gettid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_gettid); + } +} + +BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc new file mode 100644 index 000000000..39c30fe69 --- /dev/null +++ b/test/perf/linux/mapping_benchmark.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdlib.h> +#include <sys/mman.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/memory_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Conservative value for /proc/sys/vm/max_map_count, which limits the number of +// VMAs, minus a safety margin for VMAs that already exist for the test binary. +// The default value for max_map_count is +// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530. +constexpr size_t kMaxVMAs = 64001; + +// Map then unmap pages without touching them. +void BM_MapUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map, touch, then unmap pages. +void BM_MapTouchUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast<char*>(addr); + char* end = c + pages * kPageSize; + while (c < end) { + *c = 42; + c += kPageSize; + } + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map and touch many pages, unmapping all at once. +// +// NOTE(b/111429208): This is a regression test to ensure performant mapping and +// allocation even with tons of mappings. +void BM_MapTouchMany(benchmark::State& state) { + // Number of pages to map. + const int page_count = state.range(0); + + while (state.KeepRunning()) { + std::vector<void*> pages; + + for (int i = 0; i < page_count; i++) { + void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast<char*>(addr); + *c = 42; + + pages.push_back(addr); + } + + for (void* addr : pages) { + int ret = munmap(addr, kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } + } + + state.SetBytesProcessed(kPageSize * page_count * state.iterations()); +} + +BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime(); + +void BM_PageFault(benchmark::State& state) { + // Map the region in which we will take page faults. To ensure that each page + // fault maps only a single page, each page we touch must correspond to a + // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However, + // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that + // if there are background threads running, they can't inadvertently creating + // mappings in our gaps that are unmapped when the test ends. + size_t test_pages = kMaxVMAs; + // Ensure that test_pages is odd, since we want the test region to both + // begin and end with a mapped page. + if (test_pages % 2 == 0) { + test_pages--; + } + const size_t test_region_bytes = test_pages * kPageSize; + // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on + // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE + // so that Linux pre-allocates the shmem file used to back the mapping. + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE)); + for (size_t i = 0; i < test_pages / 2; i++) { + ASSERT_THAT( + mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)), + kPageSize, PROT_NONE), + SyscallSucceeds()); + } + + const size_t mapped_pages = test_pages / 2 + 1; + // "Start" at the end of the mapped region to force the mapped region to be + // reset, since we mapped it with MAP_POPULATE. + size_t cur_page = mapped_pages; + for (auto _ : state) { + if (cur_page >= mapped_pages) { + // We've reached the end of our mapped region and have to reset it to + // incur page faults again. + state.PauseTiming(); + ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED), + SyscallSucceeds()); + cur_page = 0; + state.ResumeTiming(); + } + const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize); + const char c = *reinterpret_cast<volatile char*>(addr); + benchmark::DoNotOptimize(c); + cur_page++; + } +} + +BENCHMARK(BM_PageFault)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc new file mode 100644 index 000000000..68008f6d5 --- /dev/null +++ b/test/perf/linux/open_benchmark.cc @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Open(benchmark::State& state) { + const int size = state.range(0); + std::vector<TempPath> cache; + for (int i = 0; i < size; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + cache.emplace_back(std::move(path)); + } + + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(cache[chosen].path().c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + close(fd); + } +} + +BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc new file mode 100644 index 000000000..8f5f6a2a3 --- /dev/null +++ b/test/perf/linux/pipe_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <cerrno> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Pipe(benchmark::State& state) { + int fds[2]; + TEST_CHECK(pipe(fds) == 0); + + const int size = state.range(0); + std::vector<char> wbuf(size); + std::vector<char> rbuf(size); + RandomizeBuffer(wbuf.data(), size); + + ScopedThread t([&] { + auto const fd = fds[1]; + for (int i = 0; i < state.max_iterations; i++) { + TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size); + } + }); + + for (auto _ : state) { + TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size); + } + + t.Join(); + + close(fds[0]); + close(fds[1]); + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc new file mode 100644 index 000000000..b0eb8c24e --- /dev/null +++ b/test/perf/linux/randread_benchmark.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/uio.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Create a 1GB file that will be read from at random positions. This should +// invalid any performance gains from caching. +const uint64_t kFileSize = 1ULL << 30; + +// How many bytes to write at once to initialize the file used to read from. +const uint32_t kWriteSize = 65536; + +// Largest benchmarked read unit. +const uint32_t kMaxRead = 1UL << 26; + +TempPath CreateFile(uint64_t file_size) { + auto path = TempPath::CreateFile().ValueOrDie(); + FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie(); + + // Try to minimize syscalls by using maximum size writev() requests. + std::vector<char> buffer(kWriteSize); + RandomizeBuffer(buffer.data(), buffer.size()); + const std::vector<std::vector<struct iovec>> iovecs_list = + GenerateIovecs(file_size, buffer.data(), buffer.size()); + for (const auto& iovecs : iovecs_list) { + TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0); + } + + return path; +} + +// Global test state, initialized once per process lifetime. +struct GlobalState { + const TempPath tmpfile; + explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {} +}; + +GlobalState& GetGlobalState() { + // This gets created only once throughout the lifetime of the process. + // Use a dynamically allocated object (that is never deleted) to avoid order + // of destruction of static storage variables issues. + static GlobalState* const state = + // The actual file size is the maximum random seek range (kFileSize) + the + // maximum read size so we can read that number of bytes at the end of the + // file. + new GlobalState(CreateFile(kFileSize + kMaxRead)); + return *state; +} + +void BM_RandRead(benchmark::State& state) { + const int size = state.range(0); + + GlobalState& global_state = GetGlobalState(); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY)); + std::vector<char> buf(size); + + unsigned int seed = 1; + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), + rand_r(&seed) % kFileSize) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc new file mode 100644 index 000000000..62445867d --- /dev/null +++ b/test/perf/linux/read_benchmark.cc @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Read(benchmark::State& state) { + const int size = state.range(0); + const std::string contents(size, 0); + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY)); + + std::vector<char> buf(size); + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc new file mode 100644 index 000000000..6756b5575 --- /dev/null +++ b/test/perf/linux/sched_yield_benchmark.cc @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sched.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Sched_yield(benchmark::State& state) { + for (auto ignored : state) { + TEST_CHECK(sched_yield() == 0); + } +} + +BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc new file mode 100644 index 000000000..d73e49523 --- /dev/null +++ b/test/perf/linux/send_recv_benchmark.cc @@ -0,0 +1,372 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> + +#include <cstring> + +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "benchmark/benchmark.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr ssize_t kMessageSize = 1024; + +class Message { + public: + explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {} + + explicit Message(int byte, int sz) : Message(byte, sz, 0) {} + + explicit Message(int byte, int sz, int cmsg_sz) + : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) { + iov_.iov_base = buffer_.data(); + iov_.iov_len = sz; + hdr_.msg_iov = &iov_; + hdr_.msg_iovlen = 1; + hdr_.msg_control = cmsg_buffer_.data(); + hdr_.msg_controllen = cmsg_sz; + } + + struct msghdr* header() { + return &hdr_; + } + + private: + std::vector<char> buffer_; + std::vector<char> cmsg_buffer_; + struct iovec iov_ = {}; + struct msghdr hdr_ = {}; +}; + +void BM_Recvmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvmsg)->UseRealTime(); + +void BM_Sendmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + recvmsg(recv_socket.get(), recv_msg.header(), 0); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendmsg(send_socket.get(), send_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendmsg)->UseRealTime(); + +void BM_Recvfrom(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&send_socket, &send_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + } + }); + + int bytes_received = 0; + for (auto ignored : state) { + int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvfrom)->UseRealTime(); + +void BM_Sendto(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&recv_socket, &recv_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendto)->UseRealTime(); + +PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate +// space for control messages. Note that we do not expect to receive any. +void BM_RecvmsgWithControlBuf(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + absl::Notification notification; + Message send_msg('a'); + // Create a msghdr with a buffer allocated for control messages. + Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24); + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime(); + +// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes. +// +// state.Args[0] indicates whether the underlying socket should be blocking or +// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking. +// state.Args[1] is the size of the payload to be used per sendmsg call. +void BM_SendmsgTCP(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // Check if we want to run the test w/ a blocking send socket + // or non-blocking. + const int blocking = state.range(0); + if (!blocking) { + // Set the send FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds()); + } + + absl::Notification notification; + + // Get the buffer size we should use for this iteration of the test. + const int buf_size = state.range(1); + Message send_msg('a', buf_size), recv_msg(0, buf_size); + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0); + } + }); + + int64_t bytes_sent = 0; + int ncalls = 0; + for (auto ignored : state) { + int sent = 0; + while (true) { + struct msghdr hdr = {}; + struct iovec iov = {}; + struct msghdr* snd_header = send_msg.header(); + iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent; + iov.iov_len = snd_header->msg_iov->iov_len - sent; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0); + ncalls++; + if (n > 0) { + sent += n; + if (sent == buf_size) { + break; + } + // n can be > 0 but less than requested size. In which case we don't + // poll. + continue; + } + // Poll the fd for it to become writable. + struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), + SyscallSucceedsWithValue(0)); + } + bytes_sent += static_cast<int64_t>(sent); + } + + notification.Notify(); + send_socket.reset(); + state.SetBytesProcessed(bytes_sent); +} + +void Args(benchmark::internal::Benchmark* benchmark) { + for (int blocking = 0; blocking < 2; blocking++) { + for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) { + benchmark->Args({blocking, buf_size}); + } + } +} + +BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc new file mode 100644 index 000000000..af49e4477 --- /dev/null +++ b/test/perf/linux/seqwrite_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// The maximum file size of the test file, when writes get beyond this point +// they wrap around. This should be large enough to blow away caches. +const uint64_t kMaxFile = 1 << 30; + +// Perform writes of various sizes sequentially to one file. Wraps around if it +// goes above a certain maximum file size. +void BM_SeqWrite(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector<char> buf(size); + RandomizeBuffer(buf.data(), buf.size()); + + // Start writes at offset 0. + uint64_t offset = 0; + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) == + buf.size()); + offset += buf.size(); + // Wrap around if going above the maximum file size. + if (offset >= kMaxFile) { + offset = 0; + } + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc new file mode 100644 index 000000000..a6928df58 --- /dev/null +++ b/test/perf/linux/signal_benchmark.cc @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> +#include <string.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void FixupHandler(int sig, siginfo_t* si, void* void_ctx) { + static unsigned int dataval = 0; + + // Skip the offending instruction. + ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx); + ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval); +} + +void BM_FaultSignalFixup(benchmark::State& state) { + // Set up the signal handler. + struct sigaction sa = {}; + sigemptyset(&sa.sa_mask); + sa.sa_sigaction = FixupHandler; + sa.sa_flags = SA_SIGINFO; + TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0); + + // Fault, fault, fault. + for (auto _ : state) { + register volatile unsigned int* ptr asm("rax"); + + // Trigger the segfault. + ptr = nullptr; + *ptr = 0; + } +} + +BENCHMARK(BM_FaultSignalFixup)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc new file mode 100644 index 000000000..99ef05117 --- /dev/null +++ b/test/perf/linux/sleep_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <sys/syscall.h> +#include <time.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Sleep for 'param' nanoseconds. +void BM_Sleep(benchmark::State& state) { + const int nanoseconds = state.range(0); + + for (auto _ : state) { + struct timespec ts; + ts.tv_sec = 0; + ts.tv_nsec = nanoseconds; + + int ret; + do { + ret = syscall(SYS_nanosleep, &ts, &ts); + if (ret < 0) { + TEST_CHECK(errno == EINTR); + } + } while (ret < 0); + } +} + +BENCHMARK(BM_Sleep) + ->Arg(0) + ->Arg(1) + ->Arg(1000) // 1us + ->Arg(1000 * 1000) // 1ms + ->Arg(10 * 1000 * 1000) // 10ms + ->Arg(50 * 1000 * 1000) // 50ms + ->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc new file mode 100644 index 000000000..f15424482 --- /dev/null +++ b/test/perf/linux/stat_benchmark.cc @@ -0,0 +1,62 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a file in a nested directory hierarchy at least `depth` directories +// deep, and stats that file multiple times. +void BM_Stat(benchmark::State& state) { + // Create nested directories with given depth. + int depth = state.range(0); + const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string dir_path = top_dir.path(); + + while (depth-- > 0) { + // Don't use TempPath because it will make paths too long to use. + // + // The top_dir destructor will clean up this whole tree. + dir_path = JoinPath(dir_path, absl::StrCat(depth)); + ASSERT_NO_ERRNO(Mkdir(dir_path, 0755)); + } + + // Create the file that will be stat'd. + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path)); + + struct stat st; + for (auto _ : state) { + ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds()); + } +} + +BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc new file mode 100644 index 000000000..92243a042 --- /dev/null +++ b/test/perf/linux/unlink_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a directory containing `files` files, and unlinks all the files. +void BM_Unlink(benchmark::State& state) { + // Create directory with given files. + const int file_count = state.range(0); + + // We unlink all files on each iteration, but report this as a "batch" + // iteration so that reported times are per file. + TempPath dir; + while (state.KeepRunningBatch(file_count)) { + state.PauseTiming(); + // N.B. dir is declared outside the loop so that destruction of the previous + // iteration's directory occurs here, inside of PauseTiming. + dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + std::vector<TempPath> files; + for (int i = 0; i < file_count; i++) { + TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + files.push_back(std::move(file)); + } + state.ResumeTiming(); + + while (!files.empty()) { + // Destructor unlinks. + files.pop_back(); + } + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc new file mode 100644 index 000000000..7b060c70e --- /dev/null +++ b/test/perf/linux/write_benchmark.cc @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Write(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector<char> buf(size); + RandomizeBuffer(buf.data(), size); + + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/runner/BUILD b/test/runner/BUILD new file mode 100644 index 000000000..9959ef9b0 --- /dev/null +++ b/test/runner/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "runner", + testonly = 1, + srcs = ["runner.go"], + data = [ + "//runsc", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//runsc/specutils", + "//runsc/testutil", + "//test/runner/gtest", + "//test/uds", + "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/syscalls/build_defs.bzl b/test/runner/defs.bzl index cbab85ef7..56743a526 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/runner/defs.bzl @@ -1,6 +1,119 @@ """Defines a rule for syscall test targets.""" -load("//tools:defs.bzl", "loopback") +load("//tools:defs.bzl", "default_platform", "loopback", "platforms") + +def _runner_test_impl(ctx): + # Generate a runner binary. + runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "set -euf -x -o pipefail", + "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then", + " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + "fi", + "exec %s %s %s\n" % ( + ctx.files.runner[0].short_path, + " ".join(ctx.attr.runner_args), + ctx.files.test[0].short_path, + ), + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + # Return with all transitive files. + runfiles = ctx.runfiles( + transitive_files = depset(transitive = [ + depset(target.data_runfiles.files) + for target in (ctx.attr.runner, ctx.attr.test) + if hasattr(target, "data_runfiles") + ]), + files = ctx.files.runner + ctx.files.test, + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = runner, runfiles = runfiles)] + +_runner_test = rule( + attrs = { + "runner": attr.label( + default = "//test/runner:runner", + ), + "test": attr.label( + mandatory = True, + ), + "runner_args": attr.string_list(), + "data": attr.label_list( + allow_files = True, + ), + }, + test = True, + implementation = _runner_test_impl, +) + +def _syscall_test( + test, + shard_count, + size, + platform, + use_tmpfs, + tags, + network = "none", + file_access = "exclusive", + overlay = False, + add_uds_tree = False): + # Prepend "runsc" to non-native platform names. + full_platform = platform if platform == "native" else "runsc_" + platform + + # Name the test appropriately. + name = test.split(":")[1] + "_" + full_platform + if file_access == "shared": + name += "_shared" + if overlay: + name += "_overlay" + if network != "none": + name += "_" + network + "net" + + # Apply all tags. + if tags == None: + tags = [] + + # Add the full_platform and file access in a tag to make it easier to run + # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. + tags += [full_platform, "file_" + file_access] + + # Hash this target into one of 15 buckets. This can be used to + # randomly split targets between different workflows. + hash15 = hash(native.package_name() + name) % 15 + tags.append("hash15:" + str(hash15)) + + # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until + # we figure out how to request ipv4 sockets on Guitar machines. + if network == "host": + tags.append("noguitar") + + # Disable off-host networking. + tags.append("requires-net:loopback") + + runner_args = [ + # Arguments are passed directly to runner binary. + "--platform=" + platform, + "--network=" + network, + "--use-tmpfs=" + str(use_tmpfs), + "--file-access=" + file_access, + "--overlay=" + str(overlay), + "--add-uds-tree=" + str(add_uds_tree), + ] + + # Call the rule above. + _runner_test( + name = name, + test = test, + runner_args = runner_args, + data = [loopback], + size = size, + tags = tags, + shard_count = shard_count, + ) def syscall_test( test, @@ -23,6 +136,8 @@ def syscall_test( add_hostinet: add a hostinet test. tags: starting test tags. """ + if not tags: + tags = [] _syscall_test( test = test, @@ -34,35 +149,26 @@ def syscall_test( tags = tags, ) - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "kvm", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) + for (platform, platform_tags) in platforms.items(): + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platform_tags + tags, + ) if add_overlay: _syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, overlay = True, ) @@ -72,10 +178,10 @@ def syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, file_access = "shared", ) @@ -84,97 +190,9 @@ def syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = use_tmpfs, network = "host", add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, ) - -def _syscall_test( - test, - shard_count, - size, - platform, - use_tmpfs, - tags, - network = "none", - file_access = "exclusive", - overlay = False, - add_uds_tree = False): - test_name = test.split(":")[1] - - # Prepend "runsc" to non-native platform names. - full_platform = platform if platform == "native" else "runsc_" + platform - - name = test_name + "_" + full_platform - if file_access == "shared": - name += "_shared" - if overlay: - name += "_overlay" - if network != "none": - name += "_" + network + "net" - - if tags == None: - tags = [] - - # Add the full_platform and file access in a tag to make it easier to run - # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. - tags += [full_platform, "file_" + file_access] - - # Hash this target into one of 15 buckets. This can be used to - # randomly split targets between different workflows. - hash15 = hash(native.package_name() + name) % 15 - tags.append("hash15:" + str(hash15)) - - # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until - # we figure out how to request ipv4 sockets on Guitar machines. - if network == "host": - tags.append("noguitar") - - # Disable off-host networking. - tags.append("requires-net:loopback") - - # Add tag to prevent the tests from running in a Bazel sandbox. - # TODO(b/120560048): Make the tests run without this tag. - tags.append("no-sandbox") - - # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is - # more stable. - if platform == "kvm": - tags.append("manual") - tags.append("requires-kvm") - - # TODO(b/112165693): Remove when tests pass reliably. - tags.append("notap") - - args = [ - # Arguments are passed directly to syscall_test_runner binary. - "--test-name=" + test_name, - "--platform=" + platform, - "--network=" + network, - "--use-tmpfs=" + str(use_tmpfs), - "--file-access=" + file_access, - "--overlay=" + str(overlay), - "--add-uds-tree=" + str(add_uds_tree), - ] - - sh_test( - srcs = ["syscall_test_runner.sh"], - name = name, - data = [ - ":syscall_test_runner", - loopback, - test, - ], - args = args, - size = size, - tags = tags, - shard_count = shard_count, - ) - -def sh_test(**kwargs): - """Wraps the standard sh_test.""" - native.sh_test( - **kwargs - ) diff --git a/test/syscalls/gtest/BUILD b/test/runner/gtest/BUILD index de4b2727c..de4b2727c 100644 --- a/test/syscalls/gtest/BUILD +++ b/test/runner/gtest/BUILD diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go new file mode 100644 index 000000000..f96e2415e --- /dev/null +++ b/test/runner/gtest/gtest.go @@ -0,0 +1,167 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gtest contains helpers for running google-test tests from Go. +package gtest + +import ( + "fmt" + "os/exec" + "strings" +) + +var ( + // listTestFlag is the flag that will list tests in gtest binaries. + listTestFlag = "--gtest_list_tests" + + // filterTestFlag is the flag that will filter tests in gtest binaries. + filterTestFlag = "--gtest_filter" + + // listBechmarkFlag is the flag that will list benchmarks in gtest binaries. + listBenchmarkFlag = "--benchmark_list_tests" + + // filterBenchmarkFlag is the flag that will run specified benchmarks. + filterBenchmarkFlag = "--benchmark_filter" +) + +// TestCase is a single gtest test case. +type TestCase struct { + // Suite is the suite for this test. + Suite string + + // Name is the name of this individual test. + Name string + + // all indicates that this will run without flags. This takes + // precendence over benchmark below. + all bool + + // benchmark indicates that this is a benchmark. In this case, the + // suite will be empty, and we will use the appropriate test and + // benchmark flags. + benchmark bool +} + +// FullName returns the name of the test including the suite. It is suitable to +// pass to "-gtest_filter". +func (tc TestCase) FullName() string { + return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) +} + +// Args returns arguments to be passed when invoking the test. +func (tc TestCase) Args() []string { + if tc.all { + return []string{} // No arguments. + } + if tc.benchmark { + return []string{ + fmt.Sprintf("%s=^$", filterTestFlag), + fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), + } + } + return []string{ + fmt.Sprintf("%s=^%s$", filterTestFlag, tc.FullName()), + fmt.Sprintf("%s=^$", filterBenchmarkFlag), + } +} + +// ParseTestCases calls a gtest test binary to list its test and returns a +// slice with the name and suite of each test. +// +// If benchmarks is true, then benchmarks will be included in the list of test +// cases provided. Note that this requires the binary to support the +// benchmarks_list_tests flag. +func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) { + // Run to extract test cases. + args := append([]string{listTestFlag}, extraArgs...) + cmd := exec.Command(testBin, args...) + out, err := cmd.Output() + if err != nil { + // We failed to list tests with the given flags. Just + // return something that will run the binary with no + // flags, which should execute all tests. + return []TestCase{ + TestCase{ + Suite: "Default", + Name: "All", + all: true, + }, + }, nil + } + + // Parse test output. + var t []TestCase + var suite string + for _, line := range strings.Split(string(out), "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // New suite? + if !strings.HasPrefix(line, " ") { + suite = strings.TrimSuffix(strings.TrimSpace(line), ".") + continue + } + + // Individual test. + name := strings.TrimSpace(line) + + // Do we have a suite yet? + if suite == "" { + return nil, fmt.Errorf("test without a suite: %v", name) + } + + // Add this individual test. + t = append(t, TestCase{ + Suite: suite, + Name: name, + }) + } + + // Finished? + if !benchmarks { + return t, nil + } + + // Run again to extract benchmarks. + args = append([]string{listBenchmarkFlag}, extraArgs...) + cmd = exec.Command(testBin, args...) + out, err = cmd.Output() + if err != nil { + // We were able to enumerate tests above, but not benchmarks? + // We requested them, so we return an error in this case. + exitErr, ok := err.(*exec.ExitError) + if !ok { + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err) + } + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) + } + + // Parse benchmark output. + for _, line := range strings.Split(string(out), "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // Single benchmark. + name := strings.TrimSpace(line) + + // Add the single benchmark. + t = append(t, TestCase{ + Suite: "Benchmarks", + Name: name, + benchmark: true, + }) + } + + return t, nil +} diff --git a/test/syscalls/syscall_test_runner.go b/test/runner/runner.go index ae342b68c..a78ef38e0 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/runner/runner.go @@ -34,15 +34,11 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/specutils" "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/syscalls/gtest" + "gvisor.dev/gvisor/test/runner/gtest" "gvisor.dev/gvisor/test/uds" ) -// Location of syscall tests, relative to the repo root. -const testDir = "test/syscalls/linux" - var ( - testName = flag.String("test-name", "", "name of test binary to run") debug = flag.Bool("debug", false, "enable debug logs") strace = flag.Bool("strace", false, "enable strace logs") platform = flag.String("platform", "ptrace", "platform to run on") @@ -103,7 +99,7 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) } - cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + cmd := exec.Command(testBin, tc.Args()...) cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -296,7 +292,7 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Run a new container with the test executable and filter for the // given test suite and name. - spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. @@ -404,9 +400,10 @@ func matchString(a, b string) (bool, error) { func main() { flag.Parse() - if *testName == "" { - fatalf("test-name flag must be provided") + if flag.NArg() != 1 { + fatalf("test must be provided") } + testBin := flag.Args()[0] // Only argument. log.SetLevel(log.Info) if *debug { @@ -436,15 +433,8 @@ func main() { } } - // Get path to test binary. - fullTestName := filepath.Join(testDir, *testName) - testBin, err := testutil.FindFile(fullTestName) - if err != nil { - fatalf("FindFile(%q) failed: %v", fullTestName, err) - } - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin) + testCases, err := gtest.ParseTestCases(testBin, true) if err != nil { fatalf("ParseTestCases(%q) failed: %v", testBin, err) } @@ -455,14 +445,19 @@ func main() { fatalf("TestsForShard() failed: %v", err) } + // Resolve the absolute path for the binary. + testBin, err = filepath.Abs(testBin) + if err != nil { + fatalf("Abs() failed: %v", err) + } + // Run the tests. var tests []testing.InternalTest for _, tci := range indices { // Capture tc. tc := testCases[tci] - testName := fmt.Sprintf("%s_%s", tc.Suite, tc.Name) tests = append(tests, testing.InternalTest{ - Name: testName, + Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), F: func(t *testing.T) { if *parallel { t.Parallel() diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 31d239c0e..3518e862d 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1,5 +1,4 @@ -load("//tools:defs.bzl", "go_binary") -load("//test/syscalls:build_defs.bzl", "syscall_test") +load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) @@ -259,6 +258,8 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:munmap_test") +syscall_test(test = "//test/syscalls/linux:network_namespace_test") + syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_create_test", @@ -677,6 +678,8 @@ syscall_test( test = "//test/syscalls/linux:truncate_test", ) +syscall_test(test = "//test/syscalls/linux:tuntap_test") + syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( @@ -726,21 +729,3 @@ syscall_test(test = "//test/syscalls/linux:proc_net_unix_test") syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test") syscall_test(test = "//test/syscalls/linux:proc_net_udp_test") - -go_binary( - name = "syscall_test_runner", - testonly = 1, - srcs = ["syscall_test_runner.go"], - data = [ - "//runsc", - ], - deps = [ - "//pkg/log", - "//runsc/specutils", - "//runsc/testutil", - "//test/syscalls/gtest", - "//test/uds", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/test/syscalls/gtest/gtest.go b/test/syscalls/gtest/gtest.go deleted file mode 100644 index bdec8eb07..000000000 --- a/test/syscalls/gtest/gtest.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gtest contains helpers for running google-test tests from Go. -package gtest - -import ( - "fmt" - "os/exec" - "strings" -) - -var ( - // ListTestFlag is the flag that will list tests in gtest binaries. - ListTestFlag = "--gtest_list_tests" - - // FilterTestFlag is the flag that will filter tests in gtest binaries. - FilterTestFlag = "--gtest_filter" -) - -// TestCase is a single gtest test case. -type TestCase struct { - // Suite is the suite for this test. - Suite string - - // Name is the name of this individual test. - Name string -} - -// FullName returns the name of the test including the suite. It is suitable to -// pass to "-gtest_filter". -func (tc TestCase) FullName() string { - return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) -} - -// ParseTestCases calls a gtest test binary to list its test and returns a -// slice with the name and suite of each test. -func ParseTestCases(testBin string, extraArgs ...string) ([]TestCase, error) { - args := append([]string{ListTestFlag}, extraArgs...) - cmd := exec.Command(testBin, args...) - out, err := cmd.Output() - if err != nil { - exitErr, ok := err.(*exec.ExitError) - if !ok { - return nil, fmt.Errorf("could not enumerate gtest tests: %v", err) - } - return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr) - } - - var t []TestCase - var suite string - for _, line := range strings.Split(string(out), "\n") { - // Strip comments. - line = strings.Split(line, "#")[0] - - // New suite? - if !strings.HasPrefix(line, " ") { - suite = strings.TrimSuffix(strings.TrimSpace(line), ".") - continue - } - - // Individual test. - name := strings.TrimSpace(line) - - // Do we have a suite yet? - if suite == "" { - return nil, fmt.Errorf("test without a suite: %v", name) - } - - // Add this individual test. - t = append(t, TestCase{ - Suite: suite, - Name: name, - }) - - } - - if len(t) == 0 { - return nil, fmt.Errorf("no tests parsed from %v", testBin) - } - return t, nil -} diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc index c47a05181..3c825477c 100644 --- a/test/syscalls/linux/32bit.cc +++ b/test/syscalls/linux/32bit.cc @@ -74,7 +74,7 @@ void ExitGroup32(const char instruction[2], int code) { "int $3\n" : : [ code ] "m"(code), [ ip ] "d"(m.ptr()) - : "rax", "rbx", "rsp"); + : "rax", "rbx"); } constexpr int kExitCode = 42; diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index e7c82adfc..704bae17b 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -12,8 +12,12 @@ exports_files( "socket_ip_loopback_blocking.cc", "socket_ip_tcp_generic_loopback.cc", "socket_ip_tcp_loopback.cc", + "socket_ip_tcp_loopback_blocking.cc", + "socket_ip_tcp_loopback_nonblock.cc", "socket_ip_tcp_udp_generic.cc", "socket_ip_udp_loopback.cc", + "socket_ip_udp_loopback_blocking.cc", + "socket_ip_udp_loopback_nonblock.cc", "socket_ip_unbound.cc", "socket_ipv4_tcp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_external_networking_test.cc", @@ -128,6 +132,17 @@ cc_library( ) cc_library( + name = "socket_netlink_route_util", + testonly = 1, + srcs = ["socket_netlink_route_util.cc"], + hdrs = ["socket_netlink_route_util.h"], + deps = [ + ":socket_netlink_util", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( name = "socket_test_util", testonly = 1, srcs = [ @@ -3426,6 +3441,25 @@ cc_binary( ], ) +cc_binary( + name = "tuntap_test", + testonly = 1, + srcs = ["tuntap.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/syscalls/linux:socket_netlink_route_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "udp_socket_test_cases", testonly = 1, @@ -3636,6 +3670,23 @@ cc_binary( ) cc_binary( + name = "network_namespace_test", + testonly = 1, + srcs = ["network_namespace.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/util:capability_util", + "//test/util:memory_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( name = "semaphore_test", testonly = 1, srcs = ["semaphore.cc"], diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc index d89269985..940c97285 100644 --- a/test/syscalls/linux/alarm.cc +++ b/test/syscalls/linux/alarm.cc @@ -188,6 +188,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 4dd302eed..4e473268c 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -153,6 +153,13 @@ TEST(DevTest, TTYExists) { EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); } +TEST(DevTest, NetTunExists) { + struct stat statbuf = {}; + ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds()); + // Check that it's a character device with rw-rw-rw- permissions. + EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index b5e0a512b..07bd527e6 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -868,6 +868,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index 1c3d00287..7819f4ac3 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -33,7 +33,7 @@ namespace testing { namespace { int fallocate(int fd, int mode, off_t offset, off_t len) { - return syscall(__NR_fallocate, fd, mode, offset, len); + return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len); } class AllocateTest : public FileTest { @@ -47,27 +47,27 @@ TEST_F(AllocateTest, Fallocate) { EXPECT_EQ(buf.st_size, 0); // Grow to ten bytes. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 10); // Allocate to a smaller size should be noop. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 10); // Grow again. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 20); // Grow with offset. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 30); // Grow with offset beyond EOF. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 40); } diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 421c15b87..c7cc5816e 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -1128,5 +1128,5 @@ int main(int argc, char** argv) { exit(err); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 083ebbcf0..39fd6709d 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -84,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type); // SocketPairs created with AF_INET and the given type. SocketPairKind IPv4UDPUnboundSocketPair(int type); -// IPv4UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type. +// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_DGRAM, and the given type. SocketKind IPv4UDPUnboundSocket(int type); -// IPv6UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type. +// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_DGRAM, and the given type. SocketKind IPv6UDPUnboundSocket(int type); -// IPv4TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type. +// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_STREAM and the given type. SocketKind IPv4TCPUnboundSocket(int type); -// IPv6TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type. +// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_STREAM and the given type. SocketKind IPv6TCPUnboundSocket(int type); // IfAddrHelper is a helper class that determines the local interfaces present diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index b77e4cbd1..8b48f0804 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -349,6 +349,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc new file mode 100644 index 000000000..6ea48c263 --- /dev/null +++ b/test/syscalls/linux/network_namespace.cc @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <net/if.h> +#include <sched.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/memory_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +using TestFunc = std::function<PosixError()>; +using RunFunc = std::function<PosixError(TestFunc)>; + +struct NamespaceStrategy { + RunFunc run; + + static NamespaceStrategy Of(RunFunc run) { + NamespaceStrategy s; + s.run = run; + return s; + } +}; + +PosixError RunWithUnshare(TestFunc fn) { + PosixError err = PosixError(-1, "function did not return a value"); + ScopedThread t([&] { + if (unshare(CLONE_NEWNET) != 0) { + err = PosixError(errno); + return; + } + err = fn(); + }); + t.Join(); + return err; +} + +PosixError RunWithClone(TestFunc fn) { + struct Args { + absl::Notification n; + TestFunc fn; + PosixError err; + }; + Args args; + args.fn = fn; + args.err = PosixError(-1, "function did not return a value"); + + ASSIGN_OR_RETURN_ERRNO( + Mapping child_stack, + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); + pid_t child = clone( + +[](void *arg) { + Args *args = reinterpret_cast<Args *>(arg); + args->err = args->fn(); + args->n.Notify(); + syscall(SYS_exit, 0); // Exit manually. No return address on stack. + return 0; + }, + reinterpret_cast<void *>(child_stack.addr() + kPageSize), + CLONE_NEWNET | CLONE_THREAD | CLONE_SIGHAND | CLONE_VM, &args); + if (child < 0) { + return PosixError(errno, "clone() failed"); + } + args.n.WaitForNotification(); + return args.err; +} + +class NetworkNamespaceTest + : public ::testing::TestWithParam<NamespaceStrategy> {}; + +TEST_P(NetworkNamespaceTest, LoopbackExists) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + EXPECT_NO_ERRNO(GetParam().run([]() { + // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists. + // Check loopback device exists. + int sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock < 0) { + return PosixError(errno, "socket() failed"); + } + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + if (ioctl(sock, SIOCGIFINDEX, &ifr) < 0) { + return PosixError(errno, "ioctl() failed, lo cannot be found"); + } + return NoError(); + })); +} + +INSTANTIATE_TEST_SUITE_P( + AllNetworkNamespaceTest, NetworkNamespaceTest, + ::testing::Values(NamespaceStrategy::Of(RunWithUnshare), + NamespaceStrategy::Of(RunWithClone))); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index d07571a5f..04c5161f5 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -226,5 +226,5 @@ int main(int argc, char** argv) { prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc index 30f0d75b3..c4e9cf528 100644 --- a/test/syscalls/linux/prctl_setuid.cc +++ b/test/syscalls/linux/prctl_setuid.cc @@ -264,5 +264,5 @@ int main(int argc, char** argv) { prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index a23fdb58d..f91187e75 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -2076,5 +2076,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index 4dd5cf27b..bfe3e2603 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -1208,5 +1208,5 @@ int main(int argc, char** argv) { gvisor::testing::RunExecveChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h index e3ff0579a..ca1d67691 100644 --- a/test/syscalls/linux/rseq/uapi.h +++ b/test/syscalls/linux/rseq/uapi.h @@ -15,14 +15,9 @@ #ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ #define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ -// User-kernel ABI for restartable sequences. +#include <stdint.h> -// Standard types. -// -// N.B. This header will be included in targets that do have the standard -// library, so we can't shadow the standard type names. -using __u32 = __UINT32_TYPE__; -using __u64 = __UINT64_TYPE__; +// User-kernel ABI for restartable sequences. #ifdef __x86_64__ // Syscall numbers. @@ -32,20 +27,20 @@ constexpr int kRseqSyscall = 334; #endif // __x86_64__ struct rseq_cs { - __u32 version; - __u32 flags; - __u64 start_ip; - __u64 post_commit_offset; - __u64 abort_ip; -} __attribute__((aligned(4 * sizeof(__u64)))); + uint32_t version; + uint32_t flags; + uint64_t start_ip; + uint64_t post_commit_offset; + uint64_t abort_ip; +} __attribute__((aligned(4 * sizeof(uint64_t)))); // N.B. alignment is enforced by the kernel. struct rseq { - __u32 cpu_id_start; - __u32 cpu_id; + uint32_t cpu_id_start; + uint32_t cpu_id; struct rseq_cs* rseq_cs; - __u32 flags; -} __attribute__((aligned(4 * sizeof(__u64)))); + uint32_t flags; +} __attribute__((aligned(4 * sizeof(uint64_t)))); constexpr int kRseqFlagUnregister = 1 << 0; diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc index 81d193ffd..ed27e2566 100644 --- a/test/syscalls/linux/rtsignal.cc +++ b/test/syscalls/linux/rtsignal.cc @@ -167,6 +167,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc index 2c947feb7..cf6499f8b 100644 --- a/test/syscalls/linux/seccomp.cc +++ b/test/syscalls/linux/seccomp.cc @@ -411,5 +411,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc index 4deb1ae95..6227774a4 100644 --- a/test/syscalls/linux/sigiret.cc +++ b/test/syscalls/linux/sigiret.cc @@ -132,6 +132,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc index 95be4b66c..389e5fca2 100644 --- a/test/syscalls/linux/signalfd.cc +++ b/test/syscalls/linux/signalfd.cc @@ -369,5 +369,5 @@ int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc index 7db57d968..b2fcedd62 100644 --- a/test/syscalls/linux/sigstop.cc +++ b/test/syscalls/linux/sigstop.cc @@ -147,5 +147,5 @@ int main(int argc, char** argv) { return 1; } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc index 1e5bf5942..4f8afff15 100644 --- a/test/syscalls/linux/sigtimedwait.cc +++ b/test/syscalls/linux/sigtimedwait.cc @@ -319,6 +319,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index db5663ecd..1c533fdf2 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -14,6 +14,7 @@ #include "test/syscalls/linux/socket_ip_udp_generic.h" +#include <errno.h> #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -209,46 +210,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } -// Ensure that Receiving TOS is off by default. -TEST_P(UDPSocketPairTest, RecvTosDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test that setting and getting IP_RECVTOS works as expected. -TEST_P(UDPSocketPairTest, SetRecvTos) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -401,5 +362,97 @@ TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { EXPECT_EQ(get_len, sizeof(get)); } +// Holds TOS or TClass information for IPv4 or IPv6 respectively. +struct RecvTosOption { + int level; + int option; +}; + +RecvTosOption GetRecvTosOption(int domain) { + TEST_CHECK(domain == AF_INET || domain == AF_INET6); + RecvTosOption opt; + switch (domain) { + case AF_INET: + opt.level = IPPROTO_IP; + opt.option = IP_RECVTOS; + break; + case AF_INET6: + opt.level = IPPROTO_IPV6; + opt.option = IPV6_RECVTCLASS; + break; + } + return opt; +} + +// Ensure that Receiving TOS or TCLASS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as +// expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this +// mirrors behavior in linux. +TEST_P(UDPSocketPairTest, TOSRecvMismatch) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(AF_INET); + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); +} + +// Test that an IPv4 socket does not support the IPv6 TClass option. +TEST_P(UDPSocketPairTest, TClassRecvMismatch) { + // This should only test AF_INET sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS, + &get, &get_len), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc new file mode 100644 index 000000000..53eb3b6b2 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_netlink_route_util.h" + +#include <linux/if.h> +#include <linux/netlink.h> +#include <linux/rtnetlink.h> + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr uint32_t kSeq = 12345; + +} // namespace + +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function<void(const struct nlmsghdr* hdr)>& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + +PosixErrorOr<std::vector<Link>> DumpLinks() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + std::vector<Link> links; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + links.emplace_back(); + links.back().index = msg->ifi_index; + links.back().type = msg->ifi_type; + links.back().name = + std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); + })); + return links; +} + +PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.type == ARPHRD_LOOPBACK) { + return absl::optional<Link>(link); + } + } + return absl::optional<Link>(); +} + +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char pad[NLMSG_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + req.ifinfo.ifi_flags = flags; + req.ifinfo.ifi_change = change; + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFLA_ADDRESS; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h new file mode 100644 index 000000000..2c018e487 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ + +#include <linux/netlink.h> +#include <linux/rtnetlink.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { + +struct Link { + int index; + int16_t type; + std::string name; +}; + +PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, + const std::function<void(const struct nlmsghdr* hdr)>& fn); + +PosixErrorOr<std::vector<Link>> DumpLinks(); + +PosixErrorOr<absl::optional<Link>> FindLoopbackLink(); + +// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkChangeFlags changes interface flags. E.g. IFF_UP. +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); + +// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface. +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index 2f92c27da..4b3c44527 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -658,5 +658,5 @@ int main(int argc, char** argv) { } } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc new file mode 100644 index 000000000..f6ac9d7b8 --- /dev/null +++ b/test/syscalls/linux/tuntap.cc @@ -0,0 +1,346 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_ether.h> +#include <linux/if_tun.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr int kIPLen = 4; + +constexpr const char kDevNetTun[] = "/dev/net/tun"; +constexpr const char kTapName[] = "tap0"; + +constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}; +constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}; + +PosixErrorOr<std::set<std::string>> DumpLinkNames() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + std::set<std::string> names; + for (const auto& link : links) { + names.emplace(link.name); + } + return names; +} + +PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.name == name) { + return absl::optional<Link>(link); + } + } + return absl::optional<Link>(); +} + +struct pihdr { + uint16_t pi_flags; + uint16_t pi_protocol; +} __attribute__((packed)); + +struct ping_pkt { + pihdr pi; + struct ethhdr eth; + struct iphdr ip; + struct icmphdr icmp; + char payload[64]; +} __attribute__((packed)); + +ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + ping_pkt pkt = {}; + + pkt.pi.pi_protocol = htons(ETH_P_IP); + + memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest)); + memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source)); + pkt.eth.h_proto = htons(ETH_P_IP); + + pkt.ip.ihl = 5; + pkt.ip.version = 4; + pkt.ip.tos = 0; + pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) + + sizeof(pkt.payload)); + pkt.ip.id = 1; + pkt.ip.frag_off = 1 << 6; // Do not fragment + pkt.ip.ttl = 64; + pkt.ip.protocol = IPPROTO_ICMP; + inet_pton(AF_INET, dstip, &pkt.ip.daddr); + inet_pton(AF_INET, srcip, &pkt.ip.saddr); + pkt.ip.check = IPChecksum(pkt.ip); + + pkt.icmp.type = ICMP_ECHO; + pkt.icmp.code = 0; + pkt.icmp.checksum = 0; + pkt.icmp.un.echo.sequence = 1; + pkt.icmp.un.echo.id = 1; + + strncpy(pkt.payload, "abcd", sizeof(pkt.payload)); + pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload)); + + return pkt; +} + +struct arp_pkt { + pihdr pi; + struct ethhdr eth; + struct arphdr arp; + uint8_t arp_sha[ETH_ALEN]; + uint8_t arp_spa[kIPLen]; + uint8_t arp_tha[ETH_ALEN]; + uint8_t arp_tpa[kIPLen]; +} __attribute__((packed)); + +std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + std::string buffer; + buffer.resize(sizeof(arp_pkt)); + + arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]); + { + pkt->pi.pi_protocol = htons(ETH_P_ARP); + + memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest)); + memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source)); + pkt->eth.h_proto = htons(ETH_P_ARP); + + pkt->arp.ar_hrd = htons(ARPHRD_ETHER); + pkt->arp.ar_pro = htons(ETH_P_IP); + pkt->arp.ar_hln = ETH_ALEN; + pkt->arp.ar_pln = kIPLen; + pkt->arp.ar_op = htons(ARPOP_REPLY); + + memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha)); + inet_pton(AF_INET, srcip, pkt->arp_spa); + memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha)); + inet_pton(AF_INET, dstip, pkt->arp_tpa); + } + return buffer; +} + +} // namespace + +class TuntapTest : public ::testing::Test { + protected: + void TearDown() override { + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) { + // Bring back capability if we had dropped it in test case. + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); + } + } +}; + +TEST_F(TuntapTest, CreateInterfaceNoCap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + strncpy(ifr.ifr_name, kTapName, IFNAMSIZ); + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM)); +} + +TEST_F(TuntapTest, CreateFixedNameInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_expect = ifr_set; + // See __tun_chr_ioctl() in net/drivers/tun.c. + ifr_expect.ifr_flags |= IFF_NOFILTER; + + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(kTapName))); + EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0)); +} + +TEST_F(TuntapTest, CreateInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + // Empty ifr.ifr_name. Let kernel assign. + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + std::string ifname = ifr_get.ifr_name; + EXPECT_THAT(ifname, ::testing::StartsWith("tap")); + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(ifname))); +} + +TEST_F(TuntapTest, InvalidReadWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + char buf[128] = {}; + EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); +} + +TEST_F(TuntapTest, WriteToDownDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + // Device created should be down by default. + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + char buf[128] = {}; + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO)); +} + +// This test sets up a TAP device and pings kernel by sending ICMP echo request. +// +// It works as the following: +// * Open /dev/net/tun, and create kTapName interface. +// * Use rtnetlink to do initial setup of the interface: +// * Assign IP address 10.0.0.1/24 to kernel. +// * MAC address: kMacA +// * Bring up the interface. +// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel. +// * Loop to receive packets from TAP device/fd: +// * If packet is an ICMP echo reply, it stops and passes the test. +// * If packet is an ARP request, it responds with canned reply and resends +// the +// ICMP request packet. +TEST_F(TuntapTest, PingKernel) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Interface creation. + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + absl::optional<Link> link = + ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName)); + ASSERT_TRUE(link.has_value()); + + // Interface setup. + struct in_addr addr; + inet_pton(AF_INET, "10.0.0.1", &addr); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24, + &addr, sizeof(addr))); + + if (!IsRunningOnGvisor()) { + // FIXME: gVisor doesn't support setting MAC address on interfaces yet. + EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP)); + } + + ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + + // Send ping, this would trigger an ARP request on Linux. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + + // Receive loop to process inbound packets. + struct inpkt { + union { + pihdr pi; + ping_pkt ping; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + // Process ARP packet. + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + // Respond with canned ARP reply. + EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()), + SyscallSucceedsWithValue(arp_rep.size())); + // First ping request might have been dropped due to mac address not in + // ARP cache. Send it again. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + } + + // Process ping response packet. + if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol && + r.ping.ip.protocol == ping_req.ip.protocol && + !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) && + !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) && + r.ping.icmp.type == 0 && r.ping.icmp.code == 0) { + // Ends and passes the test. + break; + } + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 9f8de6b48..740c7986d 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -21,6 +21,10 @@ #include <sys/socket.h> #include <sys/types.h> +#ifndef SIOCGSTAMP +#include <linux/sockios.h> +#endif + #include "gtest/gtest.h" #include "absl/base/macros.h" #include "absl/time/clock.h" @@ -1349,9 +1353,6 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. - SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && - !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1422,7 +1423,6 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc index 0aaba482d..19d05998e 100644 --- a/test/syscalls/linux/vfork.cc +++ b/test/syscalls/linux/vfork.cc @@ -191,5 +191,5 @@ int main(int argc, char** argv) { return gvisor::testing::RunChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/syscall_test_runner.sh b/test/syscalls/syscall_test_runner.sh deleted file mode 100755 index 864bb2de4..000000000 --- a/test/syscalls/syscall_test_runner.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# syscall_test_runner.sh is a simple wrapper around the go syscall test runner. -# It exists so that we can build the syscall test runner once, and use it for -# all syscall tests, rather than build it for each test run. - -set -euf -x -o pipefail - -echo -- "$@" - -if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then - mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" - chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" -fi - -# Get location of syscall_test_runner binary. -readonly runner=$(find "${TEST_SRCDIR}" -name syscall_test_runner) - -# Pass the arguments of this script directly to the runner. -exec "${runner}" "$@" diff --git a/test/util/BUILD b/test/util/BUILD index 1f22ebe29..8b5a0f25c 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "gtest", "select_system") +load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -260,6 +260,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", gtest, + gbenchmark, ], ) diff --git a/test/util/test_main.cc b/test/util/test_main.cc index 5c7ee0064..1f389e58f 100644 --- a/test/util/test_main.cc +++ b/test/util/test_main.cc @@ -16,5 +16,5 @@ int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/util/test_util.h b/test/util/test_util.h index 2d22b0eb8..c5cb9d6d6 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -771,6 +771,7 @@ std::string RunfilePath(std::string path); #endif void TestInit(int* argc, char*** argv); +int RunAllTests(void); } // namespace testing } // namespace gvisor diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc index ba7c0a85b..7e1ad9e66 100644 --- a/test/util/test_util_impl.cc +++ b/test/util/test_util_impl.cc @@ -17,8 +17,12 @@ #include "gtest/gtest.h" #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "benchmark/benchmark.h" #include "test/util/logging.h" +extern bool FLAGS_benchmark_list_tests; +extern std::string FLAGS_benchmark_filter; + namespace gvisor { namespace testing { @@ -26,6 +30,7 @@ void SetupGvisorDeathTest() {} void TestInit(int* argc, char*** argv) { ::testing::InitGoogleTest(argc, *argv); + benchmark::Initialize(argc, *argv); ::absl::ParseCommandLine(*argc, *argv); // Always mask SIGPIPE as it's common and tests aren't expected to handle it. @@ -34,5 +39,14 @@ void TestInit(int* argc, char*** argv) { TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); } +int RunAllTests() { + if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") { + benchmark::RunSpecifiedBenchmarks(); + return 0; + } else { + return RUN_ALL_TESTS(); + } +} + } // namespace testing } // namespace gvisor diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 6798362dc..905b16d41 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -8,7 +8,6 @@ load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") load("@pydeps//:requirements.bzl", _py_requirement = "requirement") -load("//tools/bazeldefs:tags.bzl", _go_suffixes = "go_suffixes") container_image = _container_image cc_binary = _cc_binary @@ -19,8 +18,8 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" go_image = _go_image go_embed_data = _go_embed_data -go_suffixes = _go_suffixes gtest = "@com_google_googletest//:gtest" +gbenchmark = "@com_google_benchmark//:benchmark" loopback = "//tools/bazeldefs:loopback" proto_library = native.proto_library pkg_deb = _pkg_deb diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl new file mode 100644 index 000000000..92b0b5fc0 --- /dev/null +++ b/tools/bazeldefs/platforms.bzl @@ -0,0 +1,17 @@ +"""List of platforms.""" + +# Platform to associated tags. +platforms = { + "ptrace": [ + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], + "kvm": [ + "manual", + "local", + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], +} + +default_platform = "ptrace" diff --git a/tools/defs.bzl b/tools/defs.bzl index 39f035f12..15a310403 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,9 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") +load("//tools/bazeldefs:tags.bzl", "go_suffixes") # Delegate directly. cc_binary = _cc_binary @@ -21,6 +23,7 @@ go_image = _go_image go_test = _go_test go_tool_library = _go_tool_library gtest = _gtest +gbenchmark = _gbenchmark pkg_deb = _pkg_deb pkg_tar = _pkg_tar py_library = _py_library @@ -32,6 +35,8 @@ select_system = _select_system loopback = _loopback default_installer = _default_installer default_net_util = _default_net_util +platforms = _platforms +default_platform = _default_platform def go_binary(name, **kwargs): """Wraps the standard go_binary. @@ -83,7 +88,7 @@ def go_imports(name, src, out): cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), ) -def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. The recommended way is to use this rule with mostly identical configuration as the native @@ -106,6 +111,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F imports: imports required for stateify. stateify: whether statify is enabled (default: true). marshal: whether marshal is enabled (default: false). + marshal_debug: whether the gomarshal tools emits debugging output (default: false). **kwargs: standard go_library arguments. """ all_srcs = srcs @@ -144,7 +150,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F go_marshal( name = name + suffix + "_abi_autogen", srcs = src_subset, - debug = False, + debug = select({ + "//tools/go_marshal:marshal_config_verbose": True, + "//conditions:default": marshal_debug, + }), imports = imports, package = name, ) diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index 80d9c0504..be49cf9c8 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -12,3 +12,8 @@ go_binary( "//tools/go_marshal/gomarshal", ], ) + +config_setting( + name = "marshal_config_verbose", + values = {"define": "gomarshal=verbose"}, +) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 0294ba5ba..d365a1f3c 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -44,7 +44,8 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "len", "ptr", "src", "srcs", "task", "val", + "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", + "ptr", "src", "srcs", "task", "val", // All single-letter identifiers. } @@ -193,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { return files, fsets, nil } -// collectMarshallabeTypes walks the parsed AST and collects a list of type +// collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { var types []*ast.TypeSpec for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -222,14 +223,22 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as continue } for _, spec := range gdecl.Specs { - // We already confirmed we're in a type declaration earlier. + // We already confirmed we're in a type declaration earlier, so this + // cast will succeed. t := spec.(*ast.TypeSpec) - if _, ok := t.Type.(*ast.StructType); ok { - debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name) + switch t.Type.(type) { + case *ast.StructType: + debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) + types = append(types, t) + continue + case *ast.Ident: // Newtype on primitive. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) types = append(types, t) continue } - debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl) + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } } return types @@ -269,12 +278,20 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - // We're guaranteed to have only struct type specs by now. See - // Generator.collectMarshallabeTypes. i := newInterfaceGenerator(t, fset) - i.validate() - i.emitMarshallable() - return i + switch ty := t.Type.(type) { + case *ast.StructType: + i.validateStruct() + i.emitMarshallableForStruct() + return i + case *ast.Ident: + i.validatePrimitiveNewtype(ty) + i.emitMarshallableForPrimitiveNewtype() + return i + default: + // This should've been filtered out by collectMarshallabeTypes. + panic(fmt.Sprintf("Unexpected type %+v", ty)) + } } // generateOneTestSuite generates a test suite for the automatically generated @@ -320,7 +337,7 @@ func (g *Generator) Run() error { for i, a := range asts { // Collect type declarations marked for code generation and generate // Marshallable interfaces. - for _, t := range g.collectMarshallabeTypes(a, fsets[i]) { + for _, t := range g.collectMarshallableTypes(a, fsets[i]) { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. for ref, _ := range impl.ms { @@ -338,17 +355,6 @@ func (g *Generator) Run() error { } } - // Tool was invoked with input files with no data structures marked for code - // generation. This is probably not what the user intended. - if len(impls) == 0 { - var buf bytes.Buffer - fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n") - for _, i := range g.inputs { - fmt.Fprintf(&buf, " %s\n", i) - } - abort(buf.String()) - } - // Write output file header. These include things like package name and // import statements. if err := g.writeHeader(); err != nil { @@ -391,6 +397,26 @@ func (g *Generator) writeTests(ts []*testGenerator) error { } // Write test functions. + + // If we didn't generate any Marshallable implementations, we can't just + // emit an empty test file, since that causes the build to fail with "no + // tests/benchmarks/examples found". Unfortunately we can't signal bazel to + // omit the entire package since the outputs are already defined before + // go-marshal is called. If we'd otherwise emit an empty test suite, emit an + // empty example instead. + if len(ts) == 0 { + b.reset() + b.emit("func ExampleEmptyTestSuite() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty to ensure this file contains at least\n") + b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") + b.emit("// is marked marshallable, but emitting a test file with no entities results\n") + b.emit("// in a build failure.\n") + }) + b.emit("}\n") + return b.write(g.outputTest) + } + for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 3aa299ccd..ea1af998e 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string { // newinterfaceGenerator creates a new interface generator. func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &interfaceGenerator{ t: t, r: receiverName(t), @@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -// validate ensures the type we're working with can be marshalled. These checks -// are done ahead of time and in one place so we can make assumptions later. -func (g *interfaceGenerator) validate() { +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct() { g.forEachField(func(f *ast.Field) { if len(f.Names) == 0 { g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") @@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() { g.forEachField(func(f *ast.Field) { fieldDispatcher{ primitive: func(_, t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we - // provide suggestions for some disallowed types and reject - // them, then attempt to marshal any remaining types by - // invoking the marshal.Marshallable interface on them. If - // these types don't actually implement - // marshal.Marshallable, compilation of the generated code - // will fail with an appropriate error message. - return - case "int": - g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } + g.validatePrimitiveNewtype(t) }, selector: func(_, _, _ *ast.Ident) { // No validation to perform on selector fields. However this @@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { +// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice. +func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) } } -func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { +// unmarshalStructFieldScalar reads a single scalar field from a struct, from a +// byte slice. +func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { - case "int8": - g.emit("%s = int8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) - case "uint8": - g.emit("%s = uint8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) g.shift(bufVar, 1) - - case "int16": - g.recordUsedImport("usermem") - g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) - g.shift(bufVar, 2) - case "uint16": + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) g.shift(bufVar, 2) - - case "int32": - g.recordUsedImport("usermem") - g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) - g.shift(bufVar, 4) - case "uint32": + case "int32", "uint32": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) g.shift(bufVar, 4) - - case "int64": - g.recordUsedImport("usermem") - g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) - g.shift(bufVar, 8) - case "uint64": + case "int64", "uint64": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) @@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string } } +// marshalPrimitiveScalar writes a single primitive variable to a byte slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + // areFieldsPackedExpression returns a go expression checking whether g.t's fields are // packed. Returns "", false if g.t has no fields that may be potentially // packed, otherwise returns <clause>, true, where <clause> is an expression @@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { return strings.Join(cs, " && "), true } -func (g *interfaceGenerator) emitMarshallable() { +func (g *interfaceGenerator) emitMarshallableForStruct() { // Is g.t a packed struct without consideing field types? thisPacked := true g.forEachField(func(f *ast.Field) { @@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst") }, selector: func(n, tX, tSel *ast.Ident) { - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") + g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") }) g.emit("}\n") }, @@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src") }, selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") + g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") }) g.emit("}\n") }, @@ -507,13 +536,14 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("return task.CopyOutBytes(addr, buf)\n") + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("return err\n") } if thisPacked { g.recordUsedImport("reflect") @@ -539,11 +569,11 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("hdr.Len = %s.SizeBytes()\n", g.r) g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - g.emit("len, err := task.CopyOutBytes(addr, buf)\n") + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) g.emit("// must live until after the CopyOutBytes.\n") g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return len, err\n") + g.emit("return err\n") } else { fallback() } @@ -553,20 +583,20 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("n, err := task.CopyInBytes(addr, buf)\n") + g.emit("_, err := task.CopyInBytes(addr, buf)\n") g.emit("if err != nil {\n") g.inIndent(func() { - g.emit("return n, err\n") + g.emit("return err\n") }) g.emit("}\n") g.emit("%s.UnmarshalBytes(buf)\n", g.r) - g.emit("return n, nil\n") + g.emit("return nil\n") } if thisPacked { g.recordUsedImport("reflect") @@ -592,11 +622,11 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("hdr.Len = %s.SizeBytes()\n", g.r) g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - g.emit("len, err := task.CopyInBytes(addr, buf)\n") + g.emit("_, err := task.CopyInBytes(addr, buf)\n") g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) g.emit("// must live until after the CopyInBytes.\n") g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return len, err\n") + g.emit("return err\n") } else { fallback() } @@ -649,3 +679,144 @@ func (g *interfaceGenerator) emitMarshallable() { }) g.emit("}\n\n") } + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + nt := g.t.Type.(*ast.Ident) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") + +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 8c28b00d0..8ba47eb67 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -49,9 +49,6 @@ type testGenerator struct { } func newTestGenerator(t *ast.TypeSpec) *testGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &testGenerator{ t: t, r: receiverName(t), @@ -69,14 +66,6 @@ func (g *testGenerator) typeName() string { return g.t.Name.Name } -func (g *testGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - func (g *testGenerator) testFuncName(base string) string { return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) } @@ -89,10 +78,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) { func (g *testGenerator) emitTestNonZeroSize() { g.inTestFunction("TestSizeNonZero", func() { - g.emit("x := &%s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("if x.SizeBytes() == 0 {\n") g.inIndent(func() { - g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") + g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") }) g.emit("}\n") }) @@ -100,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() { func (g *testGenerator) emitTestSuspectAlignment() { g.inTestFunction("TestSuspectAlignment", func() { - g.emit("x := %s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") }) } diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index 20353850d..f129788e0 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -91,12 +91,12 @@ type Marshallable interface { // marshalled does not escape. The implementation should avoid creating // extra copies in memory by directly deserializing to the object's // underlying memory. - CopyIn(task Task, addr usermem.Addr) (int, error) + CopyIn(task Task, addr usermem.Addr) error // CopyOut serializes a Marshallable type to a task's memory. This may only // be called from a task goroutine. This is more efficient than calling // MarshalUnsafe on Marshallable.Packed types, as the type being serialized // does not escape. The implementation should avoid creating extra copies in // memory by directly serializing from the object's underlying memory. - CopyOut(task Task, addr usermem.Addr) (int, error) + CopyOut(task Task, addr usermem.Addr) error } diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index 8de02d707..93229dedb 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -103,3 +103,13 @@ type Stat struct { CTime Timespec _ [3]int64 } + +// SignalSet is an example marshallable newtype on a primitive. +// +// +marshal +type SignalSet uint64 + +// SignalSetAlias is an example newtype on another marshallable type. +// +// +marshal +type SignalSetAlias SignalSet diff --git a/tools/installers/master.sh b/tools/installers/master.sh index 7b1956454..52f9734a6 100755 --- a/tools/installers/master.sh +++ b/tools/installers/master.sh @@ -15,6 +15,21 @@ # limitations under the License. # Install runsc from the master branch. +set -e + curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" -apt-get update && apt-get install -y runsc +while true; do + if apt-get update; then + apt-get install -y runsc + break + fi + result=$? + # Check if apt update failed to aquire the file lock. + if [[ $result -ne 100 ]]; then + exit $result + fi +done +runsc install +service docker restart + |