diff options
757 files changed, 41745 insertions, 6890 deletions
@@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Build with C++17. +build --cxxopt=-std=c++17 + # Display the current git revision in the info block. build --stamp --workspace_status_command tools/workspace_status.sh @@ -23,7 +23,18 @@ go_path( "//runsc", # Packages that are not dependencies of //runsc. + "//pkg/sentry/kernel/memevent", + "//pkg/tcpip/adapters/gonet", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/muxed", + "//pkg/tcpip/link/sharedmem", + "//pkg/tcpip/link/sharedmem/pipe", + "//pkg/tcpip/link/sharedmem/queue", + "//pkg/tcpip/link/tun", + "//pkg/tcpip/link/waitable", + "//pkg/tcpip/sample/tun_tcp_connect", + "//pkg/tcpip/sample/tun_tcp_echo", + "//pkg/tcpip/transport/tcpconntrack", ], ) diff --git a/Dockerfile b/Dockerfile index 6e9d870db..738623023 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ FROM ubuntu:bionic -RUN apt-get update && apt-get install -y curl gnupg2 git python3 +RUN apt-get update && apt-get install -y curl gnupg2 git python python3 python3-distutils python3-pip RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ - curl https://bazel.build/bazel-release.pub.gpg | apt-key add - + curl https://bazel.build/bazel-release.pub.gpg | apt-key add - RUN apt-get update && apt-get install -y bazel && apt-get clean WORKDIR /gvisor @@ -22,7 +22,7 @@ bazel-server-start: docker-build --privileged \ gvisor-bazel \ sh -c "while :; do sleep 100; done" && \ - docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --gid $(GID) -d $(HOME) gvisor" + docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --non-unique --gid $(GID) -d $(HOME) gvisor" bazel-server: docker exec gvisor-bazel true || \ @@ -48,9 +48,10 @@ Make sure the following dependencies are installed: * Linux 4.14.77+ ([older linux][old-linux]) * [git][git] -* [Bazel][bazel] 0.28.0+ +* [Bazel][bazel] 1.2+ * [Python][python] * [Docker version 17.09.0 or greater][docker] +* C++ toolchain supporting C++17 (GCC 7+, Clang 5+) * Gold linker (e.g. `binutils-gold` package on Ubuntu) ### Building @@ -1,21 +1,22 @@ -# Load go bazel rules and gazelle. load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") +# Load go bazel rules and gazelle. http_archive( name = "io_bazel_rules_go", - sha256 = "078f2a9569fa9ed846e60805fb5fb167d6f6c4ece48e6d409bf5fb2154eaf0d8", + sha256 = "b9aa86ec08a292b97ec4591cf578e020b35f98e12173bbd4a921f84f583aebd9", urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.20.0/rules_go-v0.20.0.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.20.0/rules_go-v0.20.0.tar.gz", + "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.20.2/rules_go-v0.20.2.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.20.2/rules_go-v0.20.2.tar.gz", ], ) http_archive( name = "bazel_gazelle", - sha256 = "41bff2a0b32b02f20c227d234aa25ef3783998e5453f7eade929704dcff7cd4b", + sha256 = "86c6d481b3f7aedc1d60c1c211c6f76da282ae197c3b3160f54bd3a8f847896f", urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.0/bazel-gazelle-v0.19.0.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.0/bazel-gazelle-v0.19.0.tar.gz", + "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.1/bazel-gazelle-v0.19.1.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.1/bazel-gazelle-v0.19.1.tar.gz", ], ) @@ -24,7 +25,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to go_rules_dependencies() go_register_toolchains( - go_version = "1.13.1", + go_version = "1.13.4", nogo = "@//:nogo", ) @@ -58,6 +59,26 @@ load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") protobuf_deps() +# Load python dependencies. +git_repository( + name = "rules_python", + commit = "94677401bc56ed5d756f50b441a6a5c7f735a6d4", + remote = "https://github.com/bazelbuild/rules_python.git", + shallow_since = "1573842889 -0500", +) + +load("@rules_python//python:pip.bzl", "pip_import") + +pip_import( + name = "pydeps", + python_interpreter = "python3", + requirements = "//benchmarks:requirements.txt", +) + +load("@pydeps//:requirements.bzl", "pip_install") + +pip_install() + # Load bazel_toolchain to support Remote Build Execution. # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( @@ -85,6 +106,45 @@ load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") rules_pkg_dependencies() +# Container rules. +http_archive( + name = "io_bazel_rules_docker", + sha256 = "14ac30773fdb393ddec90e158c9ec7ebb3f8a4fd533ec2abbfd8789ad81a284b", + strip_prefix = "rules_docker-0.12.1", + urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.12.1/rules_docker-v0.12.1.tar.gz"], +) + +load( + "@io_bazel_rules_docker//repositories:repositories.bzl", + container_repositories = "repositories", +) + +container_repositories() + +load("@io_bazel_rules_docker//repositories:deps.bzl", container_deps = "deps") + +container_deps() + +load( + "@io_bazel_rules_docker//container:container.bzl", + "container_pull", +) + +# This container is built from the Dockerfile in test/iptables/runner. +container_pull( + name = "iptables-test", + registry = "gcr.io", + repository = "gvisor-presubmit/iptables-test", + digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65", +) + +load( + "@io_bazel_rules_docker//go:image.bzl", + _go_image_repos = "repositories", +) + +_go_image_repos() + # External repositories, in sorted order. go_repository( name = "com_github_cenkalti_backoff", diff --git a/benchmarks/BUILD b/benchmarks/BUILD new file mode 100644 index 000000000..dbadeeaf2 --- /dev/null +++ b/benchmarks/BUILD @@ -0,0 +1,9 @@ +package(licenses = ["notice"]) + +py_binary( + name = "benchmarks", + srcs = ["run.py"], + main = "run.py", + python_version = "PY3", + deps = ["//benchmarks/runner"], +) diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..ad44cd6ac --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,172 @@ +# Benchmark tools + +These scripts are tools for collecting performance data for Docker-based tests. + +## Setup + +The scripts assume the following: + +* You have a local machine with bazel installed. +* You have some machine(s) with docker installed. These machines will be + refered to as the "Environment". +* Environment machines have the runtime(s) under test installed, such that you + can run docker with a command like: `docker run --runtime=$RUNTIME + your/image`. +* You are able to login to machines in the environment with the local machine + via ssh and the user for ssh can run docker commands without using `sudo`. +* The docker daemon on each of your environment machines is listening on + `unix:///var/run/docker.sock` (docker's default). + +For configuring the environment manually, consult the +[dockerd documentation][dockerd]. + +## Environment + +All benchmarks require a user defined yaml file describe the environment. These +files are of the form: + +```yaml +machine1: local +machine2: + hostname: 100.100.100.100 + username: username + key_path: ~/private_keyfile + key_password: passphrase +machine3: + hostname: 100.100.100.101 + username: username + key_path: ~/private_keyfile + key_password: passphrase +``` + +The yaml file defines an environment with three machines named `machine1`, +`machine2` and `machine3`. `machine1` is the local machine, `machine2` and +`machine3` are remote machines. Both `machine2` and `machine3` should be +reachable by `ssh`. For example, the command `ssh -i ~/private_keyfile +username@100.100.100.100` (using the passphrase `passphrase`) should connect to +`machine2`. + +The above is an example only. Machines should be uniform, since they are treated +as such by the tests. Machines must also be accessible to each other via their +default routes. Furthermore, some benchmarks will meaningless if running on the +local machine, such as density. + +For remote machines, `hostname`, `key_path`, and `username` are required and +others are optional. In addition key files must be generated +[using the instrcutions below](#generating-ssh-keys). + +The above yaml file can be checked for correctness with the `validate` command +in the top level perf.py script: + +`bazel run :benchmarks -- validate $PWD/examples/localhost.yaml` + +## Running benchmarks + +To list available benchmarks, use the `list` commmand: + +```bash +bazel run :benchmarks -- list + +... +Benchmark: sysbench.cpu +Metrics: events_per_second + Run sysbench CPU test. Additional arguments can be provided for sysbench. + + :param max_prime: The maximum prime number to search. +``` + +To run benchmarks, use the `run` command. For example, to run the sysbench +benchmark above: + +```bash +bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml sysbench.cpu +``` + +You can run parameterized benchmarks, for example to run with different +runtimes: + +```bash +bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --runtime=runc --runtime=runsc sysbench.cpu +``` + +Or with different parameters: + +```bash +bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --max_prime=10 --max_prime=100 sysbench.cpu +``` + +## Writing benchmarks + +To write new benchmarks, you should familiarize yourself with the structure of +the repository. There are three key components. + +## Harness + +The harness makes use of the [docker py SDK][docker-py]. It is advisable that +you familiarize yourself with that API when making changes, specifically: + +* clients +* containers +* images + +In general, benchmarks need only interact with the `Machine` objects provided to +the benchmark function, which are the machines defined in the environment. These +objects allow the benchmark to define the relationships between different +containers, and parse the output. + +## Workloads + +The harness requires workloads to run. These are all available in the +`workloads` directory. + +In general, a workload consists of a Dockerfile to build it (while these are not +hermetic, in general they should be as fixed and isolated as possible), some +parses for output if required, parser tests and sample data. Provided the test +is named after the workload package and contains a function named `sample`, this +variable will be used to automatically mock workload output when the `--mock` +flag is provided to the main tool. + +## Writing benchmarks + +Benchmarks define the tests themselves. All benchmarks have the following +function signature: + +```python +def my_func(output) -> float: + return float(output) + +@benchmark(metrics = my_func, machines = 1) +def my_benchmark(machine: machine.Machine, arg: str): + return "3.4432" +``` + +Each benchmark takes a variable amount of position arguments as +`harness.Machine` objects and some set of keyword arguments. It is recommended +that you accept arbitrary keyword arguments and pass them through when +constructing the container under test. + +To write a new benchmark, open a module in the `suites` directory and use the +above signature. You should add a descriptive doc string to describe what your +benchmark is and any test centric arguments. + +## Generating SSH Keys + +The scripts only support RSA Keys, and ssh library used in paramiko. Paramiko +only supports RSA keys that look like the following (PEM format): + +```bash +$ cat /path/to/ssh/key + +-----BEGIN RSA PRIVATE KEY----- +...private key text... +-----END RSA PRIVATE KEY----- + +``` + +To generate ssh keys in PEM format, use the [`-t rsa -m PEM -b 4096`][RSA-keys]. +option. + +[dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/ +[docker-py]: https://docker-py.readthedocs.io/en/stable/ +[paramiko]: http://docs.paramiko.org/en/2.4/api/client.html +[RSA-keys]: https://serverfault.com/questions/939909/ssh-keygen-does-not-create-rsa-private-key diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl new file mode 100644 index 000000000..79e6cdbc8 --- /dev/null +++ b/benchmarks/defs.bzl @@ -0,0 +1,18 @@ +"""Provides python helper functions.""" + +load("@pydeps//:requirements.bzl", _requirement = "requirement") + +def filter_deps(deps = None): + if deps == None: + deps = [] + return [dep for dep in deps if dep] + +def py_library(deps = None, **kwargs): + return native.py_library(deps = filter_deps(deps), **kwargs) + +def py_test(deps = None, **kwargs): + return native.py_test(deps = filter_deps(deps), **kwargs) + +def requirement(name, direct = True): + """ requirement returns the required dependency. """ + return _requirement(name) diff --git a/benchmarks/examples/localhost.yaml b/benchmarks/examples/localhost.yaml new file mode 100644 index 000000000..f70fe0fb7 --- /dev/null +++ b/benchmarks/examples/localhost.yaml @@ -0,0 +1,2 @@ +client: localhost +server: localhost diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD new file mode 100644 index 000000000..9546220c4 --- /dev/null +++ b/benchmarks/harness/BUILD @@ -0,0 +1,89 @@ +load("//benchmarks:defs.bzl", "py_library", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "harness", + srcs = ["__init__.py"], +) + +py_library( + name = "benchmark_driver", + srcs = ["benchmark_driver.py"], + deps = [ + "//benchmarks/harness/machine_mocks", + "//benchmarks/harness/machine_producers:machine_producer", + "//benchmarks/suites", + ], +) + +py_library( + name = "container", + srcs = ["container.py"], + deps = [ + requirement("asn1crypto", False), + requirement("chardet", False), + requirement("certifi", False), + requirement("docker", True), + requirement("docker-pycreds", False), + requirement("idna", False), + requirement("ptyprocess", False), + requirement("requests", False), + requirement("urllib3", False), + requirement("websocket-client", False), + ], +) + +py_library( + name = "machine", + srcs = ["machine.py"], + deps = [ + "//benchmarks/harness", + "//benchmarks/harness:container", + "//benchmarks/harness:ssh_connection", + "//benchmarks/harness:tunnel_dispatcher", + requirement("asn1crypto", False), + requirement("chardet", False), + requirement("certifi", False), + requirement("docker", True), + requirement("docker-pycreds", False), + requirement("idna", False), + requirement("ptyprocess", False), + requirement("requests", False), + requirement("urllib3", False), + requirement("websocket-client", False), + ], +) + +py_library( + name = "ssh_connection", + srcs = ["ssh_connection.py"], + deps = [ + "//benchmarks/harness", + requirement("bcrypt", False), + requirement("cffi", False), + requirement("paramiko", True), + requirement("cryptography", False), + ], +) + +py_library( + name = "tunnel_dispatcher", + srcs = ["tunnel_dispatcher.py"], + deps = [ + requirement("asn1crypto", False), + requirement("chardet", False), + requirement("certifi", False), + requirement("docker", True), + requirement("docker-pycreds", False), + requirement("idna", False), + requirement("pexpect", True), + requirement("ptyprocess", False), + requirement("requests", False), + requirement("urllib3", False), + requirement("websocket-client", False), + ], +) diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py new file mode 100644 index 000000000..a7f34da9e --- /dev/null +++ b/benchmarks/harness/__init__.py @@ -0,0 +1,25 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core benchmark utilities.""" + +import os + +# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a +# format string that accepts a single string parameter. +LOCAL_WORKLOADS_PATH = os.path.join( + os.path.dirname(__file__), "../workloads/{}") + +# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the +# remote host. This is a format string that accepts a single string parameter. +REMOTE_WORKLOADS_PATH = "workloads/{}" diff --git a/benchmarks/harness/benchmark_driver.py b/benchmarks/harness/benchmark_driver.py new file mode 100644 index 000000000..9abc21b54 --- /dev/null +++ b/benchmarks/harness/benchmark_driver.py @@ -0,0 +1,85 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Main driver for benchmarks.""" + +import copy +import statistics +import threading +import types + +from benchmarks import suites +from benchmarks.harness.machine_producers import machine_producer + + +# pylint: disable=too-many-instance-attributes +class BenchmarkDriver: + """Allocates machines and invokes a benchmark method.""" + + def __init__(self, + producer: machine_producer.MachineProducer, + method: types.FunctionType, + runs: int = 1, + **kwargs): + + self._producer = producer + self._method = method + self._kwargs = copy.deepcopy(kwargs) + self._threads = [] + self.lock = threading.RLock() + self._runs = runs + self._metric_results = {} + + def start(self): + """Starts a benchmark thread.""" + for _ in range(self._runs): + thread = threading.Thread(target=self._run_method) + thread.start() + self._threads.append(thread) + + def join(self): + """Joins the thread.""" + # pylint: disable=expression-not-assigned + [t.join() for t in self._threads] + + def _run_method(self): + """Runs all benchmarks.""" + machines = self._producer.get_machines( + suites.benchmark_machines(self._method)) + try: + result = self._method(*machines, **self._kwargs) + for name, res in result: + with self.lock: + if name in self._metric_results: + self._metric_results[name].append(res) + else: + self._metric_results[name] = [res] + finally: + # Always release. + self._producer.release_machines(machines) + + def median(self): + """Returns the median result, after join is finished.""" + for key, value in self._metric_results.items(): + yield key, [statistics.median(value)] + + def all(self): + """Returns all results.""" + for key, value in self._metric_results.items(): + yield key, value + + def meanstd(self): + """Returns all results.""" + for key, value in self._metric_results.items(): + mean = statistics.mean(value) + yield key, [mean, statistics.stdev(value, xbar=mean)] diff --git a/benchmarks/harness/container.py b/benchmarks/harness/container.py new file mode 100644 index 000000000..585436e20 --- /dev/null +++ b/benchmarks/harness/container.py @@ -0,0 +1,181 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Container definitions.""" + +import contextlib +import logging +import pydoc +import types +from typing import Tuple + +import docker +import docker.errors + +from benchmarks import workloads + + +class Container: + """Abstract container. + + Must be a context manager. + + Usage: + + with Container(client, image, ...): + ... + """ + + def run(self, **env) -> str: + """Run the container synchronously.""" + raise NotImplementedError + + def detach(self, **env): + """Run the container asynchronously.""" + raise NotImplementedError + + def address(self) -> Tuple[str, int]: + """Return the bound address for the container.""" + raise NotImplementedError + + def get_names(self) -> types.GeneratorType: + """Return names of all containers.""" + raise NotImplementedError + + +# pylint: disable=too-many-instance-attributes +class DockerContainer(Container): + """Class that handles creating a docker container.""" + + # pylint: disable=too-many-arguments + def __init__(self, + client: docker.DockerClient, + host: str, + image: str, + count: int = 1, + runtime: str = "runc", + port: int = 0, + **kwargs): + """Trys to setup "count" containers. + + Args: + client: A docker client from dockerpy. + host: The host address the image is running on. + image: The name of the image to run. + count: The number of containers to setup. + runtime: The container runtime to use. + port: The port to reserve. + **kwargs: Additional container options. + """ + assert count >= 1 + assert port == 0 or count == 1 + self._client = client + self._host = host + self._containers = [] + self._count = count + self._image = image + self._runtime = runtime + self._port = port + self._kwargs = kwargs + if port != 0: + self._ports = {"%d/tcp" % port: None} + else: + self._ports = {} + + @contextlib.contextmanager + def detach(self, **env): + env = ["%s=%s" % (key, value) for (key, value) in env.items()] + # Start all containers. + for _ in range(self._count): + try: + # Start the container in a detached mode. + container = self._client.containers.run( + self._image, + detach=True, + remove=True, + runtime=self._runtime, + ports=self._ports, + environment=env, + **self._kwargs) + logging.info("Started detached container %s -> %s", self._image, + container.attrs["Id"]) + self._containers.append(container) + except Exception as exc: + self._clean_containers() + raise exc + try: + # Wait for all containers to be up. + for container in self._containers: + while not container.attrs["State"]["Running"]: + container = self._client.containers.get(container.attrs["Id"]) + yield self + finally: + self._clean_containers() + + def address(self) -> Tuple[str, int]: + assert self._count == 1 + assert self._port != 0 + container = self._client.containers.get(self._containers[0].attrs["Id"]) + port = container.attrs["NetworkSettings"]["Ports"][ + "%d/tcp" % self._port][0]["HostPort"] + return (self._host, port) + + def get_names(self) -> types.GeneratorType: + for container in self._containers: + yield container.name + + def run(self, **env) -> str: + env = ["%s=%s" % (key, value) for (key, value) in env.items()] + return self._client.containers.run( + self._image, + runtime=self._runtime, + ports=self._ports, + remove=True, + environment=env, + **self._kwargs).decode("utf-8") + + def _clean_containers(self): + """Kills all containers.""" + for container in self._containers: + try: + container.kill() + except docker.errors.NotFound: + pass + + +class MockContainer(Container): + """Mock of Container.""" + + def __init__(self, workload: str): + self._workload = workload + + def __enter__(self): + return self + + def run(self, **env): + # Lookup sample data if any exists for the workload module. We use a + # well-defined test locate and a well-defined sample function. + mod = pydoc.locate(workloads.__name__ + "." + self._workload) + if hasattr(mod, "sample"): + return mod.sample(**env) + return "" # No output. + + def address(self) -> Tuple[str, int]: + return ("example.com", 80) + + def get_names(self) -> types.GeneratorType: + yield "mock" + + @contextlib.contextmanager + def detach(self, **env): + yield self diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py new file mode 100644 index 000000000..66b719b63 --- /dev/null +++ b/benchmarks/harness/machine.py @@ -0,0 +1,224 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Machine abstraction passed to benchmarks to run docker containers. + +Abstraction for interacting with test machines. Machines are produced +by Machine producers and represent a local or remote machine. Benchmark +methods in /benchmarks/suite are passed the required number of machines in order +to run the benchmark. Machines contain methods to run commands via bash, +possibly over ssh. Machines also hold a connection to the docker UNIX socket +to run contianers. + + Typical usage example: + + machine = Machine() + machine.run(cmd) + machine.pull(path) + container = machine.container() +""" + +import logging +import re +import subprocess +import time +from typing import Tuple + +import docker + +from benchmarks import harness +from benchmarks.harness import container +from benchmarks.harness import machine_mocks +from benchmarks.harness import ssh_connection +from benchmarks.harness import tunnel_dispatcher + + +class Machine(object): + """The machine object is the primary object for benchmarks. + + Machine objects are passed to each metric function call and benchmarks use + machines to access real connections to those machines. + + Attributes: + _name: Name as a string + """ + _name = "" + + def run(self, cmd: str) -> Tuple[str, str]: + """Convenience method for running a bash command on a machine object. + + Some machines may point to the local machine, and thus, do not have ssh + connections. Run runs a command either local or over ssh and returns the + output stdout and stderr as strings. + + Args: + cmd: The command to run as a string. + + Returns: + The command output. + """ + raise NotImplementedError + + def read(self, path: str) -> str: + """Reads the contents of some file. + + This will be mocked. + + Args: + path: The path to the file to be read. + + Returns: + The file contents. + """ + raise NotImplementedError + + def pull(self, workload: str) -> str: + """Send the given workload to the machine, build and tag it. + + All images must be defined by the workloads directory. + + Args: + workload: The workload name. + + Returns: + The workload tag. + """ + raise NotImplementedError + + def container(self, image: str, **kwargs) -> container.Container: + """Returns a container object. + + Args: + image: The pulled image tag. + **kwargs: Additional container options. + + Returns: + :return: a container.Container object. + """ + raise NotImplementedError + + def sleep(self, amount: float): + """Sleeps the given amount of time.""" + time.sleep(amount) + + def __str__(self): + return self._name + + +class MockMachine(Machine): + """A mocked machine.""" + _name = "mock" + + def run(self, cmd: str) -> Tuple[str, str]: + return "", "" + + def read(self, path: str) -> str: + return machine_mocks.Readfile(path) + + def pull(self, workload: str) -> str: + return workload # Workload is the tag. + + def container(self, image: str, **kwargs) -> container.Container: + return container.MockContainer(image) + + def sleep(self, amount: float): + pass + + +def get_address(machine: Machine) -> str: + """Return a machine's default address.""" + default_route, _ = machine.run("ip route get 8.8.8.8") + return re.search(" src ([0-9.]+) ", default_route).group(1) + + +class LocalMachine(Machine): + """The local machine. + + Attributes: + _name: Name as a string + _docker_client: a pythonic connection to to the local dockerd unix socket. + See: https://github.com/docker/docker-py + """ + + def __init__(self, name): + self._name = name + self._docker_client = docker.from_env() + + def run(self, cmd: str) -> Tuple[str, str]: + process = subprocess.Popen( + cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + return stdout.decode("utf-8"), stderr.decode("utf-8") + + def read(self, path: str) -> str: + # Read the exact path locally. + return open(path, "r").read() + + def pull(self, workload: str) -> str: + # Run the docker build command locally. + logging.info("Building %s@%s locally...", workload, self._name) + self.run("docker build --tag={} {}".format( + workload, harness.LOCAL_WORKLOADS_PATH.format(workload))) + return workload # Workload is the tag. + + def container(self, image: str, **kwargs) -> container.Container: + # Return a local docker container directly. + return container.DockerContainer(self._docker_client, get_address(self), + image, **kwargs) + + def sleep(self, amount: float): + time.sleep(amount) + + +class RemoteMachine(Machine): + """Remote machine accessible via an SSH connection. + + Attributes: + _name: Name as a string + _ssh_connection: a paramiko backed ssh connection which can be used to run + commands on this machine + _tunnel: a python wrapper around a port forwarded ssh connection between a + local unix socket and the remote machine's dockerd unix socket. + _docker_client: a pythonic wrapper backed by the _tunnel. Allows sending + docker commands: see https://github.com/docker/docker-py + """ + + def __init__(self, name, **kwargs): + self._name = name + self._ssh_connection = ssh_connection.SSHConnection(name, **kwargs) + self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs) + self._tunnel.connect() + self._docker_client = self._tunnel.get_docker_client() + + def run(self, cmd: str) -> Tuple[str, str]: + return self._ssh_connection.run(cmd) + + def read(self, path: str) -> str: + # Just cat remotely. + stdout, stderr = self._ssh_connection.run("cat '{}'".format(path)) + return stdout + stderr + + def pull(self, workload: str) -> str: + # Push to the remote machine and build. + logging.info("Building %s@%s remotely...", workload, self._name) + remote_path = self._ssh_connection.send_workload(workload) + self.run("docker build --tag={} {}".format(workload, remote_path)) + return workload # Workload is the tag. + + def container(self, image: str, **kwargs) -> container.Container: + # Return a remote docker container. + return container.DockerContainer(self._docker_client, get_address(self), + image, **kwargs) + + def sleep(self, amount: float): + time.sleep(amount) diff --git a/benchmarks/harness/machine_mocks/BUILD b/benchmarks/harness/machine_mocks/BUILD new file mode 100644 index 000000000..c8ec4bc79 --- /dev/null +++ b/benchmarks/harness/machine_mocks/BUILD @@ -0,0 +1,9 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "machine_mocks", + srcs = ["__init__.py"], +) diff --git a/benchmarks/harness/machine_mocks/__init__.py b/benchmarks/harness/machine_mocks/__init__.py new file mode 100644 index 000000000..00f0085d7 --- /dev/null +++ b/benchmarks/harness/machine_mocks/__init__.py @@ -0,0 +1,81 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Machine mock files.""" + +MEMINFO = """\ +MemTotal: 7652344 kB +MemFree: 7174724 kB +MemAvailable: 7152008 kB +Buffers: 7544 kB +Cached: 178856 kB +SwapCached: 0 kB +Active: 270928 kB +Inactive: 68436 kB +Active(anon): 153124 kB +Inactive(anon): 880 kB +Active(file): 117804 kB +Inactive(file): 67556 kB +Unevictable: 0 kB +Mlocked: 0 kB +SwapTotal: 0 kB +SwapFree: 0 kB +Dirty: 900 kB +Writeback: 0 kB +AnonPages: 153000 kB +Mapped: 129120 kB +Shmem: 1044 kB +Slab: 60864 kB +SReclaimable: 22792 kB +SUnreclaim: 38072 kB +KernelStack: 2672 kB +PageTables: 5756 kB +NFS_Unstable: 0 kB +Bounce: 0 kB +WritebackTmp: 0 kB +CommitLimit: 3826172 kB +Committed_AS: 663836 kB +VmallocTotal: 34359738367 kB +VmallocUsed: 0 kB +VmallocChunk: 0 kB +HardwareCorrupted: 0 kB +AnonHugePages: 0 kB +ShmemHugePages: 0 kB +ShmemPmdMapped: 0 kB +CmaTotal: 0 kB +CmaFree: 0 kB +HugePages_Total: 0 +HugePages_Free: 0 +HugePages_Rsvd: 0 +HugePages_Surp: 0 +Hugepagesize: 2048 kB +DirectMap4k: 94196 kB +DirectMap2M: 4624384 kB +DirectMap1G: 3145728 kB +""" + +CONTENTS = { + "/proc/meminfo": MEMINFO, +} + + +def Readfile(path: str) -> str: + """Reads a mock file. + + Args: + path: The target path. + + Returns: + Mocked file contents or None. + """ + return CONTENTS.get(path, None) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD new file mode 100644 index 000000000..a48da02a1 --- /dev/null +++ b/benchmarks/harness/machine_producers/BUILD @@ -0,0 +1,40 @@ +load("//benchmarks:defs.bzl", "py_library", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "harness", + srcs = ["__init__.py"], +) + +py_library( + name = "machine_producer", + srcs = ["machine_producer.py"], +) + +py_library( + name = "mock_producer", + srcs = ["mock_producer.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/harness/machine_producers:machine_producer", + ], +) + +py_library( + name = "yaml_producer", + srcs = ["yaml_producer.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/harness/machine_producers:machine_producer", + requirement("PyYAML", False), + ], +) + +py_library( + name = "gcloud_mock_recorder", + srcs = ["gcloud_mock_recorder.py"], +) diff --git a/benchmarks/harness/machine_producers/__init__.py b/benchmarks/harness/machine_producers/__init__.py new file mode 100644 index 000000000..634ef4843 --- /dev/null +++ b/benchmarks/harness/machine_producers/__init__.py @@ -0,0 +1,13 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/benchmarks/harness/machine_producers/gcloud_mock_recorder.py b/benchmarks/harness/machine_producers/gcloud_mock_recorder.py new file mode 100644 index 000000000..fd9837a37 --- /dev/null +++ b/benchmarks/harness/machine_producers/gcloud_mock_recorder.py @@ -0,0 +1,97 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A recorder and replay for testing the GCloudProducer. + +MockPrinter and MockReader handle printing and reading mock data for the +purposes of testing. MockPrinter is passed to GCloudProducer objects. The user +can then run scenarios and record them for playback in tests later. + +MockReader is passed to MockGcloudProducer objects and handles reading the +previously recorded mock data. + +It is left to the user to check if data printed is properly redacted for their +own use. The intended usecase for this class is data coming from gcloud +commands, which will contain public IPs and other instance data. + +The data format is json and printed/read from the ./test_data directory. The +data is the output of subprocess.CompletedProcess objects in json format. + + Typical usage example: + + recorder = MockPrinter() + producer = GCloudProducer(args, recorder) + machines = producer.get_machines(1) + with open("my_file.json") as fd: + recorder.write_out(fd) + + reader = MockReader(filename) + producer = MockGcloudProducer(args, mock) + machines = producer.get_machines(1) + assert len(machines) == 1 +""" + +import io +import json +import subprocess + + +class MockPrinter(object): + """Handles printing Mock data for MockGcloudProducer. + + Attributes: + _records: list of json object records for printing + """ + + def __init__(self): + self._records = [] + + def record(self, entry: subprocess.CompletedProcess): + """Records data and strips out ip addresses.""" + + record = { + "args": entry.args, + "stdout": entry.stdout.decode("utf-8"), + "returncode": str(entry.returncode) + } + self._records.append(record) + + def write_out(self, fd: io.FileIO): + """Prints out the data into the given filepath.""" + fd.write(json.dumps(self._records, indent=4)) + + +class MockReader(object): + """Handles reading Mock data for MockGcloudProducer. + + Attributes: + _records: List[json] records read from the passed in file. + """ + + def __init__(self, filepath: str): + with open(filepath, "rb") as file: + self._records = json.loads(file.read()) + self._i = 0 + + def __iter__(self): + return self + + def __next__(self, args) -> subprocess.CompletedProcess: + """Returns the next record as a CompletedProcess.""" + if self._i < len(self._records): + record = self._records[self._i] + stdout = record["stdout"].encode("ascii") + returncode = int(record["returncode"]) + return subprocess.CompletedProcess( + args=args, returncode=returncode, stdout=stdout, stderr=b"") + raise StopIteration() diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py new file mode 100644 index 000000000..124ee14cc --- /dev/null +++ b/benchmarks/harness/machine_producers/machine_producer.py @@ -0,0 +1,30 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract types.""" + +from typing import List + +from benchmarks.harness import machine + + +class MachineProducer: + """Abstract Machine producer.""" + + def get_machines(self, num_machines: int) -> List[machine.Machine]: + """Returns the requested number of machines.""" + raise NotImplementedError + + def release_machines(self, machine_list: List[machine.Machine]): + """Releases the given set of machines.""" + raise NotImplementedError diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py new file mode 100644 index 000000000..4f29ad53f --- /dev/null +++ b/benchmarks/harness/machine_producers/mock_producer.py @@ -0,0 +1,31 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Producers of mocks.""" + +from typing import List + +from benchmarks.harness import machine +from benchmarks.harness.machine_producers import machine_producer + + +class MockMachineProducer(machine_producer.MachineProducer): + """Produces MockMachine objects.""" + + def get_machines(self, num_machines: int) -> List[machine.MockMachine]: + """Returns the request number of MockMachines.""" + return [machine.MockMachine() for i in range(num_machines)] + + def release_machines(self, machine_list: List[machine.MockMachine]): + """No-op.""" + return diff --git a/benchmarks/harness/machine_producers/yaml_producer.py b/benchmarks/harness/machine_producers/yaml_producer.py new file mode 100644 index 000000000..5d334e480 --- /dev/null +++ b/benchmarks/harness/machine_producers/yaml_producer.py @@ -0,0 +1,106 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Producers based on yaml files.""" + +import os +import threading +from typing import Dict +from typing import List + +import yaml + +from benchmarks.harness import machine +from benchmarks.harness.machine_producers import machine_producer + + +class YamlMachineProducer(machine_producer.MachineProducer): + """Loads machines from a yaml file.""" + + def __init__(self, path: str): + self.machines = build_machines(path) + self.max_machines = len(self.machines) + self.machine_condition = threading.Condition() + + def get_machines(self, num_machines: int) -> List[machine.Machine]: + if num_machines > self.max_machines: + raise ValueError( + "Insufficient Ammount of Machines. {ask} asked for and have {max_num} max." + .format(ask=num_machines, max_num=self.max_machines)) + + with self.machine_condition: + while not self._enough_machines(num_machines): + self.machine_condition.wait(timeout=1) + return [self.machines.pop(0) for _ in range(num_machines)] + + def release_machines(self, machine_list: List[machine.Machine]): + with self.machine_condition: + while machine_list: + next_machine = machine_list.pop() + self.machines.append(next_machine) + self.machine_condition.notify() + + def _enough_machines(self, ask: int): + return ask <= len(self.machines) + + +def build_machines(path: str, num_machines: str = -1) -> List[machine.Machine]: + """Builds machine objects defined by the yaml file "path". + + Args: + path: The path to a yaml file which defines machines. + num_machines: Optional limit on how many machine objects to build. + + Returns: + Machine objects in a list. + + If num_machines is set, len(machines) <= num_machines. + """ + data = parse_yaml(path) + machines = [] + for key, value in data.items(): + if len(machines) == num_machines: + return machines + if isinstance(value, dict): + machines.append(machine.RemoteMachine(key, **value)) + else: + machines.append(machine.LocalMachine(key)) + return machines + + +def parse_yaml(path: str) -> Dict[str, Dict[str, str]]: + """Parse the yaml file pointed by path. + + Args: + path: The path to yaml file. + + Returns: + The contents of the yaml file as a dictionary. + """ + data = get_file_contents(path) + return yaml.load(data, Loader=yaml.Loader) + + +def get_file_contents(path: str) -> str: + """Dumps the file contents to a string and returns them. + + Args: + path: The path to dump. + + Returns: + The file contents as a string. + """ + if not os.path.isabs(path): + path = os.path.abspath(path) + with open(path) as input_file: + return input_file.read() diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py new file mode 100644 index 000000000..fcbfbcdb2 --- /dev/null +++ b/benchmarks/harness/ssh_connection.py @@ -0,0 +1,111 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SSHConnection handles the details of SSH connections.""" + +import os +import warnings + +import paramiko + +from benchmarks import harness + +# Get rid of paramiko Cryptography Warnings. +warnings.filterwarnings(action="ignore", module=".*paramiko.*") + + +def send_one_file(client: paramiko.SSHClient, path: str, remote_dir: str): + """Sends a single file via an SSH client. + + Args: + client: The existing SSH client. + path: The local path. + remote_dir: The remote directory. + """ + filename = path.split("/").pop() + client.exec_command("mkdir -p " + remote_dir) + with client.open_sftp() as ftp_client: + ftp_client.put(path, os.path.join(remote_dir, filename)) + + +class SSHConnection: + """SSH connection to a remote machine.""" + + def __init__(self, name: str, hostname: str, key_path: str, username: str, + **kwargs): + """Sets up a paramiko ssh connection to the given hostname.""" + self._name = name # Unused. + self._hostname = hostname + self._username = username + self._key_path = key_path # RSA Key path + self._kwargs = kwargs + # SSHConnection wraps paramiko. paramiko supports RSA, ECDSA, and Ed25519 + # keys, and we've chosen to only suport and require RSA keys. paramiko + # supports RSA keys that begin with '----BEGIN RSAKEY----'. + # https://stackoverflow.com/questions/53600581/ssh-key-generated-by-ssh-keygen-is-not-recognized-by-paramiko + self.rsa_key = self._rsa() + self.run("true") # Validate. + + def _client(self) -> paramiko.SSHClient: + """Returns a connected SSH client.""" + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + hostname=self._hostname, + port=22, + username=self._username, + pkey=self.rsa_key, + allow_agent=False, + look_for_keys=False) + return client + + def _rsa(self): + if "key_password" in self._kwargs: + password = self._kwargs["key_password"] + else: + password = None + rsa = paramiko.RSAKey.from_private_key_file(self._key_path, password) + return rsa + + def run(self, cmd: str) -> (str, str): + """Runs a command via ssh. + + Args: + cmd: The shell command to run. + + Returns: + The contents of stdout and stderr. + """ + with self._client() as client: + _, stdout, stderr = client.exec_command(command=cmd) + stdout.channel.recv_exit_status() + stdout = stdout.read().decode("utf-8") + stderr = stderr.read().decode("utf-8") + return stdout, stderr + + def send_workload(self, name: str) -> str: + """Sends a workload to the remote machine. + + Args: + name: The workload name. + + Returns: + The remote path. + """ + with self._client() as client: + for dirpath, _, filenames in os.walk( + harness.LOCAL_WORKLOADS_PATH.format(name)): + for filename in filenames: + send_one_file(client, os.path.join(dirpath, filename), + harness.REMOTE_WORKLOADS_PATH.format(name)) + return harness.REMOTE_WORKLOADS_PATH.format(name) diff --git a/benchmarks/harness/tunnel_dispatcher.py b/benchmarks/harness/tunnel_dispatcher.py new file mode 100644 index 000000000..c56fd022a --- /dev/null +++ b/benchmarks/harness/tunnel_dispatcher.py @@ -0,0 +1,122 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tunnel handles setting up connections to remote machines. + +Tunnel dispatcher is a wrapper around the connection from a local UNIX socket +and a remote UNIX socket via SSH with port forwarding. This is done to +initialize the pythonic dockerpy client to run containers on the remote host by +connecting to /var/run/docker.sock (where Docker is listening). Tunnel +dispatcher sets up the local UNIX socket and calls the `ssh` command as a +subprocess, and holds a reference to that subprocess. It manages clean-up on +exit as best it can by killing the ssh subprocess and deleting the local UNIX +socket,stored in /tmp for easy cleanup in most systems if this fails. + + Typical usage example: + + t = Tunnel(name, **kwargs) + t.connect() + client = t.get_docker_client() # + client.containers.run("ubuntu", "echo hello world") + +""" + +import os +import tempfile +import time + +import docker +import pexpect + +SSH_TUNNEL_COMMAND = """ssh + -o GlobalKnownHostsFile=/dev/null + -o UserKnownHostsFile=/dev/null + -o StrictHostKeyChecking=no + -o IdentitiesOnly=yes + -nNT -L {filename}:/var/run/docker.sock + -i {key_path} + {username}@{hostname}""" + + +class Tunnel(object): + """The tunnel object represents the tunnel via ssh. + + This connects a local unix domain socket with a remote socket. + + Attributes: + _filename: a temporary name of the UNIX socket prefixed by the name + argument. + _hostname: the IP or resolvable hostname of the remote host. + _username: the username of the ssh_key used to run ssh. + _key_path: path to a valid key. + _key_password: optional password to the ssh key in _key_path + _process: holds reference to the ssh subprocess created. + + Returns: + The new minimum port. + + Raises: + ConnectionError: If no available port is found. + """ + + def __init__(self, + name: str, + hostname: str, + username: str, + key_path: str, + key_password: str = "", + **kwargs): + self._filename = tempfile.NamedTemporaryFile(prefix=name).name + self._hostname = hostname + self._username = username + self._key_path = key_path + self._key_password = key_password + self._kwargs = kwargs + self._process = None + + def connect(self): + """Connects the SSH tunnel and stores the subprocess reference in _process.""" + cmd = SSH_TUNNEL_COMMAND.format( + filename=self._filename, + key_path=self._key_path, + username=self._username, + hostname=self._hostname) + self._process = pexpect.spawn(cmd, timeout=10) + + # If given a password, assume we'll be asked for it. + if self._key_password: + self._process.expect(["Enter passphrase for key .*: "]) + self._process.sendline(self._key_password) + + while True: + # Wait for the tunnel to appear. + if self._process.exitstatus is not None: + raise ConnectionError("Error in setting up ssh tunnel") + if os.path.exists(self._filename): + return + time.sleep(0.1) + + def path(self): + """Return the socket file.""" + return self._filename + + def get_docker_client(self): + """Returns a docker client for this Tunnel.""" + return docker.DockerClient(base_url="unix:/" + self._filename) + + def __del__(self): + """Closes the ssh connection process and deletes the socket file.""" + if self._process: + self._process.close() + if os.path.exists(self._filename): + os.remove(self._filename) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 000000000..577eb1a2e --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,32 @@ +asn1crypto==1.2.0 +atomicwrites==1.3.0 +attrs==19.3.0 +bcrypt==3.1.7 +certifi==2019.9.11 +cffi==1.13.2 +chardet==3.0.4 +Click==7.0 +cryptography==2.8 +docker==3.7.0 +docker-pycreds==0.4.0 +idna==2.8 +importlib-metadata==0.23 +more-itertools==7.2.0 +packaging==19.2 +paramiko==2.6.0 +pathlib2==2.3.5 +pexpect==4.7.0 +pluggy==0.9.0 +ptyprocess==0.6.0 +py==1.8.0 +pycparser==2.19 +PyNaCl==1.3.0 +pyparsing==2.4.5 +pytest==4.3.0 +PyYAML==5.1.2 +requests==2.22.0 +six==1.13.0 +urllib3==1.25.7 +wcwidth==0.1.7 +websocket-client==0.56.0 +zipp==0.6.0 diff --git a/benchmarks/run.py b/benchmarks/run.py new file mode 100644 index 000000000..a22eb8641 --- /dev/null +++ b/benchmarks/run.py @@ -0,0 +1,19 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Benchmark runner.""" + +from benchmarks import runner + +if __name__ == "__main__": + runner.runner() diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD new file mode 100644 index 000000000..de24824cc --- /dev/null +++ b/benchmarks/runner/BUILD @@ -0,0 +1,53 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package(licenses = ["notice"]) + +py_library( + name = "runner", + srcs = ["__init__.py"], + data = [ + "//benchmarks/workloads:files", + ], + visibility = ["//benchmarks:__pkg__"], + deps = [ + "//benchmarks/harness:benchmark_driver", + "//benchmarks/harness/machine_producers:mock_producer", + "//benchmarks/harness/machine_producers:yaml_producer", + "//benchmarks/suites", + "//benchmarks/suites:absl", + "//benchmarks/suites:density", + "//benchmarks/suites:fio", + "//benchmarks/suites:helpers", + "//benchmarks/suites:http", + "//benchmarks/suites:media", + "//benchmarks/suites:ml", + "//benchmarks/suites:network", + "//benchmarks/suites:redis", + "//benchmarks/suites:startup", + "//benchmarks/suites:sysbench", + "//benchmarks/suites:syscall", + requirement("click", True), + ], +) + +py_test( + name = "runner_test", + srcs = ["runner_test.py"], + python_version = "PY3", + tags = [ + "local", + "manual", + ], + deps = [ + ":runner", + requirement("click", True), + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py new file mode 100644 index 000000000..9bf9cfd65 --- /dev/null +++ b/benchmarks/runner/__init__.py @@ -0,0 +1,301 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""High-level benchmark utility.""" + +import copy +import csv +import logging +import pkgutil +import pydoc +import re +import sys +import types +from typing import List +from typing import Tuple + +import click + +from benchmarks import suites +from benchmarks.harness import benchmark_driver +from benchmarks.harness.machine_producers import mock_producer +from benchmarks.harness.machine_producers import yaml_producer + + +@click.group() +@click.option( + "--verbose/--no-verbose", default=False, help="Enable verbose logging.") +@click.option("--debug/--no-debug", default=False, help="Enable debug logging.") +def runner(verbose: bool = False, debug: bool = False): + """Run distributed benchmarks. + + See the run and list commands for details. + + Args: + verbose: Enable verbose logging. + debug: Enable debug logging (supercedes verbose). + """ + if debug: + logging.basicConfig(level=logging.DEBUG) + elif verbose: + logging.basicConfig(level=logging.INFO) + + +def find_benchmarks( + regex: str) -> List[Tuple[str, types.ModuleType, types.FunctionType]]: + """Finds all available benchmarks. + + Args: + regex: A regular expression to match. + + Returns: + A (short_name, module, function) tuple for each match. + """ + pkgs = pkgutil.walk_packages(suites.__path__, suites.__name__ + ".") + found = [] + for _, name, _ in pkgs: + mod = pydoc.locate(name) + funcs = [ + getattr(mod, x) + for x in dir(mod) + if suites.is_benchmark(getattr(mod, x)) + ] + for func in funcs: + # Use the short_name with the benchmarks. prefix stripped. + prefix_len = len(suites.__name__ + ".") + short_name = mod.__name__[prefix_len:] + "." + func.__name__ + # Add to the list if a pattern is provided. + if re.compile(regex).match(short_name): + found.append((short_name, mod, func)) + return found + + +@runner.command("list") +@click.argument("method", nargs=-1) +def list_all(method): + """Lists available benchmarks.""" + if not method: + method = ".*" + else: + method = "(" + ",".join(method) + ")" + for (short_name, _, func) in find_benchmarks(method): + print("Benchmark %s:" % short_name) + metrics = suites.benchmark_metrics(func) + if func.__doc__: + print(" " + func.__doc__.lstrip().rstrip()) + if metrics: + print("\n Metrics:") + for metric in metrics: + print("\t{name}: {doc}".format(name=metric[0], doc=metric[1])) + print("\n") + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-branches +# pylint: disable=too-many-locals +@runner.command( + context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) +@click.pass_context +@click.argument("method") +@click.option("--mock/--no-mock", default=False, help="Mock the machines.") +@click.option("--env", default=None, help="Specify a yaml file with machines.") +@click.option( + "--runtime", default=["runc"], help="The runtime to use.", multiple=True) +@click.option("--metric", help="The metric to extract.", multiple=True) +@click.option( + "--runs", default=1, help="The number of times to run each benchmark.") +@click.option( + "--stat", + default="median", + help="How to aggregate the data from all runs." + "\nmedian - returns the median of all runs (default)" + "\nall - returns all results comma separated" + "\nmeanstd - returns result as mean,std") +# pylint: disable=too-many-statements +def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str], + metric: List[str], stat: str, **kwargs): + """Runs arbitrary benchmarks. + + All unknown command line flags are passed through to the underlying benchmark + method. Flags may be specified multiple times, in which case it is considered + a "dimension" for the test, and a comma-separated table will be emitted + instead of a single result. + + See the output of list to see available metrics for any given benchmark + method. The method parameter is a regular expression that will match against + available benchmarks. If multiple benchmarks match, then that is considered a + distinct "dimension" for the test. + + All benchmarks are run in parallel where possible, but have exclusive + ownership over the individual machines. + + Exactly one of the --mock and --env flag must be specified. + + Every benchmark method will be run the times indicated by --runs. + + Args: + ctx: Click context. + method: A regular expression for methods to be run. + runs: Number of runs. + env: Environment to use. + mock: If true, use mocked environment (supercedes env). + runtime: A list of runtimes to test. + metric: A list of metrics to extract. + stat: The class of statistics to extract. + **kwargs: Dimensions to test. + """ + # First, calculate additional arguments. + # + # This essentially calculates any arguments that appear multiple times, and + # moves those to the "dimensions" dictionary, which maps to lists. These + # dimensions are then iterated over to generate the relevant csv output. + dimensions = {} + + if stat not in ["median", "all", "meanstd"]: + raise ValueError("Illegal value for --result, see help.") + + def squish(key: str, value: str): + """Collapse an argument into kwargs or dimensions.""" + if key in dimensions: + # Extend an existing dimension. + dimensions[key].append(value) + elif key in kwargs: + # Create a new dimension. + dimensions[key] = [kwargs[key], value] + del kwargs[key] + else: + # A single value. + kwargs[key] = value + + for item in ctx.args: + if "=" in method: + # This must be the method. The method is simply set to the first + # non-matching argument, which we're also parsing here. + item, method = method, item + if "=" not in item: + logging.error("illegal argument: %s", item) + sys.exit(1) + (key, value) = item.lstrip("-").split("=", 1) + squish(key, value) + + # Convert runtime and metric to dimensions. + # + # They exist only in the arguments above for documentation purposes. + # Essentially here we are treating them like anything else. Note however, + # that an empty set here will result in a dimension. This is important for + # metrics, where an empty set actually means all metrics. + def fold(key: str, value, allow_flatten=False): + """Collapse a list value into kwargs or dimensions.""" + if len(value) == 1 and allow_flatten: + kwargs[key] = value[0] + else: + dimensions[key] = value + + fold("runtime", runtime, allow_flatten=True) + fold("metric", metric) + + # Lookup the methods. + # + # We match the method parameter to a regular expression. This allows you to + # do things like `run --mock .*` for a broad test. Note that we track the + # short_names in the dimensions here, and look up again in the recursion. + methods = { + short_name: func for (short_name, _, func) in find_benchmarks(method) + } + if not methods: + # Must match at least one method. + logging.error("no matching benchmarks for %s: try list.", method) + sys.exit(1) + fold("method", list(methods.keys()), allow_flatten=True) + + # Construct the environment. + if mock and env: + # You can't provide both. + logging.error("both --mock and --env are set: which one is it?") + sys.exit(1) + elif mock: + producer = mock_producer.MockMachineProducer() + elif env: + producer = yaml_producer.YamlMachineProducer(env) + else: + # You must provide one of mock or env. + logging.error("no enviroment provided: use --mock or --env.") + sys.exit(1) + + # Spin up the drivers. + # + # We ensure that metric is the last entry, because we have special behavior. + # They actually run the test once and the benchmark is a generator that + # produces all viable metrics. + dimension_keys = list(dimensions.keys()) + if "metric" in dimension_keys: + dimension_keys.remove("metric") + dimension_keys.append("metric") + drivers = [] + + def _start(keywords, finished, left): + """Runs a test across dimensions recursively.""" + # Resolve the method fully, it starts as a string. + if "method" in keywords and isinstance(keywords["method"], str): + keywords["method"] = methods[keywords["method"]] + # Is this a non-recursive case? + if not left: + driver = benchmark_driver.BenchmarkDriver(producer, runs=runs, **keywords) + driver.start() + drivers.append((finished, driver)) + else: + # Recurse on the next dimension. + current, left = left[0], left[1:] + keywords = copy.deepcopy(keywords) + if current == "metric": + # We use a generator, popped below. Note that metric is + # guaranteed to be the last element here, and we will provide + # the value for 'done' below when generating the csv. + keywords[current] = dimensions[current] + _start(keywords, finished, left) + else: + # Generate manually. + for value in dimensions[current]: + keywords[current] = value + _start(keywords, finished + [value], left) + + # Start all the drivers, recursively. + _start(kwargs, [], dimension_keys) + + # Finish all tests, write results. + output = csv.writer(sys.stdout) + output.writerow(dimension_keys + ["result"]) + for (done, driver) in drivers: + driver.join() + for (metric_name, result) in getattr(driver, stat)(): + output.writerow([ # Collapse the method name. + hasattr(x, "__name__") and x.__name__ or x for x in done + ] + [metric_name] + result) + + +@runner.command() +@click.argument("env") +@click.option( + "--cmd", default="uname -a", help="command to run on all found machines") +@click.option( + "--workload", default="true", help="workload to run all found machines") +def validate(env, cmd, workload): + """Validates an environment described by yaml file.""" + producer = yaml_producer.YamlMachineProducer(env) + for machine in producer.machines: + print("Machine %s:" % machine) + stdout, _ = machine.run(cmd) + print(" Output of '%s': %s" % (cmd, stdout.lstrip().rstrip())) + image = machine.pull(workload) + stdout = machine.container(image).run() + print(" Container %s: %s" % (workload, stdout.lstrip().rstrip())) diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py new file mode 100644 index 000000000..5719c2838 --- /dev/null +++ b/benchmarks/runner/runner_test.py @@ -0,0 +1,59 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Top-level tests.""" + +import os +import subprocess +import sys + +from click import testing +import pytest + +from benchmarks import runner + + +def _get_locale(): + output = subprocess.check_output(["locale", "-a"]) + locales = output.split() + if b"en_US.utf8" in locales: + return "en_US.UTF-8" + else: + return "C.UTF-8" + + +def _set_locale(): + locale = _get_locale() + if os.getenv("LANG") != locale: + os.environ["LANG"] = locale + os.environ["LC_ALL"] = locale + os.execv("/proc/self/exe", ["python"] + sys.argv) + + +def test_list(): + cli_runner = testing.CliRunner() + result = cli_runner.invoke(runner.runner, ["list"]) + print(result.output) + assert result.exit_code == 0 + + +def test_run(): + cli_runner = testing.CliRunner() + result = cli_runner.invoke(runner.runner, ["run", "--mock", "."]) + print(result.output) + assert result.exit_code == 0 + + +if __name__ == "__main__": + _set_locale() + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/suites/BUILD b/benchmarks/suites/BUILD new file mode 100644 index 000000000..04fc23261 --- /dev/null +++ b/benchmarks/suites/BUILD @@ -0,0 +1,130 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "suites", + srcs = ["__init__.py"], +) + +py_library( + name = "absl", + srcs = ["absl.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/absl", + ], +) + +py_library( + name = "density", + srcs = ["density.py"], + deps = [ + "//benchmarks/harness:container", + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + ], +) + +py_library( + name = "fio", + srcs = ["fio.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + "//benchmarks/workloads/fio", + ], +) + +py_library( + name = "helpers", + srcs = ["helpers.py"], + deps = ["//benchmarks/harness:machine"], +) + +py_library( + name = "http", + srcs = ["http.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/ab", + ], +) + +py_library( + name = "media", + srcs = ["media.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + "//benchmarks/workloads/ffmpeg", + ], +) + +py_library( + name = "ml", + srcs = ["ml.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:startup", + "//benchmarks/workloads/tensorflow", + ], +) + +py_library( + name = "network", + srcs = ["network.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + "//benchmarks/workloads/iperf", + ], +) + +py_library( + name = "redis", + srcs = ["redis.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/redisbenchmark", + ], +) + +py_library( + name = "startup", + srcs = ["startup.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + ], +) + +py_library( + name = "sysbench", + srcs = ["sysbench.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/sysbench", + ], +) + +py_library( + name = "syscall", + srcs = ["syscall.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/syscall", + ], +) diff --git a/benchmarks/suites/__init__.py b/benchmarks/suites/__init__.py new file mode 100644 index 000000000..360736cc3 --- /dev/null +++ b/benchmarks/suites/__init__.py @@ -0,0 +1,119 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core benchmark annotations.""" + +import functools +import inspect +import types +from typing import List +from typing import Tuple + +BENCHMARK_METRICS = '__benchmark_metrics__' +BENCHMARK_MACHINES = '__benchmark_machines__' + + +def is_benchmark(func: types.FunctionType) -> bool: + """Returns true if the given function is a benchmark.""" + return isinstance(func, types.FunctionType) and \ + hasattr(func, BENCHMARK_METRICS) and \ + hasattr(func, BENCHMARK_MACHINES) + + +def benchmark_metrics(func: types.FunctionType) -> List[Tuple[str, str]]: + """Returns the list of available metrics.""" + return [(metric.__name__, metric.__doc__) + for metric in getattr(func, BENCHMARK_METRICS)] + + +def benchmark_machines(func: types.FunctionType) -> int: + """Returns the number of machines required.""" + return getattr(func, BENCHMARK_MACHINES) + + +# pylint: disable=unused-argument +def default(value, **kwargs): + """Returns the passed value.""" + return value + + +def benchmark(metrics: List[types.FunctionType] = None, + machines: int = 1) -> types.FunctionType: + """Define a benchmark function with metrics. + + Args: + metrics: A list of metric functions. + machines: The number of machines required. + + Returns: + A function that accepts the given number of machines, and iteratively + returns a set of (metric_name, metric_value) pairs when called repeatedly. + """ + if not metrics: + # The default passes through. + metrics = [default] + + def decorator(func: types.FunctionType) -> types.FunctionType: + """Decorator function.""" + # Every benchmark should accept at least two parameters: + # runtime: The runtime to use for the benchmark (str, required). + # metrics: The metrics to use, if not the default (str, optional). + @functools.wraps(func) + def wrapper(*args, runtime: str, metric: list = None, **kwargs): + """Wrapper function.""" + # First -- ensure that we marshall all types appropriately. In + # general, we will call this with only strings. These strings will + # need to be converted to their underlying types/classes. + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.annotation != inspect.Parameter.empty and \ + param.name in kwargs and not isinstance(kwargs[param.name], param.annotation): + try: + # Marshall to the appropriate type. + kwargs[param.name] = param.annotation(kwargs[param.name]) + except Exception as exc: + raise ValueError( + 'illegal type for %s(%s=%s): %s' % + (func.__name__, param.name, kwargs[param.name], exc)) + elif param.default != inspect.Parameter.empty and \ + param.name not in kwargs: + # Ensure that we have the value set, because it will + # be passed to the metric function for evaluation. + kwargs[param.name] = param.default + + # Next, figure out how to apply a metric. We do this prior to + # running the underlying function to prevent having to wait a few + # minutes for a result just to see some error. + if not metric: + # Return all metrics in the iterator. + result = func(*args, runtime=runtime, **kwargs) + for metric_func in metrics: + yield (metric_func.__name__, metric_func(result, **kwargs)) + else: + result = None + for single_metric in metric: + for metric_func in metrics: + # Is this a function that matches the name? + # Apply this function to the result. + if metric_func.__name__ == single_metric: + if not result: + # Lazy evaluation: only if metric matches. + result = func(*args, runtime=runtime, **kwargs) + yield single_metric, metric_func(result, **kwargs) + + # Set metadata on the benchmark (used above). + setattr(wrapper, BENCHMARK_METRICS, metrics) + setattr(wrapper, BENCHMARK_MACHINES, machines) + return wrapper + + return decorator diff --git a/benchmarks/suites/absl.py b/benchmarks/suites/absl.py new file mode 100644 index 000000000..5d9b57a09 --- /dev/null +++ b/benchmarks/suites/absl.py @@ -0,0 +1,37 @@ +# python3 +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""absl build benchmark.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import absl + + +@suites.benchmark(metrics=[absl.elapsed_time], machines=1) +def build(target: machine.Machine, **kwargs) -> str: + """Runs the absl workload and report the absl build time. + + Runs the 'bazel build //absl/...' in a clean bazel directory and + monitors time elapsed. + + Args: + target: A machine object. + **kwargs: Additional container options. + + Returns: + Container output. + """ + image = target.pull("absl") + return target.container(image, **kwargs).run() diff --git a/benchmarks/suites/density.py b/benchmarks/suites/density.py new file mode 100644 index 000000000..89d29fb26 --- /dev/null +++ b/benchmarks/suites/density.py @@ -0,0 +1,121 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Density tests.""" + +import re +import types + +from benchmarks import suites +from benchmarks.harness import container +from benchmarks.harness import machine +from benchmarks.suites import helpers + + +# pylint: disable=unused-argument +def memory_usage(value, **kwargs): + """Returns the passed value.""" + return value + + +def density(target: machine.Machine, + workload: str, + count: int = 50, + wait: float = 0, + load_func: types.FunctionType = None, + **kwargs): + """Calculate the average memory usage per container. + + Args: + target: A machine object. + workload: The workload to run. + count: The number of containers to start. + wait: The time to wait after starting. + load_func: Callback that is called after count images have been started on + the given machine. + **kwargs: Additional container options. + + Returns: + The average usage in Kb per container. + """ + count = int(count) + + # Drop all caches. + helpers.drop_caches(target) + before = target.read("/proc/meminfo") + + # Load the workload. + image = target.pull(workload) + + with target.container( + image=image, count=count, **kwargs).detach() as containers: + # Call the optional load function callback if given. + if load_func: + load_func(target, containers) + # Wait 'wait' time before taking a measurement. + target.sleep(wait) + + # Drop caches again. + helpers.drop_caches(target) + after = target.read("/proc/meminfo") + + # Calculate the memory used. + available_re = re.compile(r"MemAvailable:\s*(\d+)\skB\n") + before_available = available_re.findall(before) + after_available = available_re.findall(after) + return 1024 * float(int(before_available[0]) - + int(after_available[0])) / float(count) + + +def load_redis(target: machine.Machine, containers: container.Container): + """Use redis-benchmark "LPUSH" to load each container with 1G of data. + + Args: + target: A machine object. + containers: A set of containers. + """ + target.pull("redisbenchmark") + for name in containers.get_names(): + flags = "-d 10000 -t LPUSH" + target.container( + "redisbenchmark", links={ + name: name + }).run( + host=name, flags=flags) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def empty(target: machine.Machine, **kwargs) -> float: + """Run trivial containers in a density test.""" + return density(target, workload="sleep", wait=1.0, **kwargs) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def node(target: machine.Machine, **kwargs) -> float: + """Run node containers in a density test.""" + return density(target, workload="node", wait=3.0, **kwargs) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def ruby(target: machine.Machine, **kwargs) -> float: + """Run ruby containers in a density test.""" + return density(target, workload="ruby", wait=3.0, **kwargs) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def redis(target: machine.Machine, **kwargs) -> float: + """Run redis containers in a density test.""" + if "count" not in kwargs: + kwargs["count"] = 5 + return density( + target, workload="redis", wait=3.0, load_func=load_redis, **kwargs) diff --git a/benchmarks/suites/fio.py b/benchmarks/suites/fio.py new file mode 100644 index 000000000..2171790c5 --- /dev/null +++ b/benchmarks/suites/fio.py @@ -0,0 +1,165 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File I/O tests.""" + +import os + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers +from benchmarks.workloads import fio + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +def run_fio(target: machine.Machine, + test: str, + ioengine: str = "sync", + size: int = 1024 * 1024 * 1024, + iodepth: int = 4, + blocksize: int = 1024 * 1024, + time: int = -1, + mount_dir: str = "", + filename: str = "file.dat", + tmpfs: bool = False, + ramp_time: int = 0, + **kwargs) -> str: + """FIO benchmarks. + + For more on fio see: + https://media.readthedocs.org/pdf/fio/latest/fio.pdf + + Args: + target: A machine object. + test: The test to run (read, write, randread, randwrite, etc.) + ioengine: The engine for I/O. + size: The size of the generated file in bytes (if an integer) or 5g, 16k, + etc. + iodepth: The I/O for certain engines. + blocksize: The blocksize for reads and writes in bytes (if an integer) or + 4k, etc. + time: If test is time based, how long to run in seconds. + mount_dir: The absolute path on the host to mount a bind mount. + filename: The name of the file to creat inside container. For a path of + /dir/dir/file, the script setup a volume like 'docker run -v + mount_dir:/dir/dir fio' and fio will create (and delete) the file + /dir/dir/file. If tmpfs is set, this /dir/dir will be a tmpfs. + tmpfs: If true, mount on tmpfs. + ramp_time: The time to run before recording statistics + **kwargs: Additional container options. + + Returns: + The output of fio as a string. + """ + # Pull the image before dropping caches. + image = target.pull("fio") + + if not mount_dir: + stdout, _ = target.run("pwd") + mount_dir = stdout.rstrip() + + # Setup the volumes. + volumes = {mount_dir: {"bind": "/disk", "mode": "rw"}} if not tmpfs else None + tmpfs = {"/disk": ""} if tmpfs else None + + # Construct a file in the volume. + filepath = os.path.join("/disk", filename) + + # If we are running a read test, us fio to write a file and then flush file + # data from memory. + if "read" in test: + target.container( + image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( + test="write", + ioengine="sync", + size=size, + iodepth=iodepth, + blocksize=blocksize, + path=filepath) + helpers.drop_caches(target) + + # Run the test. + time_str = "--time_base --runtime={time}".format( + time=time) if int(time) > 0 else "" + res = target.container( + image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( + test=test, + ioengine=ioengine, + size=size, + iodepth=iodepth, + blocksize=blocksize, + time=time_str, + path=filepath, + ramp_time=ramp_time) + + target.run( + "rm {path}".format(path=os.path.join(mount_dir.rstrip(), filename))) + + return res + + +@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) +def read(*args, **kwargs): + """Read test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="read", **kwargs) + + +@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) +def randread(*args, **kwargs): + """Random read test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="randread", **kwargs) + + +@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) +def write(*args, **kwargs): + """Write test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="write", **kwargs) + + +@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) +def randwrite(*args, **kwargs): + """Random write test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="randwrite", **kwargs) diff --git a/benchmarks/suites/helpers.py b/benchmarks/suites/helpers.py new file mode 100644 index 000000000..b3c7360ab --- /dev/null +++ b/benchmarks/suites/helpers.py @@ -0,0 +1,57 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Benchmark helpers.""" + +import datetime +from benchmarks.harness import machine + + +class Timer: + """Helper to time runtime of some call. + + Usage: + + with Timer as t: + # do something. + t.get_time_in_seconds() + """ + + def __init__(self): + self._start = datetime.datetime.now() + + def __enter__(self): + self.start() + return self + + def start(self): + """Starts the timer.""" + self._start = datetime.datetime.now() + + def elapsed(self) -> float: + """Returns the elapsed time in seconds.""" + return (datetime.datetime.now() - self._start).total_seconds() + + def __exit__(self, exception_type, exception_value, exception_traceback): + pass + + +def drop_caches(target: machine.Machine): + """Drops caches on the machine. + + Args: + target: A machine object. + """ + target.run("sudo sync") + target.run("sudo sysctl vm.drop_caches=3") + target.run("sudo sysctl vm.drop_caches=3") diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py new file mode 100644 index 000000000..ea9024e43 --- /dev/null +++ b/benchmarks/suites/http.py @@ -0,0 +1,138 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""HTTP benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import ab + + +# pylint: disable=too-many-arguments +def http(server: machine.Machine, + client: machine.Machine, + workload: str, + requests: int = 5000, + connections: int = 10, + port: int = 80, + path: str = "notfound", + **kwargs) -> str: + """Run apachebench (ab) against an http server. + + Args: + server: A machine object. + client: A machine object. + workload: The http-serving workload. + requests: Number of requests to send the server. Default is 5000. + connections: Number of concurent connections to use. Default is 10. + port: The port to access in benchmarking. + path: File to download, generally workload-specific. + **kwargs: Additional container options. + + Returns: + The full apachebench output. + """ + # Pull the client & server. + apachebench = client.pull("ab") + netcat = client.pull("netcat") + image = server.pull(workload) + + with server.container(image, port=port, **kwargs).detach() as container: + (host, port) = container.address() + # Wait for the server to come up. + client.container(netcat).run(host=host, port=port) + # Run the benchmark, no arguments. + return client.container(apachebench).run( + host=host, + port=port, + requests=requests, + connections=connections, + path=path) + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +def http_app(server: machine.Machine, + client: machine.Machine, + workload: str, + requests: int = 5000, + connections: int = 10, + port: int = 80, + path: str = "notfound", + **kwargs) -> str: + """Run apachebench (ab) against an http application. + + Args: + server: A machine object. + client: A machine object. + workload: The http-serving workload. + requests: Number of requests to send the server. Default is 5000. + connections: Number of concurent connections to use. Default is 10. + port: The port to use for benchmarking. + path: File to download, generally workload-specific. + **kwargs: Additional container options. + + Returns: + The full apachebench output. + """ + # Pull the client & server. + apachebench = client.pull("ab") + netcat = client.pull("netcat") + server_netcat = server.pull("netcat") + redis = server.pull("redis") + image = server.pull(workload) + redis_port = 6379 + redis_name = "redis_server" + + with server.container(redis, name=redis_name).detach(): + server.container(server_netcat, links={redis_name: redis_name})\ + .run(host=redis_name, port=redis_port) + with server.container(image, port=port, links={redis_name: redis_name}, **kwargs)\ + .detach(host=redis_name) as container: + (host, port) = container.address() + # Wait for the server to come up. + client.container(netcat).run(host=host, port=port) + # Run the benchmark, no arguments. + return client.container(apachebench).run( + host=host, + port=port, + requests=requests, + connections=connections, + path=path) + + +@suites.benchmark(metrics=[ab.transfer_rate, ab.latency], machines=2) +def httpd(*args, **kwargs) -> str: + """Apache2 benchmark.""" + return http(*args, workload="httpd", port=80, **kwargs) + + +@suites.benchmark( + metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) +def nginx(*args, **kwargs) -> str: + """Nginx benchmark.""" + return http(*args, workload="nginx", port=80, **kwargs) + + +@suites.benchmark( + metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) +def node(*args, **kwargs) -> str: + """Node benchmark.""" + return http_app(*args, workload="node_template", path="", port=8080, **kwargs) + + +@suites.benchmark( + metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) +def ruby(*args, **kwargs) -> str: + """Ruby benchmark.""" + return http_app(*args, workload="ruby_template", path="", port=9292, **kwargs) diff --git a/benchmarks/suites/media.py b/benchmarks/suites/media.py new file mode 100644 index 000000000..9cbffdaa1 --- /dev/null +++ b/benchmarks/suites/media.py @@ -0,0 +1,42 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Media processing benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers +from benchmarks.workloads import ffmpeg + + +@suites.benchmark(metrics=[ffmpeg.run_time], machines=1) +def transcode(target: machine.Machine, **kwargs) -> float: + """Runs a video transcoding workload and times it. + + Args: + target: A machine object. + **kwargs: Additional container options. + + Returns: + Total workload runtime. + """ + # Load before timing. + image = target.pull("ffmpeg") + + # Drop caches. + helpers.drop_caches(target) + + # Time startup + transcoding. + with helpers.Timer() as timer: + target.container(image, **kwargs).run() + return timer.elapsed() diff --git a/benchmarks/suites/ml.py b/benchmarks/suites/ml.py new file mode 100644 index 000000000..a394d1f69 --- /dev/null +++ b/benchmarks/suites/ml.py @@ -0,0 +1,33 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Machine Learning tests.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import startup +from benchmarks.workloads import tensorflow + + +@suites.benchmark(metrics=[tensorflow.run_time], machines=1) +def train(target: machine.Machine, **kwargs): + """Run the tensorflow benchmark and return the runtime in seconds of workload. + + Args: + target: A machine object. + **kwargs: Additional container options. + + Returns: + The total runtime. + """ + return startup.startup(target, workload="tensorflow", count=1, **kwargs) diff --git a/benchmarks/suites/network.py b/benchmarks/suites/network.py new file mode 100644 index 000000000..f973cf3f1 --- /dev/null +++ b/benchmarks/suites/network.py @@ -0,0 +1,101 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Network microbenchmarks.""" + +from typing import Dict + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers +from benchmarks.workloads import iperf + + +def run_iperf(client: machine.Machine, + server: machine.Machine, + client_kwargs: Dict[str, str] = None, + server_kwargs: Dict[str, str] = None) -> str: + """Measure iperf performance. + + Args: + client: A machine object. + server: A machine object. + client_kwargs: Additional client container options. + server_kwargs: Additional server container options. + + Returns: + The output of iperf. + """ + if not client_kwargs: + client_kwargs = dict() + if not server_kwargs: + server_kwargs = dict() + + # Pull images. + netcat = client.pull("netcat") + iperf_client_image = client.pull("iperf") + iperf_server_image = server.pull("iperf") + + # Set this due to a bug in the kernel that resets connections. + client.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") + server.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") + + with server.container( + iperf_server_image, port=5001, **server_kwargs).detach() as iperf_server: + (host, port) = iperf_server.address() + # Wait until the service is available. + client.container(netcat).run(host=host, port=port) + # Run a warm-up run. + client.container( + iperf_client_image, stderr=True, **client_kwargs).run( + host=host, port=port) + # Run the client with relevant arguments. + res = client.container(iperf_client_image, stderr=True, **client_kwargs)\ + .run(host=host, port=port) + helpers.drop_caches(client) + return res + + +@suites.benchmark(metrics=[iperf.bandwidth], machines=2) +def upload(client: machine.Machine, server: machine.Machine, **kwargs) -> str: + """Measure upload performance. + + Args: + client: A machine object. + server: A machine object. + **kwargs: Client container options. + + Returns: + The output of iperf. + """ + if kwargs["runtime"] == "runc": + kwargs["network_mode"] = "host" + return run_iperf(client, server, client_kwargs=kwargs) + + +@suites.benchmark(metrics=[iperf.bandwidth], machines=2) +def download(client: machine.Machine, server: machine.Machine, **kwargs) -> str: + """Measure download performance. + + Args: + client: A machine object. + server: A machine object. + **kwargs: Server container options. + + Returns: + The output of iperf. + """ + + client_kwargs = {"network_mode": "host"} + return run_iperf( + client, server, client_kwargs=client_kwargs, server_kwargs=kwargs) diff --git a/benchmarks/suites/redis.py b/benchmarks/suites/redis.py new file mode 100644 index 000000000..b84dd073d --- /dev/null +++ b/benchmarks/suites/redis.py @@ -0,0 +1,46 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Redis benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import redisbenchmark + + +@suites.benchmark(metrics=list(redisbenchmark.METRICS.values()), machines=2) +def redis(server: machine.Machine, + client: machine.Machine, + flags: str = "", + **kwargs) -> str: + """Run redis-benchmark on client pointing at server machine. + + Args: + server: A machine object. + client: A machine object. + flags: Flags to pass redis-benchmark. + **kwargs: Additional container options. + + Returns: + Output from redis-benchmark. + """ + redis_server = server.pull("redis") + redis_client = client.pull("redisbenchmark") + netcat = client.pull("netcat") + with server.container( + redis_server, port=6379, **kwargs).detach() as container: + (host, port) = container.address() + # Wait for the container to be up. + client.container(netcat).run(host=host, port=port) + # Run all redis benchmarks. + return client.container(redis_client).run(host=host, port=port, flags=flags) diff --git a/benchmarks/suites/startup.py b/benchmarks/suites/startup.py new file mode 100644 index 000000000..a1b6c5753 --- /dev/null +++ b/benchmarks/suites/startup.py @@ -0,0 +1,110 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Start-up benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers + + +# pylint: disable=unused-argument +def startup_time_ms(value, **kwargs): + """Returns average startup time per container in milliseconds. + + Args: + value: The floating point time in seconds. + **kwargs: Ignored. + + Returns: + The time given in milliseconds. + """ + return value * 1000 + + +def startup(target: machine.Machine, + workload: str, + count: int = 5, + port: int = 0, + **kwargs): + """Time the startup of some workload. + + Args: + target: A machine object. + workload: The workload to run. + count: Number of containers to start. + port: The port to check for liveness, if provided. + **kwargs: Additional container options. + + Returns: + The mean start-up time in seconds. + """ + # Load before timing. + image = target.pull(workload) + netcat = target.pull("netcat") + count = int(count) + port = int(port) + + with helpers.Timer() as timer: + for _ in range(count): + if not port: + # Run the container synchronously. + target.container(image, **kwargs).run() + else: + # Run a detached container until httpd available. + with target.container(image, port=port, **kwargs).detach() as server: + (server_host, server_port) = server.address() + target.container(netcat).run(host=server_host, port=server_port) + return timer.elapsed() / float(count) + + +@suites.benchmark(metrics=[startup_time_ms], machines=1) +def empty(target: machine.Machine, **kwargs) -> float: + """Time the startup of a trivial container. + + Args: + target: A machine object. + **kwargs: Additional startup options. + + Returns: + The time to run the container. + """ + return startup(target, workload="true", **kwargs) + + +@suites.benchmark(metrics=[startup_time_ms], machines=1) +def node(target: machine.Machine, **kwargs) -> float: + """Time the startup of the node container. + + Args: + target: A machine object. + **kwargs: Additional statup options. + + Returns: + The time to run the container. + """ + return startup(target, workload="node", port=8080, **kwargs) + + +@suites.benchmark(metrics=[startup_time_ms], machines=1) +def ruby(target: machine.Machine, **kwargs) -> float: + """Time the startup of the ruby container. + + Args: + target: A machine object. + **kwargs: Additional startup options. + + Returns: + The time to run the container. + """ + return startup(target, workload="ruby", port=3000, **kwargs) diff --git a/benchmarks/suites/sysbench.py b/benchmarks/suites/sysbench.py new file mode 100644 index 000000000..2a6e2126c --- /dev/null +++ b/benchmarks/suites/sysbench.py @@ -0,0 +1,119 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sysbench-based benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import sysbench + + +def run_sysbench(target: machine.Machine, + test: str = "cpu", + threads: int = 8, + time: int = 5, + options: str = "", + **kwargs) -> str: + """Run sysbench container with arguments. + + Args: + target: A machine object. + test: Relevant sysbench test to run (e.g. cpu, memory). + threads: The number of threads to use for tests. + time: The time to run tests. + options: Additional sysbench options. + **kwargs: Additional container options. + + Returns: + The output of the command as a string. + """ + image = target.pull("sysbench") + return target.container(image, **kwargs).run( + test=test, threads=threads, time=time, options=options) + + +@suites.benchmark(metrics=[sysbench.cpu_events_per_second], machines=1) +def cpu(target: machine.Machine, max_prime: int = 5000, **kwargs) -> str: + """Run sysbench CPU test. + + Additional arguments can be provided for sysbench. + + Args: + target: A machine object. + max_prime: The maximum prime number to search. + **kwargs: + - threads: The number of threads to use for tests. + - time: The time to run tests. + - options: Additional sysbench options. See sysbench tool: + https://github.com/akopytov/sysbench + + Returns: + Sysbench output. + """ + options = kwargs.pop("options", "") + options += " --cpu-max-prime={}".format(max_prime) + return run_sysbench(target, test="cpu", options=options, **kwargs) + + +@suites.benchmark(metrics=[sysbench.memory_ops_per_second], machines=1) +def memory(target: machine.Machine, **kwargs) -> str: + """Run sysbench memory test. + + Additional arguments can be provided per sysbench. + + Args: + target: A machine object. + **kwargs: + - threads: The number of threads to use for tests. + - time: The time to run tests. + - options: Additional sysbench options. See sysbench tool: + https://github.com/akopytov/sysbench + + Returns: + Sysbench output. + """ + return run_sysbench(target, test="memory", **kwargs) + + +@suites.benchmark( + metrics=[ + sysbench.mutex_time, sysbench.mutex_latency, sysbench.mutex_deviation + ], + machines=1) +def mutex(target: machine.Machine, + locks: int = 4, + count: int = 10000000, + threads: int = 8, + **kwargs) -> str: + """Run sysbench mutex test. + + Additional arguments can be provided per sysbench. + + Args: + target: A machine object. + locks: The number of locks to use. + count: The number of mutexes. + threads: The number of threads to use for tests. + **kwargs: + - time: The time to run tests. + - options: Additional sysbench options. See sysbench tool: + https://github.com/akopytov/sysbench + + Returns: + Sysbench output. + """ + options = kwargs.pop("options", "") + options += " --mutex-loops=1 --mutex-locks={} --mutex-num={}".format( + count, locks) + return run_sysbench( + target, test="mutex", options=options, threads=threads, **kwargs) diff --git a/benchmarks/suites/syscall.py b/benchmarks/suites/syscall.py new file mode 100644 index 000000000..fa7665b00 --- /dev/null +++ b/benchmarks/suites/syscall.py @@ -0,0 +1,37 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Syscall microbenchmark.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads.syscall import syscall_time_ns + + +@suites.benchmark(metrics=[syscall_time_ns], machines=1) +def syscall(target: machine.Machine, count: int = 1000000, **kwargs) -> str: + """Runs the syscall workload and report the syscall time. + + Runs the syscall 'SYS_gettimeofday(0,0)' 'count' times and monitors time + elapsed based on the runtime's MONOTONIC clock. + + Args: + target: A machine object. + count: The number of syscalls to execute. + **kwargs: Additional container options. + + Returns: + Container output. + """ + image = target.pull("syscall") + return target.container(image, **kwargs).run(count=count) diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD new file mode 100644 index 000000000..735d7127f --- /dev/null +++ b/benchmarks/tcp/BUILD @@ -0,0 +1,41 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("@rules_cc//cc:defs.bzl", "cc_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "tcp_proxy", + srcs = ["tcp_proxy.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +# nsjoin is a trivial replacement for nsenter. This is used because nsenter is +# not available on all systems where this benchmark is run (and we aim to +# minimize external dependencies.) + +cc_binary( + name = "nsjoin", + srcs = ["nsjoin.c"], + visibility = ["//:sandbox"], +) + +sh_binary( + name = "tcp_benchmark", + srcs = ["tcp_benchmark.sh"], + data = [ + ":nsjoin", + ":tcp_proxy", + ], + visibility = ["//:sandbox"], +) diff --git a/benchmarks/tcp/README.md b/benchmarks/tcp/README.md new file mode 100644 index 000000000..38e6e69f0 --- /dev/null +++ b/benchmarks/tcp/README.md @@ -0,0 +1,87 @@ +# TCP Benchmarks + +This directory contains a standardized TCP benchmark. This helps to evaluate the +performance of netstack and native networking stacks under various conditions. + +## `tcp_benchmark` + +This benchmark allows TCP throughput testing under various conditions. The setup +consists of an iperf client, a client proxy, a server proxy and an iperf server. +The client proxy and server proxy abstract the network mechanism used to +communicate between the iperf client and server. + +The setup looks like the following: + +``` + +--------------+ (native) +--------------+ + | iperf client |[lo @ 10.0.0.1]------>| client proxy | + +--------------+ +--------------+ + [client.0 @ 10.0.0.2] + (netstack) | | (native) + +------+-----+ + | + [br0] + | + Network emulation applied ---> [wan.0:wan.1] + | + [br1] + | + +------+-----+ + (netstack) | | (native) + [server.0 @ 10.0.0.3] + +--------------+ +--------------+ + | iperf server |<------[lo @ 10.0.0.4]| server proxy | + +--------------+ (native) +--------------+ +``` + +Different configurations can be run using different arguments. For example: + +* Native test under normal internet conditions: `tcp_benchmark` +* Native test under ideal conditions: `tcp_benchmark --ideal` +* Netstack client under ideal conditions: `tcp_benchmark --client --ideal` +* Netstack client with 5% packet loss: `tcp_benchmark --client --ideal --loss + 5` + +Use `tcp_benchmark --help` for full arguments. + +This tool may be used to easily generate data for graphing. For example, to +generate a CSV for various latencies, you might do: + +``` +rm -f /tmp/netstack_latency.csv /tmp/native_latency.csv +latencies=$(seq 0 5 50; + seq 60 10 100; + seq 125 25 250; + seq 300 50 500) +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/netstack_latency.csv +done +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/native_latency.csv +done +``` + +Similarly, to generate a CSV for various levels of packet loss, the following +would be appropriate: + +``` +rm -f /tmp/netstack_loss.csv /tmp/native_loss.csv +losses=$(seq 0 0.1 1.0; + seq 1.2 0.2 2.0; + seq 2.5 0.5 5.0; + seq 6.0 1.0 10.0) +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/netstack_loss.csv +done +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/native_loss.csv +done +``` diff --git a/benchmarks/tcp/nsjoin.c b/benchmarks/tcp/nsjoin.c new file mode 100644 index 000000000..524b4d549 --- /dev/null +++ b/benchmarks/tcp/nsjoin.c @@ -0,0 +1,47 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include <errno.h> +#include <fcntl.h> +#include <sched.h> +#include <stdio.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +int main(int argc, char** argv) { + if (argc <= 2) { + fprintf(stderr, "error: must provide a namespace file.\n"); + fprintf(stderr, "usage: %s <file> [arguments...]\n", argv[0]); + return 1; + } + + int fd = open(argv[1], O_RDONLY); + if (fd < 0) { + fprintf(stderr, "error opening %s: %s\n", argv[1], strerror(errno)); + return 1; + } + if (setns(fd, 0) < 0) { + fprintf(stderr, "error joining %s: %s\n", argv[1], strerror(errno)); + return 1; + } + + execvp(argv[2], &argv[2]); + return 1; +} diff --git a/benchmarks/tcp/tcp_benchmark.sh b/benchmarks/tcp/tcp_benchmark.sh new file mode 100755 index 000000000..69344c9c3 --- /dev/null +++ b/benchmarks/tcp/tcp_benchmark.sh @@ -0,0 +1,369 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TCP benchmark; see README.md for documentation. + +# Fixed parameters. +iperf_port=45201 # Not likely to be privileged. +proxy_port=44000 # Ditto. +client_addr=10.0.0.1 +client_proxy_addr=10.0.0.2 +server_proxy_addr=10.0.0.3 +server_addr=10.0.0.4 +mask=8 + +# Defaults; this provides a reasonable approximation of a decent internet link. +# Parameters can be varied independently from this set to see response to +# various changes in the kind of link available. +client=false +server=false +verbose=false +gso=0 +swgso=false +mtu=1280 # 1280 is a reasonable lowest-common-denominator. +latency=10 # 10ms approximates a fast, dedicated connection. +latency_variation=1 # +/- 1ms is a relatively low amount of jitter. +loss=0.1 # 0.1% loss is non-zero, but not extremely high. +duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses. +duration=30 # 30s is enough time to consistent results (experimentally). +helper_dir=$(dirname $0) +netstack_opts= + +# Check for netem support. +lsmod_output=$(lsmod | grep sch_netem) +if [ "$?" != "0" ]; then + echo "warning: sch_netem may not be installed." >&2 +fi + +while [ $# -gt 0 ]; do + case "$1" in + --client) + client=true + ;; + --client_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -client_tcp_probe_file=$1" + ;; + --server) + server=true + ;; + --verbose) + verbose=true + ;; + --gso) + shift + gso=$1 + ;; + --swgso) + swgso=true + ;; + --server_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -server_tcp_probe_file=$1" + ;; + --ideal) + mtu=1500 # Standard ethernet. + latency=0 # No latency. + latency_variation=0 # No jitter. + loss=0 # No loss. + duplicate=0 # No duplicates. + ;; + --mtu) + shift + [ "$#" -le 0 ] && echo "no mtu provided" && exit 1 + mtu=$1 + ;; + --sack) + netstack_opts="${netstack_opts} -sack" + ;; + --cubic) + netstack_opts="${netstack_opts} -cubic" + ;; + --duration) + shift + [ "$#" -le 0 ] && echo "no duration provided" && exit 1 + duration=$1 + ;; + --latency) + shift + [ "$#" -le 0 ] && echo "no latency provided" && exit 1 + latency=$1 + ;; + --latency-variation) + shift + [ "$#" -le 0 ] && echo "no latency variation provided" && exit 1 + latency_variation=$1 + ;; + --loss) + shift + [ "$#" -le 0 ] && echo "no loss probability provided" && exit 1 + loss=$1 + ;; + --duplicate) + shift + [ "$#" -le 0 ] && echo "no duplicate provided" && exit 1 + duplicate=$1 + ;; + --cpuprofile) + shift + netstack_opts="${netstack_opts} -cpuprofile=$1" + ;; + --memprofile) + shift + netstack_opts="${netstack_opts} -memprofile=$1" + ;; + --helpers) + shift + [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1 + helper_dir=$1 + ;; + *) + echo "usage: $0 [options]" + echo "options:" + echo " --help show this message" + echo " --verbose verbose output" + echo " --client use netstack as the client" + echo " --ideal reset all network emulation" + echo " --server use netstack as the server" + echo " --mtu set the mtu (bytes)" + echo " --sack enable SACK support" + echo " --cubic enable CUBIC congestion control for Netstack" + echo " --duration set the test duration (s)" + echo " --latency set the latency (ms)" + echo " --latency-variation set the latency variation" + echo " --loss set the loss probability (%)" + echo " --duplicate set the duplicate probability (%)" + echo " --helpers set the helper directory" + echo "" + echo "The output will of the script will be:" + echo " <throughput> <client-cpu-usage> <server-cpu-usage>" + exit 1 + esac + shift +done + +if [ ${verbose} == "true" ]; then + set -x +fi + +# Latency needs to be halved, since it's applied on both ways. +half_latency=$(echo ${latency}/2 | bc -l | awk '{printf "%1.2f", $0}') +half_loss=$(echo ${loss}/2 | bc -l | awk '{printf "%1.6f", $0}') +half_duplicate=$(echo ${duplicate}/2 | bc -l | awk '{printf "%1.6f", $0}') +helper_dir=${helper_dir#$(pwd)/} # Use relative paths. +proxy_binary=${helper_dir}/tcp_proxy +nsjoin_binary=${helper_dir}/nsjoin + +if [ ! -e ${proxy_binary} ]; then + echo "Could not locate ${proxy_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ ! -e ${nsjoin_binary} ]; then + echo "Could not locate ${nsjoin_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ $(echo ${latency_variation} | awk '{printf "%1.2f", $0}') != "0.00" ]; then + # As long as there's some jitter, then we use the paretonormal distribution. + # This will preserve the minimum RTT, but add a realistic amount of jitter to + # the connection and cause re-ordering, etc. The regular pareto distribution + # appears to an unreasonable level of delay (we want only small spikes.) + distribution="distribution paretonormal" +else + distribution="" +fi + +# Client proxy that will listen on the client's iperf target forward traffic +# using the host networking stack. +client_args="${proxy_binary} -port ${proxy_port} -forward ${server_proxy_addr}:${proxy_port}" +if ${client}; then + # Client proxy that will listen on the client's iperf target + # and forward traffic using netstack. + client_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -client \\ + -mtu ${mtu} -iface client.0 -addr ${client_proxy_addr} -mask ${mask} \\ + -forward ${server_proxy_addr}:${proxy_port} -gso=${gso} -swgso=${swgso}" +fi + +# Server proxy that will listen on the proxy port and forward to the server's +# iperf server using the host networking stack. +server_args="${proxy_binary} -port ${proxy_port} -forward ${server_addr}:${iperf_port}" +if ${server}; then + # Server proxy that will listen on the proxy port and forward to the servers' + # iperf server using netstack. + server_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -server \\ + -mtu ${mtu} -iface server.0 -addr ${server_proxy_addr} -mask ${mask} \\ + -forward ${server_addr}:${iperf_port} -gso=${gso} -swgso=${swgso}" +fi + +# Specify loss and duplicate parameters only if they are non-zero +loss_opt="" +if [ "$(echo $half_loss | bc -q)" != "0" ]; then + loss_opt="loss random ${half_loss}%" +fi +duplicate_opt="" +if [ "$(echo $half_duplicate | bc -q)" != "0" ]; then + duplicate_opt="duplicate ${half_duplicate}%" +fi + +exec unshare -U -m -n -r -f -p --mount-proc /bin/bash << EOF +set -e -m + +if [ ${verbose} == "true" ]; then + set -x +fi + +mount -t tmpfs netstack-bench /tmp + +# We may have reset the path in the unshare if the shell loaded some public +# profiles. Ensure that tools are discoverable via the parent's PATH. +export PATH=${PATH} + +# Add client, server interfaces. +ip link add client.0 type veth peer name client.1 +ip link add server.0 type veth peer name server.1 + +# Add network emulation devices. +ip link add wan.0 type veth peer name wan.1 +ip link set wan.0 up +ip link set wan.1 up + +# Enroll on the bridge. +ip link add name br0 type bridge +ip link add name br1 type bridge +ip link set client.1 master br0 +ip link set server.1 master br1 +ip link set wan.0 master br0 +ip link set wan.1 master br1 +ip link set br0 up +ip link set br1 up + +# Set the MTU appropriately. +ip link set client.0 mtu ${mtu} +ip link set server.0 mtu ${mtu} +ip link set wan.0 mtu ${mtu} +ip link set wan.1 mtu ${mtu} + +# Add appropriate latency, loss and duplication. +# +# This is added in at the point of bridge connection. +for device in wan.0 wan.1; do + # NOTE: We don't support a loss correlation as testing has shown that it + # actually doesn't work. The man page actually has a small comment about this + # "It is also possible to add a correlation, but this option is now deprecated + # due to the noticed bad behavior." For more information see netem(8). + tc qdisc add dev \$device root netem \\ + delay ${half_latency}ms ${latency_variation}ms ${distribution} \\ + ${loss_opt} ${duplicate_opt} +done + +# Start a client proxy. +touch /tmp/client.netns +unshare -n mount --bind /proc/self/ns/net /tmp/client.netns + +# Move the endpoint into the namespace. +while ip link | grep client.0 > /dev/null; do + ip link set dev client.0 netns /tmp/client.netns +done + +if ! ${client}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/client.netns ip addr add ${client_proxy_addr}/${mask} dev client.0 +fi + +# Start a server proxy. +touch /tmp/server.netns +unshare -n mount --bind /proc/self/ns/net /tmp/server.netns +# Move the endpoint into the namespace. +while ip link | grep server.0 > /dev/null; do + ip link set dev server.0 netns /tmp/server.netns +done +if ! ${server}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/server.netns ip addr add ${server_proxy_addr}/${mask} dev server.0 +fi + +# Add client and server addresses, and bring everything up. +${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0 +${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0 +${nsjoin_binary} /tmp/client.netns ip link set client.0 up +${nsjoin_binary} /tmp/client.netns ip link set lo up +${nsjoin_binary} /tmp/server.netns ip link set server.0 up +${nsjoin_binary} /tmp/server.netns ip link set lo up +ip link set dev client.1 up +ip link set dev server.1 up + +${nsjoin_binary} /tmp/client.netns ${client_args} & +client_pid=\$! +${nsjoin_binary} /tmp/server.netns ${server_args} & +server_pid=\$! + +# Start the iperf server. +${nsjoin_binary} /tmp/server.netns iperf -p ${iperf_port} -s >&2 & +iperf_pid=\$! + +# Show traffic information. +if ! ${client} && ! ${server}; then + ${nsjoin_binary} /tmp/client.netns ping -c 100 -i 0.001 -W 1 ${server_addr} >&2 || true +fi + +results_file=\$(mktemp) +function cleanup { + rm -f \$results_file + kill -TERM \$client_pid + kill -TERM \$server_pid + wait \$client_pid + wait \$server_pid + kill -9 \$iperf_pid 2>/dev/null +} + +# Allow failure from this point. +set +e +trap cleanup EXIT + +# Run the benchmark, recording the results file. +while ${nsjoin_binary} /tmp/client.netns iperf \\ + -p ${proxy_port} -c ${client_addr} -t ${duration} -f m 2>&1 \\ + | tee \$results_file \\ + | grep "connect failed" >/dev/null; do + sleep 0.1 # Wait for all services. +done + +# Unlink all relevant devices from the bridge. This is because when the bridge +# is deleted, the kernel may hang. It appears that this problem is fixed in +# upstream commit 1ce5cce895309862d2c35d922816adebe094fe4a. +ip link set client.1 nomaster +ip link set server.1 nomaster +ip link set wan.0 nomaster +ip link set wan.1 nomaster + +# Emit raw results. +cat \$results_file >&2 + +# Emit a useful result (final throughput). +mbits=\$(grep Mbits/sec \$results_file \\ + | sed -n -e 's/^.*[[:space:]]\\([[:digit:]]\\+\\(\\.[[:digit:]]\\+\\)\\?\\)[[:space:]]*Mbits\\/sec.*/\\1/p') +client_cpu_ticks=\$(cat /proc/\$client_pid/stat \\ + | awk '{print (\$14+\$15);}') +server_cpu_ticks=\$(cat /proc/\$server_pid/stat \\ + | awk '{print (\$14+\$15);}') +ticks_per_sec=\$(getconf CLK_TCK) +client_cpu_load=\$(bc -l <<< \$client_cpu_ticks/\$ticks_per_sec/${duration}) +server_cpu_load=\$(bc -l <<< \$server_cpu_ticks/\$ticks_per_sec/${duration}) +echo \$mbits \$client_cpu_load \$server_cpu_load +EOF diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go new file mode 100644 index 000000000..361a56755 --- /dev/null +++ b/benchmarks/tcp/tcp_proxy.go @@ -0,0 +1,436 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary tcp_proxy is a simple TCP proxy. +package main + +import ( + "encoding/gob" + "flag" + "fmt" + "io" + "log" + "math/rand" + "net" + "os" + "os/signal" + "regexp" + "runtime" + "runtime/pprof" + "strconv" + "syscall" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +var ( + port = flag.Int("port", 0, "bind port (all addresses)") + forward = flag.String("forward", "", "forwarding target") + client = flag.Bool("client", false, "use netstack for listen") + server = flag.Bool("server", false, "use netstack for dial") + + // Netstack-specific options. + mtu = flag.Int("mtu", 1280, "mtu for network stack") + addr = flag.String("addr", "", "address for tap-based netstack") + mask = flag.Int("mask", 8, "mask size for address") + iface = flag.String("iface", "", "network interface name to bind for netstack") + sack = flag.Bool("sack", false, "enable SACK support for netstack") + cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack") + gso = flag.Int("gso", 0, "GSO maximum size") + swgso = flag.Bool("swgso", false, "software-level GSO") + clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.") + memprofile = flag.String("memprofile", "", "write memory profile to the specified file.") +) + +type impl interface { + dial(address string) (net.Conn, error) + listen(port int) (net.Listener, error) + printStats() +} + +type netImpl struct{} + +func (netImpl) dial(address string) (net.Conn, error) { + return net.Dial("tcp", address) +} + +func (netImpl) listen(port int) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf(":%d", port)) +} + +func (netImpl) printStats() { +} + +const ( + nicID = 1 // Fixed. + rcvBufSize = 1 << 20 // 1MB. +) + +type netstackImpl struct { + s *stack.Stack + addr tcpip.Address + mode string +} + +func setupNetwork(ifaceName string) (fd int, err error) { + // Get all interfaces in the namespace. + ifaces, err := net.Interfaces() + if err != nil { + return -1, fmt.Errorf("querying interfaces: %v", err) + } + + for _, iface := range ifaces { + if iface.Name != ifaceName { + continue + } + // Create the socket. + const protocol = 0x0300 // htons(ETH_P_ALL) + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + if err != nil { + return -1, fmt.Errorf("unable to create raw socket: %v", err) + } + + // Bind to the appropriate device. + ll := syscall.SockaddrLinklayer{ + Protocol: protocol, + Ifindex: iface.Index, + Pkttype: syscall.PACKET_HOST, + } + if err := syscall.Bind(fd, &ll); err != nil { + return -1, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) + } + + // RAW Sockets by default have a very small SO_RCVBUF of 256KB, + // up it to at least 1MB to reduce packet drops. + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, rcvBufSize); err != nil { + return -1, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + } + + if !*swgso && *gso != 0 { + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + return -1, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) + } + } + return fd, nil + } + return -1, fmt.Errorf("failed to find interface: %v", ifaceName) +} + +func newNetstackImpl(mode string) (impl, error) { + fd, err := setupNetwork(*iface) + if err != nil { + return nil, err + } + + // Parse details. + parsedAddr := tcpip.Address(net.ParseIP(*addr).To4()) + parsedDest := tcpip.Address("") // Filled in below. + parsedMask := tcpip.AddressMask("") // Filled in below. + switch *mask { + case 8: + parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0}) + case 16: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0}) + case 24: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0}) + default: + // This is just laziness; we don't expect a different mask. + return nil, fmt.Errorf("mask %d not supported", mask) + } + + // Create a new network stack. + netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()} + transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()} + s := stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + }) + + // Generate a new mac for the eth device. + mac := make(net.HardwareAddr, 6) + rand.Read(mac) // Fill with random data. + mac[0] &^= 0x1 // Clear multicast bit. + mac[0] |= 0x2 // Set local assignment bit (IEEE802). + ep, err := fdbased.New(&fdbased.Options{ + FDs: []int{fd}, + MTU: uint32(*mtu), + EthernetHeader: true, + Address: tcpip.LinkAddress(mac), + // Enable checksum generation as we need to generate valid + // checksums for the veth device to deliver our packets to the + // peer. But we do want to disable checksum verification as veth + // devices do perform GRO and the linux host kernel may not + // regenerate valid checksums after GRO. + TXChecksumOffload: false, + RXChecksumOffload: true, + PacketDispatchMode: fdbased.RecvMMsg, + GSOMaxSize: uint32(*gso), + SoftwareGSOEnabled: *swgso, + }) + if err != nil { + return nil, fmt.Errorf("failed to create FD endpoint: %v", err) + } + if err := s.CreateNIC(nicID, ep); err != nil { + return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil { + return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err) + } + + subnet, err := tcpip.NewSubnet(parsedDest, parsedMask) + if err != nil { + return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err) + } + // Add default route; we only support + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + + // Set protocol options. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %v", err) + } + + // Set Congestion Control to cubic if requested. + if *cubic { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %v", err) + } + } + + return netstackImpl{ + s: s, + addr: parsedAddr, + mode: mode, + }, nil +} + +func (n netstackImpl) dial(address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + if host == "" { + // A host must be provided for the dial. + return nil, fmt.Errorf("no host provided") + } + portNumber, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + addr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(net.ParseIP(host).To4()), + Port: uint16(portNumber), + } + conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return conn, nil +} + +func (n netstackImpl) listen(port int) (net.Listener, error) { + addr := tcpip.FullAddress{ + NIC: nicID, + Port: uint16(port), + } + listener, err := gonet.NewListener(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return listener, nil +} + +var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`) + +func (n netstackImpl) printStats() { + // Don't show zero fields. + stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "") + log.Printf("netstack %s Stats: %+v\n", n.mode, stats) +} + +// installProbe installs a TCP Probe function that will dump endpoint +// state to the specified file. It also returns a close func() that +// can be used to close the probeFile. +func (n netstackImpl) installProbe(probeFileName string) (close func()) { + // Install Probe to dump out end point state. + probeFile, err := os.Create(probeFileName) + if err != nil { + log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err) + } + probeEncoder := gob.NewEncoder(probeFile) + // Install a TCP Probe. + n.s.AddTCPProbe(func(state stack.TCPEndpointState) { + probeEncoder.Encode(state) + }) + return func() { probeFile.Close() } +} + +func main() { + flag.Parse() + if *port == 0 { + log.Fatalf("no port provided") + } + if *forward == "" { + log.Fatalf("no forward provided") + } + // Seed the random number generator to ensure that we are given MAC addresses that don't + // for the case of the client and server stack. + rand.Seed(time.Now().UTC().UnixNano()) + + if *cpuprofile != "" { + f, err := os.Create(*cpuprofile) + if err != nil { + log.Fatal("could not create CPU profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing CPU profile: ", err) + } + }() + if err := pprof.StartCPUProfile(f); err != nil { + log.Fatal("could not start CPU profile: ", err) + } + defer pprof.StopCPUProfile() + } + + var ( + in impl + out impl + err error + ) + if *server { + in, err = newNetstackImpl("server") + if *serverTCPProbeFile != "" { + defer in.(netstackImpl).installProbe(*serverTCPProbeFile)() + } + + } else { + in = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + if *client { + out, err = newNetstackImpl("client") + if *clientTCPProbeFile != "" { + defer out.(netstackImpl).installProbe(*clientTCPProbeFile)() + } + } else { + out = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + + // Dial forward before binding. + var next net.Conn + for { + next, err = out.dial(*forward) + if err == nil { + break + } + time.Sleep(50 * time.Millisecond) + log.Printf("connect failed retrying: %v", err) + } + + // Bind once to the server socket. + listener, err := in.listen(*port) + if err != nil { + // Should not happen, everything must be bound by this time + // this proxy is started. + log.Fatalf("unable to listen: %v", err) + } + log.Printf("client=%v, server=%v, ready.", *client, *server) + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM) + go func() { + <-sigs + if *cpuprofile != "" { + pprof.StopCPUProfile() + } + if *memprofile != "" { + f, err := os.Create(*memprofile) + if err != nil { + log.Fatal("could not create memory profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing memory profile: ", err) + } + }() + runtime.GC() // get up-to-date statistics + if err := pprof.WriteHeapProfile(f); err != nil { + log.Fatalf("Unable to write heap profile: %v", err) + } + } + os.Exit(0) + }() + + for { + // Forward all connections. + inConn, err := listener.Accept() + if err != nil { + // This should not happen; we are listening + // successfully. Exhausted all available FDs? + log.Fatalf("accept error: %v", err) + } + log.Printf("incoming connection established.") + + // Copy both ways. + go io.Copy(inConn, next) + go io.Copy(next, inConn) + + // Print stats every second. + go func() { + t := time.NewTicker(time.Second) + defer t.Stop() + for { + <-t.C + in.printStats() + out.printStats() + } + }() + + for { + // Dial again. + next, err = out.dial(*forward) + if err == nil { + break + } + } + } +} diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD new file mode 100644 index 000000000..643806105 --- /dev/null +++ b/benchmarks/workloads/BUILD @@ -0,0 +1,35 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "workloads", + srcs = ["__init__.py"], +) + +filegroup( + name = "files", + srcs = [ + "//benchmarks/workloads/ab:files", + "//benchmarks/workloads/absl:files", + "//benchmarks/workloads/curl:files", + "//benchmarks/workloads/ffmpeg:files", + "//benchmarks/workloads/fio:files", + "//benchmarks/workloads/httpd:files", + "//benchmarks/workloads/iperf:files", + "//benchmarks/workloads/netcat:files", + "//benchmarks/workloads/nginx:files", + "//benchmarks/workloads/node:files", + "//benchmarks/workloads/node_template:files", + "//benchmarks/workloads/redis:files", + "//benchmarks/workloads/redisbenchmark:files", + "//benchmarks/workloads/ruby:files", + "//benchmarks/workloads/ruby_template:files", + "//benchmarks/workloads/sleep:files", + "//benchmarks/workloads/sysbench:files", + "//benchmarks/workloads/syscall:files", + "//benchmarks/workloads/tensorflow:files", + "//benchmarks/workloads/true:files", + ], +) diff --git a/benchmarks/workloads/__init__.py b/benchmarks/workloads/__init__.py new file mode 100644 index 000000000..e12651e76 --- /dev/null +++ b/benchmarks/workloads/__init__.py @@ -0,0 +1,14 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Workloads, parsers and test data.""" diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD new file mode 100644 index 000000000..e99a8d674 --- /dev/null +++ b/benchmarks/workloads/ab/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "ab", + srcs = ["__init__.py"], +) + +py_test( + name = "ab_test", + srcs = ["ab_test.py"], + python_version = "PY3", + deps = [ + ":ab", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/ab/Dockerfile b/benchmarks/workloads/ab/Dockerfile new file mode 100644 index 000000000..0d0b6e2eb --- /dev/null +++ b/benchmarks/workloads/ab/Dockerfile @@ -0,0 +1,15 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + apache2-utils \ + && rm -rf /var/lib/apt/lists/* + +# Parameterized workload. +ENV requests 5000 +ENV connections 10 +ENV host localhost +ENV port 8080 +ENV path notfound +CMD ["sh", "-c", "ab -n ${requests} -c ${connections} http://${host}:${port}/${path}"] diff --git a/benchmarks/workloads/ab/__init__.py b/benchmarks/workloads/ab/__init__.py new file mode 100644 index 000000000..eedf8e083 --- /dev/null +++ b/benchmarks/workloads/ab/__init__.py @@ -0,0 +1,88 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Apachebench tool.""" + +import re + +SAMPLE_DATA = """This is ApacheBench, Version 2.3 <$Revision: 1826891 $> +Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ +Licensed to The Apache Software Foundation, http://www.apache.org/ + +Benchmarking 10.10.10.10 (be patient).....done + + +Server Software: Apache/2.4.38 +Server Hostname: 10.10.10.10 +Server Port: 80 + +Document Path: /latin10k.txt +Document Length: 210 bytes + +Concurrency Level: 1 +Time taken for tests: 0.180 seconds +Complete requests: 100 +Failed requests: 0 +Non-2xx responses: 100 +Total transferred: 38800 bytes +HTML transferred: 21000 bytes +Requests per second: 556.44 [#/sec] (mean) +Time per request: 1.797 [ms] (mean) +Time per request: 1.797 [ms] (mean, across all concurrent requests) +Transfer rate: 210.84 [Kbytes/sec] received + +Connection Times (ms) + min mean[+/-sd] median max +Connect: 0 0 0.2 0 2 +Processing: 1 2 1.0 1 8 +Waiting: 1 1 1.0 1 7 +Total: 1 2 1.2 1 10 + +Percentage of the requests served within a certain time (ms) + 50% 1 + 66% 2 + 75% 2 + 80% 2 + 90% 2 + 95% 3 + 98% 7 + 99% 10 + 100% 10 (longest request)""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def transfer_rate(data: str, **kwargs) -> float: + """Mean transfer rate in Kbytes/sec.""" + regex = r"Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received" + return float(re.compile(regex).search(data).group(1)) + + +# pylint: disable=unused-argument +def latency(data: str, **kwargs) -> float: + """Mean latency in milliseconds.""" + regex = r"Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s" + res = re.compile(regex).search(data) + return float(res.group(1)) + + +# pylint: disable=unused-argument +def requests_per_second(data: str, **kwargs) -> float: + """Requests per second.""" + regex = r"Requests per second:\s+(\d+\.?\d+?)\s+" + res = re.compile(regex).search(data) + return float(res.group(1)) diff --git a/benchmarks/workloads/ab/ab_test.py b/benchmarks/workloads/ab/ab_test.py new file mode 100644 index 000000000..4afac2996 --- /dev/null +++ b/benchmarks/workloads/ab/ab_test.py @@ -0,0 +1,42 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser test.""" + +import sys + +import pytest + +from benchmarks.workloads import ab + + +def test_transfer_rate_parser(): + """Test transfer rate parser.""" + res = ab.transfer_rate(ab.sample()) + assert res == 210.84 + + +def test_latency_parser(): + """Test latency parser.""" + res = ab.latency(ab.sample()) + assert res == 2 + + +def test_requests_per_second(): + """Test requests per second parser.""" + res = ab.requests_per_second(ab.sample()) + assert res == 556.44 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD new file mode 100644 index 000000000..bb499620e --- /dev/null +++ b/benchmarks/workloads/absl/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "absl", + srcs = ["__init__.py"], +) + +py_test( + name = "absl_test", + srcs = ["absl_test.py"], + python_version = "PY3", + deps = [ + ":absl", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/absl/Dockerfile b/benchmarks/workloads/absl/Dockerfile new file mode 100644 index 000000000..e935c5ddc --- /dev/null +++ b/benchmarks/workloads/absl/Dockerfile @@ -0,0 +1,24 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + wget \ + git \ + pkg-config \ + zip \ + g++ \ + zlib1g-dev \ + unzip \ + python3 \ + && rm -rf /var/lib/apt/lists/* +RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh +RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh +RUN ./bazel-0.27.0-installer-linux-x86_64.sh + +RUN git clone https://github.com/abseil/abseil-cpp.git +WORKDIR abseil-cpp +RUN git checkout 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f +RUN bazel clean +ENV path "absl/base/..." +CMD bazel build ${path} 2>&1 diff --git a/benchmarks/workloads/absl/__init__.py b/benchmarks/workloads/absl/__init__.py new file mode 100644 index 000000000..b40e3f915 --- /dev/null +++ b/benchmarks/workloads/absl/__init__.py @@ -0,0 +1,63 @@ +# python3 +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ABSL build benchmark.""" + +import re + +SAMPLE_BAZEL_OUTPUT = """Extracting Bazel installation... +Starting local Bazel server and connecting to it... +Loading: +Loading: 0 packages loaded +Loading: 0 packages loaded + currently loading: absl/algorithm ... (11 packages) +Analyzing: 241 targets (16 packages loaded, 0 targets configured) +Analyzing: 241 targets (21 packages loaded, 617 targets configured) +Analyzing: 241 targets (27 packages loaded, 687 targets configured) +Analyzing: 241 targets (32 packages loaded, 1105 targets configured) +Analyzing: 241 targets (32 packages loaded, 1294 targets configured) +Analyzing: 241 targets (35 packages loaded, 1575 targets configured) +Analyzing: 241 targets (35 packages loaded, 1575 targets configured) +Analyzing: 241 targets (36 packages loaded, 1603 targets configured) +Analyzing: 241 targets (36 packages loaded, 1603 targets configured) +INFO: Analyzed 241 targets (37 packages loaded, 1864 targets configured). +INFO: Found 241 targets... +[0 / 5] [Prepa] BazelWorkspaceStatusAction stable-status.txt +[16 / 50] [Analy] Compiling absl/base/dynamic_annotations.cc ... (20 actions, 10 running) +[60 / 77] Compiling external/com_google_googletest/googletest/src/gtest.cc; 5s processwrapper-sandbox ... (12 actions, 11 running) +[158 / 174] Compiling absl/container/internal/raw_hash_set_test.cc; 2s processwrapper-sandbox ... (12 actions, 11 running) +[278 / 302] Compiling absl/container/internal/raw_hash_set_test.cc; 6s processwrapper-sandbox ... (12 actions, 11 running) +[384 / 406] Compiling absl/container/internal/raw_hash_set_test.cc; 10s processwrapper-sandbox ... (12 actions, 11 running) +[581 / 604] Compiling absl/container/flat_hash_set_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) +[722 / 745] Compiling absl/container/node_hash_set_test.cc; 9s processwrapper-sandbox ... (12 actions, 11 running) +[846 / 867] Compiling absl/hash/hash_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) +INFO: From Compiling absl/debugging/symbolize_test.cc: +/tmp/cclCVipU.s: Assembler messages: +/tmp/cclCVipU.s:1662: Warning: ignoring changed section attributes for .text +[999 / 1,022] Compiling absl/hash/hash_test.cc; 19s processwrapper-sandbox ... (12 actions, 11 running) +[1,082 / 1,084] Compiling absl/container/flat_hash_map_test.cc; 7s processwrapper-sandbox +INFO: Elapsed time: 81.861s, Critical Path: 23.81s +INFO: 515 processes: 515 processwrapper-sandbox. +INFO: Build completed successfully, 1084 total actions +INFO: Build completed successfully, 1084 total actions""" + + +def sample(): + return SAMPLE_BAZEL_OUTPUT + + +# pylint: disable=unused-argument +def elapsed_time(data: str, **kwargs) -> float: + """Returns the elapsed time for running an absl build.""" + return float(re.compile(r"Elapsed time: (\d*.?\d*)s").search(data).group(1)) diff --git a/benchmarks/workloads/absl/absl_test.py b/benchmarks/workloads/absl/absl_test.py new file mode 100644 index 000000000..41f216999 --- /dev/null +++ b/benchmarks/workloads/absl/absl_test.py @@ -0,0 +1,31 @@ +# python3 +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ABSL build test.""" + +import sys + +import pytest + +from benchmarks.workloads import absl + + +def test_elapsed_time(): + """Test elapsed_time.""" + res = absl.elapsed_time(absl.sample()) + assert res == 81.861 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/curl/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/curl/Dockerfile b/benchmarks/workloads/curl/Dockerfile new file mode 100644 index 000000000..336cb088a --- /dev/null +++ b/benchmarks/workloads/curl/Dockerfile @@ -0,0 +1,14 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Accept a host and port parameter. +ENV host localhost +ENV port 8080 + +# Spin until we make a successful request. +CMD ["sh", "-c", "while ! curl -v -i http://$host:$port; do true; done"] diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD new file mode 100644 index 000000000..c1f2afc40 --- /dev/null +++ b/benchmarks/workloads/ffmpeg/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "ffmpeg", + srcs = ["__init__.py"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/ffmpeg/Dockerfile b/benchmarks/workloads/ffmpeg/Dockerfile new file mode 100644 index 000000000..f2f530d7c --- /dev/null +++ b/benchmarks/workloads/ffmpeg/Dockerfile @@ -0,0 +1,10 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* +WORKDIR /media +ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4 +CMD ["ffmpeg", "-i", "video.mp4", "-c:v", "libx264", "-preset", "veryslow", "output.mp4"] diff --git a/benchmarks/workloads/ffmpeg/__init__.py b/benchmarks/workloads/ffmpeg/__init__.py new file mode 100644 index 000000000..7578a443b --- /dev/null +++ b/benchmarks/workloads/ffmpeg/__init__.py @@ -0,0 +1,20 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simple ffmpeg workload.""" + + +# pylint: disable=unused-argument +def run_time(value, **kwargs): + """Returns the startup and runtime of the ffmpeg workload in seconds.""" + return value diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD new file mode 100644 index 000000000..7fc96cfa5 --- /dev/null +++ b/benchmarks/workloads/fio/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "fio", + srcs = ["__init__.py"], +) + +py_test( + name = "fio_test", + srcs = ["fio_test.py"], + python_version = "PY3", + deps = [ + ":fio", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/fio/Dockerfile b/benchmarks/workloads/fio/Dockerfile new file mode 100644 index 000000000..b3cf864eb --- /dev/null +++ b/benchmarks/workloads/fio/Dockerfile @@ -0,0 +1,23 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + fio \ + && rm -rf /var/lib/apt/lists/* + +# Parameterized test. +ENV test write +ENV ioengine sync +ENV size 5000000 +ENV iodepth 4 +ENV blocksize "1m" +ENV time "" +ENV path "/disk/file.dat" +ENV ramp_time 0 + +CMD ["sh", "-c", "fio --output-format=json --name=test --ramp_time=${ramp_time} --ioengine=${ioengine} --size=${size} \ +--filename=${path} --iodepth=${iodepth} --bs=${blocksize} --rw=${test} ${time}"] + + + diff --git a/benchmarks/workloads/fio/__init__.py b/benchmarks/workloads/fio/__init__.py new file mode 100644 index 000000000..52711e956 --- /dev/null +++ b/benchmarks/workloads/fio/__init__.py @@ -0,0 +1,369 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FIO benchmark tool.""" + +import json + +SAMPLE_DATA = """ +{ + "fio version" : "fio-3.1", + "timestamp" : 1554837456, + "timestamp_ms" : 1554837456621, + "time" : "Tue Apr 9 19:17:36 2019", + "jobs" : [ + { + "jobname" : "test", + "groupid" : 0, + "error" : 0, + "eta" : 2147483647, + "elapsed" : 1, + "job options" : { + "name" : "test", + "ioengine" : "sync", + "size" : "1073741824", + "filename" : "/disk/file.dat", + "iodepth" : "4", + "bs" : "4096", + "rw" : "write" + }, + "read" : { + "io_bytes" : 0, + "io_kbytes" : 0, + "bw" : 0, + "iops" : 0.000000, + "runtime" : 0, + "total_ios" : 0, + "short_ios" : 0, + "drop_ios" : 0, + "slat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "clat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000, + "percentile" : { + "1.000000" : 0, + "5.000000" : 0, + "10.000000" : 0, + "20.000000" : 0, + "30.000000" : 0, + "40.000000" : 0, + "50.000000" : 0, + "60.000000" : 0, + "70.000000" : 0, + "80.000000" : 0, + "90.000000" : 0, + "95.000000" : 0, + "99.000000" : 0, + "99.500000" : 0, + "99.900000" : 0, + "99.950000" : 0, + "99.990000" : 0, + "0.00" : 0, + "0.00" : 0, + "0.00" : 0 + } + }, + "lat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "bw_min" : 0, + "bw_max" : 0, + "bw_agg" : 0.000000, + "bw_mean" : 0.000000, + "bw_dev" : 0.000000, + "bw_samples" : 0, + "iops_min" : 0, + "iops_max" : 0, + "iops_mean" : 0.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 0 + }, + "write" : { + "io_bytes" : 1073741824, + "io_kbytes" : 1048576, + "bw" : 1753471, + "iops" : 438367.892977, + "runtime" : 598, + "total_ios" : 262144, + "short_ios" : 0, + "drop_ios" : 0, + "slat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "clat_ns" : { + "min" : 1693, + "max" : 754733, + "mean" : 2076.404373, + "stddev" : 1724.195529, + "percentile" : { + "1.000000" : 1736, + "5.000000" : 1752, + "10.000000" : 1768, + "20.000000" : 1784, + "30.000000" : 1800, + "40.000000" : 1800, + "50.000000" : 1816, + "60.000000" : 1816, + "70.000000" : 1848, + "80.000000" : 1928, + "90.000000" : 2512, + "95.000000" : 2992, + "99.000000" : 6176, + "99.500000" : 6304, + "99.900000" : 11328, + "99.950000" : 15168, + "99.990000" : 17792, + "0.00" : 0, + "0.00" : 0, + "0.00" : 0 + } + }, + "lat_ns" : { + "min" : 1731, + "max" : 754770, + "mean" : 2117.878979, + "stddev" : 1730.290512 + }, + "bw_min" : 1731120, + "bw_max" : 1731120, + "bw_agg" : 98.725328, + "bw_mean" : 1731120.000000, + "bw_dev" : 0.000000, + "bw_samples" : 1, + "iops_min" : 432780, + "iops_max" : 432780, + "iops_mean" : 432780.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 1 + }, + "trim" : { + "io_bytes" : 0, + "io_kbytes" : 0, + "bw" : 0, + "iops" : 0.000000, + "runtime" : 0, + "total_ios" : 0, + "short_ios" : 0, + "drop_ios" : 0, + "slat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "clat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000, + "percentile" : { + "1.000000" : 0, + "5.000000" : 0, + "10.000000" : 0, + "20.000000" : 0, + "30.000000" : 0, + "40.000000" : 0, + "50.000000" : 0, + "60.000000" : 0, + "70.000000" : 0, + "80.000000" : 0, + "90.000000" : 0, + "95.000000" : 0, + "99.000000" : 0, + "99.500000" : 0, + "99.900000" : 0, + "99.950000" : 0, + "99.990000" : 0, + "0.00" : 0, + "0.00" : 0, + "0.00" : 0 + } + }, + "lat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "bw_min" : 0, + "bw_max" : 0, + "bw_agg" : 0.000000, + "bw_mean" : 0.000000, + "bw_dev" : 0.000000, + "bw_samples" : 0, + "iops_min" : 0, + "iops_max" : 0, + "iops_mean" : 0.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 0 + }, + "usr_cpu" : 17.922948, + "sys_cpu" : 81.574539, + "ctx" : 3, + "majf" : 0, + "minf" : 10, + "iodepth_level" : { + "1" : 100.000000, + "2" : 0.000000, + "4" : 0.000000, + "8" : 0.000000, + "16" : 0.000000, + "32" : 0.000000, + ">=64" : 0.000000 + }, + "latency_ns" : { + "2" : 0.000000, + "4" : 0.000000, + "10" : 0.000000, + "20" : 0.000000, + "50" : 0.000000, + "100" : 0.000000, + "250" : 0.000000, + "500" : 0.000000, + "750" : 0.000000, + "1000" : 0.000000 + }, + "latency_us" : { + "2" : 82.737350, + "4" : 12.605286, + "10" : 4.543686, + "20" : 0.107956, + "50" : 0.010000, + "100" : 0.000000, + "250" : 0.000000, + "500" : 0.000000, + "750" : 0.000000, + "1000" : 0.010000 + }, + "latency_ms" : { + "2" : 0.000000, + "4" : 0.000000, + "10" : 0.000000, + "20" : 0.000000, + "50" : 0.000000, + "100" : 0.000000, + "250" : 0.000000, + "500" : 0.000000, + "750" : 0.000000, + "1000" : 0.000000, + "2000" : 0.000000, + ">=2000" : 0.000000 + }, + "latency_depth" : 4, + "latency_target" : 0, + "latency_percentile" : 100.000000, + "latency_window" : 0 + } + ], + "disk_util" : [ + { + "name" : "dm-1", + "read_ios" : 0, + "write_ios" : 3, + "read_merges" : 0, + "write_merges" : 0, + "read_ticks" : 0, + "write_ticks" : 0, + "in_queue" : 0, + "util" : 0.000000, + "aggr_read_ios" : 0, + "aggr_write_ios" : 3, + "aggr_read_merges" : 0, + "aggr_write_merge" : 0, + "aggr_read_ticks" : 0, + "aggr_write_ticks" : 0, + "aggr_in_queue" : 0, + "aggr_util" : 0.000000 + }, + { + "name" : "dm-0", + "read_ios" : 0, + "write_ios" : 3, + "read_merges" : 0, + "write_merges" : 0, + "read_ticks" : 0, + "write_ticks" : 0, + "in_queue" : 0, + "util" : 0.000000, + "aggr_read_ios" : 0, + "aggr_write_ios" : 3, + "aggr_read_merges" : 0, + "aggr_write_merge" : 0, + "aggr_read_ticks" : 0, + "aggr_write_ticks" : 2, + "aggr_in_queue" : 0, + "aggr_util" : 0.000000 + }, + { + "name" : "nvme0n1", + "read_ios" : 0, + "write_ios" : 3, + "read_merges" : 0, + "write_merges" : 0, + "read_ticks" : 0, + "write_ticks" : 2, + "in_queue" : 0, + "util" : 0.000000 + } + ] +} +""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def read_bandwidth(data: str, **kwargs) -> int: + """File I/O bandwidth.""" + return json.loads(data)["jobs"][0]["read"]["bw"] * 1024 + + +# pylint: disable=unused-argument +def write_bandwidth(data: str, **kwargs) -> int: + """File I/O bandwidth.""" + return json.loads(data)["jobs"][0]["write"]["bw"] * 1024 + + +# pylint: disable=unused-argument +def read_io_ops(data: str, **kwargs) -> float: + """File I/O operations per second.""" + return float(json.loads(data)["jobs"][0]["read"]["iops"]) + + +# pylint: disable=unused-argument +def write_io_ops(data: str, **kwargs) -> float: + """File I/O operations per second.""" + return float(json.loads(data)["jobs"][0]["write"]["iops"]) + + +# Change function names so we just print "bandwidth" and "io_ops". +read_bandwidth.__name__ = "bandwidth" +write_bandwidth.__name__ = "bandwidth" +read_io_ops.__name__ = "io_ops" +write_io_ops.__name__ = "io_ops" diff --git a/benchmarks/workloads/fio/fio_test.py b/benchmarks/workloads/fio/fio_test.py new file mode 100644 index 000000000..04a6eeb7e --- /dev/null +++ b/benchmarks/workloads/fio/fio_test.py @@ -0,0 +1,44 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser tests.""" + +import sys + +import pytest + +from benchmarks.workloads import fio + + +def test_read_io_ops(): + """Test read ops parser.""" + assert fio.read_io_ops(fio.sample()) == 0.0 + + +def test_write_io_ops(): + """Test write ops parser.""" + assert fio.write_io_ops(fio.sample()) == 438367.892977 + + +def test_read_bandwidth(): + """Test read bandwidth parser.""" + assert fio.read_bandwidth(fio.sample()) == 0.0 + + +def test_write_bandwith(): + """Test write bandwidth parser.""" + assert fio.write_bandwidth(fio.sample()) == 1753471 * 1024 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/httpd/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/httpd/Dockerfile b/benchmarks/workloads/httpd/Dockerfile new file mode 100644 index 000000000..5259c8f4f --- /dev/null +++ b/benchmarks/workloads/httpd/Dockerfile @@ -0,0 +1,27 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + apache2 \ + && rm -rf /var/lib/apt/lists/* + +# Link the htdoc directory to tmp. +RUN mkdir -p /usr/local/apache2/htdocs && \ + cd /usr/local/apache2 && ln -s /tmp htdocs + +# Generate a bunch of relevant files. +RUN mkdir -p /local && \ + for size in 1 10 100 1000 1024 10240; do \ + dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ + done + +# Standard settings. +ENV APACHE_RUN_DIR /tmp +ENV APACHE_RUN_USER nobody +ENV APACHE_RUN_GROUP nogroup +ENV APACHE_LOG_DIR /tmp +ENV APACHE_PID_FILE /tmp/apache.pid + +# Copy on start-up; serve everything from /tmp (including the configuration). +CMD ["sh", "-c", "cp -a /local/* /tmp && apache2 -c \"ServerName localhost\" -c \"DocumentRoot /tmp\" -X"] diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD new file mode 100644 index 000000000..fe0acbfce --- /dev/null +++ b/benchmarks/workloads/iperf/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "iperf", + srcs = ["__init__.py"], +) + +py_test( + name = "iperf_test", + srcs = ["iperf_test.py"], + python_version = "PY3", + deps = [ + ":iperf", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/iperf/Dockerfile b/benchmarks/workloads/iperf/Dockerfile new file mode 100644 index 000000000..9704c506c --- /dev/null +++ b/benchmarks/workloads/iperf/Dockerfile @@ -0,0 +1,14 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + iperf \ + && rm -rf /var/lib/apt/lists/* + +# Accept a host parameter. +ENV host "" +ENV port 5001 + +# Start as client if the host is provided. +CMD ["sh", "-c", "test -z \"${host}\" && iperf -s || iperf -f K --realtime -c ${host} -p ${port}"] diff --git a/benchmarks/workloads/iperf/__init__.py b/benchmarks/workloads/iperf/__init__.py new file mode 100644 index 000000000..3817a7ade --- /dev/null +++ b/benchmarks/workloads/iperf/__init__.py @@ -0,0 +1,40 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""iperf.""" + +import re + +SAMPLE_DATA = """ +------------------------------------------------------------ +Client connecting to 10.138.15.215, TCP port 32779 +TCP window size: 45.0 KByte (default) +------------------------------------------------------------ +[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779 +[ ID] Interval Transfer Bandwidth +[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec + +""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def bandwidth(data: str, **kwargs) -> float: + """Calculate the bandwidth.""" + regex = r"\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec" + res = re.compile(regex).search(data) + return float(res.group(1)) * 1000 diff --git a/benchmarks/workloads/iperf/iperf_test.py b/benchmarks/workloads/iperf/iperf_test.py new file mode 100644 index 000000000..6959b7e8a --- /dev/null +++ b/benchmarks/workloads/iperf/iperf_test.py @@ -0,0 +1,28 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for iperf.""" + +import sys + +import pytest + +from benchmarks.workloads import iperf + + +def test_bandwidth(): + assert iperf.bandwidth(iperf.sample()) == 45900 * 1000 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/netcat/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/netcat/Dockerfile b/benchmarks/workloads/netcat/Dockerfile new file mode 100644 index 000000000..d8548d89a --- /dev/null +++ b/benchmarks/workloads/netcat/Dockerfile @@ -0,0 +1,14 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + netcat \ + && rm -rf /var/lib/apt/lists/* + +# Accept a host and port parameter. +ENV host localhost +ENV port 8080 + +# Spin until we make a successful request. +CMD ["sh", "-c", "while ! nc -zv $host $port; do true; done"] diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/nginx/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/nginx/Dockerfile b/benchmarks/workloads/nginx/Dockerfile new file mode 100644 index 000000000..b64eb52ae --- /dev/null +++ b/benchmarks/workloads/nginx/Dockerfile @@ -0,0 +1 @@ +FROM nginx:1.15.10 diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD new file mode 100644 index 000000000..59460d02f --- /dev/null +++ b/benchmarks/workloads/node/BUILD @@ -0,0 +1,13 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "index.js", + "package.json", + ], +) diff --git a/benchmarks/workloads/node/Dockerfile b/benchmarks/workloads/node/Dockerfile new file mode 100644 index 000000000..139a38bf5 --- /dev/null +++ b/benchmarks/workloads/node/Dockerfile @@ -0,0 +1,2 @@ +FROM node:onbuild +CMD ["node", "index.js"] diff --git a/benchmarks/workloads/node/index.js b/benchmarks/workloads/node/index.js new file mode 100644 index 000000000..584158462 --- /dev/null +++ b/benchmarks/workloads/node/index.js @@ -0,0 +1,28 @@ +'use strict'; + +var start = new Date().getTime(); + +// Load dependencies to simulate an average nodejs app. +var req_0 = require('async'); +var req_1 = require('bluebird'); +var req_2 = require('firebase'); +var req_3 = require('firebase-admin'); +var req_4 = require('@google-cloud/container'); +var req_5 = require('@google-cloud/logging'); +var req_6 = require('@google-cloud/monitoring'); +var req_7 = require('@google-cloud/spanner'); +var req_8 = require('lodash'); +var req_9 = require('mailgun-js'); +var req_10 = require('request'); +var express = require('express'); +var app = express(); + +var loaded = new Date().getTime() - start; +app.get('/', function(req, res) { + res.send('Hello World!<br>Loaded in ' + loaded + 'ms'); +}); + +console.log('Loaded in ' + loaded + ' ms'); +app.listen(8080, function() { + console.log('Listening on port 8080...'); +}); diff --git a/benchmarks/workloads/node/package.json b/benchmarks/workloads/node/package.json new file mode 100644 index 000000000..c00b9b3cb --- /dev/null +++ b/benchmarks/workloads/node/package.json @@ -0,0 +1,19 @@ +{ + "name": "node", + "version": "1.0.0", + "main": "index.js", + "dependencies": { + "@google-cloud/container": "^0.3.0", + "@google-cloud/logging": "^4.2.0", + "@google-cloud/monitoring": "^0.6.0", + "@google-cloud/spanner": "^2.2.1", + "async": "^2.6.1", + "bluebird": "^3.5.3", + "express": "^4.16.4", + "firebase": "^5.7.2", + "firebase-admin": "^6.4.0", + "lodash": "^4.17.11", + "mailgun-js": "^0.22.0", + "request": "^2.88.0" + } +} diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD new file mode 100644 index 000000000..ae7f121d3 --- /dev/null +++ b/benchmarks/workloads/node_template/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "index.hbs", + "index.js", + "package.json", + "package-lock.json", + ], +) diff --git a/benchmarks/workloads/node_template/Dockerfile b/benchmarks/workloads/node_template/Dockerfile new file mode 100644 index 000000000..7eb065d54 --- /dev/null +++ b/benchmarks/workloads/node_template/Dockerfile @@ -0,0 +1,5 @@ +FROM node:onbuild + +ENV host "127.0.0.1" + +CMD ["sh", "-c", "node index.js ${host}"] diff --git a/benchmarks/workloads/node_template/index.hbs b/benchmarks/workloads/node_template/index.hbs new file mode 100644 index 000000000..03feceb75 --- /dev/null +++ b/benchmarks/workloads/node_template/index.hbs @@ -0,0 +1,8 @@ +<!DOCTYPE html> +<html> +<body> + {{#each text}} + <p>{{this}}</p> + {{/each}} +</body> +</html> diff --git a/benchmarks/workloads/node_template/index.js b/benchmarks/workloads/node_template/index.js new file mode 100644 index 000000000..04a27f356 --- /dev/null +++ b/benchmarks/workloads/node_template/index.js @@ -0,0 +1,43 @@ +const app = require('express')(); +const path = require('path'); +const redis = require('redis'); +const srs = require('secure-random-string'); + +// The hostname is the first argument. +const host_name = process.argv[2]; + +var client = redis.createClient({host: host_name, detect_buffers: true}); + +app.set('views', __dirname); +app.set('view engine', 'hbs'); + +app.get('/', (req, res) => { + var tmp = []; + /* Pull four random keys from the redis server. */ + for (i = 0; i < 4; i++) { + client.get(Math.floor(Math.random() * (100)), function(err, reply) { + tmp.push(reply.toString()); + }); + } + + res.render('index', {text: tmp}); +}); + +/** + * Securely generate a random string. + * @param {number} len + * @return {string} + */ +function randomBody(len) { + return srs({alphanumeric: true, length: len}); +} + +/** Mutates one hundred keys randomly. */ +function generateText() { + for (i = 0; i < 100; i++) { + client.set(i, randomBody(1024)); + } +} + +generateText(); +app.listen(8080); diff --git a/benchmarks/workloads/node_template/package-lock.json b/benchmarks/workloads/node_template/package-lock.json new file mode 100644 index 000000000..580e68aa5 --- /dev/null +++ b/benchmarks/workloads/node_template/package-lock.json @@ -0,0 +1,486 @@ +{ + "name": "nodedum", + "version": "1.0.0", + "lockfileVersion": 1, + "requires": true, + "dependencies": { + "accepts": { + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz", + "integrity": "sha1-63d99gEXI6OxTopywIBcjoZ0a9I=", + "requires": { + "mime-types": "~2.1.18", + "negotiator": "0.6.1" + } + }, + "array-flatten": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI=" + }, + "async": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/async/-/async-2.6.2.tgz", + "integrity": "sha512-H1qVYh1MYhEEFLsP97cVKqCGo7KfCyTt6uEWqsTBr9SO84oK9Uwbyd/yCW+6rKJLHksBNUVWZDAjfS+Ccx0Bbg==", + "requires": { + "lodash": "^4.17.11" + } + }, + "body-parser": { + "version": "1.18.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.18.3.tgz", + "integrity": "sha1-WykhmP/dVTs6DyDe0FkrlWlVyLQ=", + "requires": { + "bytes": "3.0.0", + "content-type": "~1.0.4", + "debug": "2.6.9", + "depd": "~1.1.2", + "http-errors": "~1.6.3", + "iconv-lite": "0.4.23", + "on-finished": "~2.3.0", + "qs": "6.5.2", + "raw-body": "2.3.3", + "type-is": "~1.6.16" + } + }, + "bytes": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", + "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=" + }, + "commander": { + "version": "2.20.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.0.tgz", + "integrity": "sha512-7j2y+40w61zy6YC2iRNpUe/NwhNyoXrYpHMrSunaMG64nRnaf96zO/KMQR4OyN/UnE5KLyEBnKHd4aG3rskjpQ==", + "optional": true + }, + "content-disposition": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.2.tgz", + "integrity": "sha1-DPaLud318r55YcOoUXjLhdunjLQ=" + }, + "content-type": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.4.tgz", + "integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA==" + }, + "cookie": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz", + "integrity": "sha1-5+Ch+e9DtMi6klxcWpboBtFoc7s=" + }, + "cookie-signature": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", + "integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw=" + }, + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "requires": { + "ms": "2.0.0" + } + }, + "depd": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz", + "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=" + }, + "destroy": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz", + "integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA=" + }, + "double-ended-queue": { + "version": "2.1.0-0", + "resolved": "https://registry.npmjs.org/double-ended-queue/-/double-ended-queue-2.1.0-0.tgz", + "integrity": "sha1-ED01J/0xUo9AGIEwyEHv3XgmTlw=" + }, + "ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0=" + }, + "encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k=" + }, + "escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=" + }, + "etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc=" + }, + "express": { + "version": "4.16.4", + "resolved": "https://registry.npmjs.org/express/-/express-4.16.4.tgz", + "integrity": "sha512-j12Uuyb4FMrd/qQAm6uCHAkPtO8FDTRJZBDd5D2KOL2eLaz1yUNdUB/NOIyq0iU4q4cFarsUCrnFDPBcnksuOg==", + "requires": { + "accepts": "~1.3.5", + "array-flatten": "1.1.1", + "body-parser": "1.18.3", + "content-disposition": "0.5.2", + "content-type": "~1.0.4", + "cookie": "0.3.1", + "cookie-signature": "1.0.6", + "debug": "2.6.9", + "depd": "~1.1.2", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "finalhandler": "1.1.1", + "fresh": "0.5.2", + "merge-descriptors": "1.0.1", + "methods": "~1.1.2", + "on-finished": "~2.3.0", + "parseurl": "~1.3.2", + "path-to-regexp": "0.1.7", + "proxy-addr": "~2.0.4", + "qs": "6.5.2", + "range-parser": "~1.2.0", + "safe-buffer": "5.1.2", + "send": "0.16.2", + "serve-static": "1.13.2", + "setprototypeof": "1.1.0", + "statuses": "~1.4.0", + "type-is": "~1.6.16", + "utils-merge": "1.0.1", + "vary": "~1.1.2" + } + }, + "finalhandler": { + "version": "1.1.1", + "resolved": "http://registry.npmjs.org/finalhandler/-/finalhandler-1.1.1.tgz", + "integrity": "sha512-Y1GUDo39ez4aHAw7MysnUD5JzYX+WaIj8I57kO3aEPT1fFRL4sr7mjei97FgnwhAyyzRYmQZaTHb2+9uZ1dPtg==", + "requires": { + "debug": "2.6.9", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "on-finished": "~2.3.0", + "parseurl": "~1.3.2", + "statuses": "~1.4.0", + "unpipe": "~1.0.0" + } + }, + "foreachasync": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz", + "integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY=" + }, + "forwarded": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz", + "integrity": "sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ=" + }, + "fresh": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", + "integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac=" + }, + "handlebars": { + "version": "4.0.14", + "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.0.14.tgz", + "integrity": "sha512-E7tDoyAA8ilZIV3xDJgl18sX3M8xB9/fMw8+mfW4msLW8jlX97bAnWgT3pmaNXuvzIEgSBMnAHfuXsB2hdzfow==", + "requires": { + "async": "^2.5.0", + "optimist": "^0.6.1", + "source-map": "^0.6.1", + "uglify-js": "^3.1.4" + } + }, + "hbs": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.0.4.tgz", + "integrity": "sha512-esVlyV/V59mKkwFai5YmPRSNIWZzhqL5YMN0++ueMxyK1cCfPa5f6JiHtapPKAIVAhQR6rpGxow0troav9WMEg==", + "requires": { + "handlebars": "4.0.14", + "walk": "2.3.9" + } + }, + "http-errors": { + "version": "1.6.3", + "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", + "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", + "requires": { + "depd": "~1.1.2", + "inherits": "2.0.3", + "setprototypeof": "1.1.0", + "statuses": ">= 1.4.0 < 2" + } + }, + "iconv-lite": { + "version": "0.4.23", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.23.tgz", + "integrity": "sha512-neyTUVFtahjf0mB3dZT77u+8O0QB89jFdnBkd5P1JgYPbPaia3gXXOVL2fq8VyU2gMMD7SaN7QukTB/pmXYvDA==", + "requires": { + "safer-buffer": ">= 2.1.2 < 3" + } + }, + "inherits": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", + "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=" + }, + "ipaddr.js": { + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.8.0.tgz", + "integrity": "sha1-6qM9bd16zo9/b+DJygRA5wZzix4=" + }, + "lodash": { + "version": "4.17.15", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz", + "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A==" + }, + "media-typer": { + "version": "0.3.0", + "resolved": "http://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", + "integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=" + }, + "merge-descriptors": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", + "integrity": "sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E=" + }, + "methods": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", + "integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4=" + }, + "mime": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/mime/-/mime-1.4.1.tgz", + "integrity": "sha512-KI1+qOZu5DcW6wayYHSzR/tXKCDC5Om4s1z2QJjDULzLcmf3DvzS7oluY4HCTrc+9FiKmWUgeNLg7W3uIQvxtQ==" + }, + "mime-db": { + "version": "1.37.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.37.0.tgz", + "integrity": "sha512-R3C4db6bgQhlIhPU48fUtdVmKnflq+hRdad7IyKhtFj06VPNVdk2RhiYL3UjQIlso8L+YxAtFkobT0VK+S/ybg==" + }, + "mime-types": { + "version": "2.1.21", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.21.tgz", + "integrity": "sha512-3iL6DbwpyLzjR3xHSFNFeb9Nz/M8WDkX33t1GFQnFOllWk8pOrh/LSrB5OXlnlW5P9LH73X6loW/eogc+F5lJg==", + "requires": { + "mime-db": "~1.37.0" + } + }, + "minimist": { + "version": "0.0.10", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz", + "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8=" + }, + "ms": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=" + }, + "negotiator": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.1.tgz", + "integrity": "sha1-KzJxhOiZIQEXeyhWP7XnECrNDKk=" + }, + "on-finished": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.3.0.tgz", + "integrity": "sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=", + "requires": { + "ee-first": "1.1.1" + } + }, + "optimist": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/optimist/-/optimist-0.6.1.tgz", + "integrity": "sha1-2j6nRob6IaGaERwybpDrFaAZZoY=", + "requires": { + "minimist": "~0.0.1", + "wordwrap": "~0.0.2" + } + }, + "parseurl": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.2.tgz", + "integrity": "sha1-/CidTtiZMRlGDBViUyYs3I3mW/M=" + }, + "path-to-regexp": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", + "integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w=" + }, + "proxy-addr": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.4.tgz", + "integrity": "sha512-5erio2h9jp5CHGwcybmxmVqHmnCBZeewlfJ0pex+UW7Qny7OOZXTtH56TGNyBizkgiOwhJtMKrVzDTeKcySZwA==", + "requires": { + "forwarded": "~0.1.2", + "ipaddr.js": "1.8.0" + } + }, + "qs": { + "version": "6.5.2", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.5.2.tgz", + "integrity": "sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA==" + }, + "range-parser": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.0.tgz", + "integrity": "sha1-9JvmtIeJTdxA3MlKMi9hEJLgDV4=" + }, + "raw-body": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.3.3.tgz", + "integrity": "sha512-9esiElv1BrZoI3rCDuOuKCBRbuApGGaDPQfjSflGxdy4oyzqghxu6klEkkVIvBje+FF0BX9coEv8KqW6X/7njw==", + "requires": { + "bytes": "3.0.0", + "http-errors": "1.6.3", + "iconv-lite": "0.4.23", + "unpipe": "1.0.0" + } + }, + "redis": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/redis/-/redis-2.8.0.tgz", + "integrity": "sha512-M1OkonEQwtRmZv4tEWF2VgpG0JWJ8Fv1PhlgT5+B+uNq2cA3Rt1Yt/ryoR+vQNOQcIEgdCdfH0jr3bDpihAw1A==", + "requires": { + "double-ended-queue": "^2.1.0-0", + "redis-commands": "^1.2.0", + "redis-parser": "^2.6.0" + }, + "dependencies": { + "redis-commands": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.4.0.tgz", + "integrity": "sha512-cu8EF+MtkwI4DLIT0x9P8qNTLFhQD4jLfxLR0cCNkeGzs87FN6879JOJwNQR/1zD7aSYNbU0hgsV9zGY71Itvw==" + }, + "redis-parser": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", + "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" + } + } + }, + "redis-commands": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.5.0.tgz", + "integrity": "sha512-6KxamqpZ468MeQC3bkWmCB1fp56XL64D4Kf0zJSwDZbVLLm7KFkoIcHrgRvQ+sk8dnhySs7+yBg94yIkAK7aJg==" + }, + "redis-parser": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", + "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" + }, + "safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==" + }, + "safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==" + }, + "secure-random-string": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.0.tgz", + "integrity": "sha512-V/h8jqoz58zklNGybVhP++cWrxEPXlLM/6BeJ4e0a8zlb4BsbYRzFs16snrxByPa5LUxCVTD3M6EYIVIHR1fAg==" + }, + "send": { + "version": "0.16.2", + "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz", + "integrity": "sha512-E64YFPUssFHEFBvpbbjr44NCLtI1AohxQ8ZSiJjQLskAdKuriYEP6VyGEsRDH8ScozGpkaX1BGvhanqCwkcEZw==", + "requires": { + "debug": "2.6.9", + "depd": "~1.1.2", + "destroy": "~1.0.4", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "fresh": "0.5.2", + "http-errors": "~1.6.2", + "mime": "1.4.1", + "ms": "2.0.0", + "on-finished": "~2.3.0", + "range-parser": "~1.2.0", + "statuses": "~1.4.0" + } + }, + "serve-static": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.13.2.tgz", + "integrity": "sha512-p/tdJrO4U387R9oMjb1oj7qSMaMfmOyd4j9hOFoxZe2baQszgHcSWjuya/CiT5kgZZKRudHNOA0pYXOl8rQ5nw==", + "requires": { + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "parseurl": "~1.3.2", + "send": "0.16.2" + } + }, + "setprototypeof": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz", + "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ==" + }, + "source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==" + }, + "statuses": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.4.0.tgz", + "integrity": "sha512-zhSCtt8v2NDrRlPQpCNtw/heZLtfUDqxBM1udqikb/Hbk52LK4nQSwr10u77iopCW5LsyHpuXS0GnEc48mLeew==" + }, + "type-is": { + "version": "1.6.16", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.16.tgz", + "integrity": "sha512-HRkVv/5qY2G6I8iab9cI7v1bOIdhm94dVjQCPFElW9W+3GeDOSHmy2EBYe4VTApuzolPcmgFTN3ftVJRKR2J9Q==", + "requires": { + "media-typer": "0.3.0", + "mime-types": "~2.1.18" + } + }, + "uglify-js": { + "version": "3.5.9", + "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.5.9.tgz", + "integrity": "sha512-WpT0RqsDtAWPNJK955DEnb6xjymR8Fn0OlK4TT4pS0ASYsVPqr5ELhgwOwLCP5J5vHeJ4xmMmz3DEgdqC10JeQ==", + "optional": true, + "requires": { + "commander": "~2.20.0", + "source-map": "~0.6.1" + } + }, + "unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw=" + }, + "utils-merge": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", + "integrity": "sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM=" + }, + "vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=" + }, + "walk": { + "version": "2.3.9", + "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.9.tgz", + "integrity": "sha1-MbTbZnjyrgHDnqn7hyWpAx5Vins=", + "requires": { + "foreachasync": "^3.0.0" + } + }, + "wordwrap": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz", + "integrity": "sha1-o9XabNXAvAAI03I0u68b7WMFkQc=" + } + } +} diff --git a/benchmarks/workloads/node_template/package.json b/benchmarks/workloads/node_template/package.json new file mode 100644 index 000000000..7dcadd523 --- /dev/null +++ b/benchmarks/workloads/node_template/package.json @@ -0,0 +1,19 @@ +{ + "name": "nodedum", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "author": "", + "license": "ISC", + "dependencies": { + "express": "^4.16.4", + "hbs": "^4.0.4", + "redis": "^2.8.0", + "redis-commands": "^1.2.0", + "redis-parser": "^2.6.0", + "secure-random-string": "^1.1.0" + } +} diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/redis/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/redis/Dockerfile b/benchmarks/workloads/redis/Dockerfile new file mode 100644 index 000000000..0f17249af --- /dev/null +++ b/benchmarks/workloads/redis/Dockerfile @@ -0,0 +1 @@ +FROM redis:5.0.4 diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD new file mode 100644 index 000000000..d40e75a3a --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "redisbenchmark", + srcs = ["__init__.py"], +) + +py_test( + name = "redisbenchmark_test", + srcs = ["redisbenchmark_test.py"], + python_version = "PY3", + deps = [ + ":redisbenchmark", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/redisbenchmark/Dockerfile b/benchmarks/workloads/redisbenchmark/Dockerfile new file mode 100644 index 000000000..f94f6442e --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/Dockerfile @@ -0,0 +1,4 @@ +FROM redis:5.0.4 +ENV host localhost +ENV port 6379 +CMD ["sh", "-c", "redis-benchmark --csv -h ${host} -p ${port} ${flags}"] diff --git a/benchmarks/workloads/redisbenchmark/__init__.py b/benchmarks/workloads/redisbenchmark/__init__.py new file mode 100644 index 000000000..229cef5fa --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/__init__.py @@ -0,0 +1,85 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Redis-benchmark tool.""" + +import re + +OPERATIONS = [ + "PING_INLINE", + "PING_BULK", + "SET", + "GET", + "INCR", + "LPUSH", + "RPUSH", + "LPOP", + "RPOP", + "SADD", + "HSET", + "SPOP", + "LRANGE_100", + "LRANGE_300", + "LRANGE_500", + "LRANGE_600", + "MSET", +] + +METRICS = dict() + +SAMPLE_DATA = """ +"PING_INLINE","48661.80" +"PING_BULK","50301.81" +"SET","48923.68" +"GET","49382.71" +"INCR","49975.02" +"LPUSH","49875.31" +"RPUSH","50276.52" +"LPOP","50327.12" +"RPOP","50556.12" +"SADD","49504.95" +"HSET","49504.95" +"SPOP","50025.02" +"LPUSH (needed to benchmark LRANGE)","48875.86" +"LRANGE_100 (first 100 elements)","33955.86" +"LRANGE_300 (first 300 elements)","16550.81" +"LRANGE_500 (first 450 elements)","13653.74" +"LRANGE_600 (first 600 elements)","11219.57" +"MSET (10 keys)","44682.75" +""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# Bind a metric for each operation noted above. +for op in OPERATIONS: + + def bind(metric): + """Bind op to a new scope.""" + + # pylint: disable=unused-argument + def parse(data: str, **kwargs) -> float: + """Operation throughput in requests/sec.""" + regex = r"\"" + metric + r"( .*)?\",\"(\d*.\d*)" + res = re.compile(regex).search(data) + if res: + return float(res.group(2)) + return 0.0 + + parse.__name__ = metric + return parse + + METRICS[op] = bind(op) diff --git a/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py new file mode 100644 index 000000000..419ced059 --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py @@ -0,0 +1,51 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser test.""" + +import sys + +import pytest + +from benchmarks.workloads import redisbenchmark + +RESULTS = { + "PING_INLINE": 48661.80, + "PING_BULK": 50301.81, + "SET": 48923.68, + "GET": 49382.71, + "INCR": 49975.02, + "LPUSH": 49875.31, + "RPUSH": 50276.52, + "LPOP": 50327.12, + "RPOP": 50556.12, + "SADD": 49504.95, + "HSET": 49504.95, + "SPOP": 50025.02, + "LRANGE_100": 33955.86, + "LRANGE_300": 16550.81, + "LRANGE_500": 13653.74, + "LRANGE_600": 11219.57, + "MSET": 44682.75 +} + + +def test_metrics(): + """Test all metrics.""" + for (metric, func) in redisbenchmark.METRICS.items(): + res = func(redisbenchmark.sample()) + assert float(res) == RESULTS[metric] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD new file mode 100644 index 000000000..9846c7e70 --- /dev/null +++ b/benchmarks/workloads/ruby/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "Gemfile", + "Gemfile.lock", + "config.ru", + "index.rb", + ], +) diff --git a/benchmarks/workloads/ruby/Dockerfile b/benchmarks/workloads/ruby/Dockerfile new file mode 100644 index 000000000..a9a7a7086 --- /dev/null +++ b/benchmarks/workloads/ruby/Dockerfile @@ -0,0 +1,28 @@ +# example based on https://github.com/errm/fib + +FROM ruby:2.5 + +RUN apt-get update -qq && apt-get install -y build-essential libpq-dev nodejs libsodium-dev + +# Set an environment variable where the Rails app is installed to inside of Docker image +ENV RAILS_ROOT /var/www/app_name +RUN mkdir -p $RAILS_ROOT + +# Set working directory +WORKDIR $RAILS_ROOT + +# Setting env up +ENV RAILS_ENV='production' +ENV RACK_ENV='production' + +# Adding gems +COPY Gemfile Gemfile +COPY Gemfile.lock Gemfile.lock +RUN bundle install --jobs 20 --retry 5 --without development test + +# Adding project files +COPY . . + +EXPOSE $PORT +STOPSIGNAL SIGINT +CMD ["bundle", "exec", "puma", "config.ru"] diff --git a/benchmarks/workloads/ruby/Gemfile b/benchmarks/workloads/ruby/Gemfile new file mode 100644 index 000000000..8f1bdad6e --- /dev/null +++ b/benchmarks/workloads/ruby/Gemfile @@ -0,0 +1,12 @@ +source "https://rubygems.org" +# load a bunch of dependencies to take up memory +gem "sinatra" +gem "puma" +gem "redis" +gem 'rake' +gem 'squid', '~> 1.4' +gem 'cassandra-driver' +gem 'ruby-fann' +gem 'rbnacl' +gem 'bcrypt' +gem "activemerchant"
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock new file mode 100644 index 000000000..b44817bd3 --- /dev/null +++ b/benchmarks/workloads/ruby/Gemfile.lock @@ -0,0 +1,55 @@ +GEM + remote: https://rubygems.org/ + specs: + activesupport (5.2.3) + concurrent-ruby (~> 1.0, >= 1.0.2) + i18n (>= 0.7, < 2) + minitest (~> 5.1) + tzinfo (~> 1.1) + cassandra-driver (3.2.3) + ione (~> 1.2) + concurrent-ruby (1.1.5) + i18n (1.6.0) + concurrent-ruby (~> 1.0) + ione (1.2.4) + minitest (5.11.3) + mustermann (1.0.3) + pdf-core (0.7.0) + prawn (2.2.2) + pdf-core (~> 0.7.0) + ttfunk (~> 1.5) + puma (3.12.1) + rack (2.0.7) + rack-protection (2.0.5) + rack + rake (12.3.2) + redis (4.1.1) + ruby-fann (1.2.6) + sinatra (2.0.5) + mustermann (~> 1.0) + rack (~> 2.0) + rack-protection (= 2.0.5) + tilt (~> 2.0) + squid (1.4.1) + activesupport (>= 4.0) + prawn (~> 2.2) + thread_safe (0.3.6) + tilt (2.0.9) + ttfunk (1.5.1) + tzinfo (1.2.5) + thread_safe (~> 0.1) + +PLATFORMS + ruby + +DEPENDENCIES + cassandra-driver + puma + rake + redis + ruby-fann + sinatra + squid (~> 1.4) + +BUNDLED WITH + 1.17.1 diff --git a/benchmarks/workloads/ruby/config.ru b/benchmarks/workloads/ruby/config.ru new file mode 100755 index 000000000..fbd5acc82 --- /dev/null +++ b/benchmarks/workloads/ruby/config.ru @@ -0,0 +1,2 @@ +require './index' +run Sinatra::Application
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/index.rb b/benchmarks/workloads/ruby/index.rb new file mode 100755 index 000000000..5fa85af93 --- /dev/null +++ b/benchmarks/workloads/ruby/index.rb @@ -0,0 +1,14 @@ +require "sinatra" +require "puma" +require "redis" +require "rake" +require "squid" +require "cassandra" +require "ruby-fann" +require "rbnacl" +require "bcrypt" +require "activemerchant" + +get "/" do + "Hello World!" +end
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD new file mode 100644 index 000000000..2b99892af --- /dev/null +++ b/benchmarks/workloads/ruby_template/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "Gemfile", + "Gemfile.lock", + "config.ru", + "index.erb", + "main.rb", + ], +) diff --git a/benchmarks/workloads/ruby_template/Dockerfile b/benchmarks/workloads/ruby_template/Dockerfile new file mode 100755 index 000000000..a06d68bf4 --- /dev/null +++ b/benchmarks/workloads/ruby_template/Dockerfile @@ -0,0 +1,38 @@ +# example based on https://github.com/errm/fib + +FROM alpine:3.9 as build + +COPY Gemfile Gemfile.lock ./ + +RUN apk add --no-cache ruby ruby-dev ruby-bundler ruby-json build-base bash \ + && bundle install --frozen -j4 -r3 --no-cache --without development \ + && apk del --no-cache ruby-bundler \ + && rm -rf /usr/lib/ruby/gems/*/cache + +FROM alpine:3.9 as prod + +COPY --from=build /usr/lib/ruby/gems /usr/lib/ruby/gems +RUN apk add --no-cache ruby ruby-json ruby-etc redis apache2-utils \ + && ruby -e "Gem::Specification.map.each do |spec| \ + Gem::Installer.for_spec( \ + spec, \ + wrappers: true, \ + force: true, \ + install_dir: spec.base_dir, \ + build_args: spec.build_args, \ + ).generate_bin \ + end" + +WORKDIR /app +COPY . /app/. + +ENV PORT=9292 \ + WEB_CONCURRENCY=20 \ + WEB_MAX_THREADS=20 \ + RACK_ENV=production + +ENV host localhost +EXPOSE $PORT +USER nobody +STOPSIGNAL SIGINT +CMD ["sh", "-c", "/usr/bin/puma", "${host}"] diff --git a/benchmarks/workloads/ruby_template/Gemfile b/benchmarks/workloads/ruby_template/Gemfile new file mode 100755 index 000000000..ac521b32c --- /dev/null +++ b/benchmarks/workloads/ruby_template/Gemfile @@ -0,0 +1,5 @@ +source "https://rubygems.org" + +gem "sinatra" +gem "puma" +gem "redis"
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock new file mode 100644 index 000000000..dd8d56fb7 --- /dev/null +++ b/benchmarks/workloads/ruby_template/Gemfile.lock @@ -0,0 +1,26 @@ +GEM + remote: https://rubygems.org/ + specs: + mustermann (1.0.3) + puma (3.12.0) + rack (2.0.6) + rack-protection (2.0.5) + rack + sinatra (2.0.5) + mustermann (~> 1.0) + rack (~> 2.0) + rack-protection (= 2.0.5) + tilt (~> 2.0) + tilt (2.0.9) + redis (4.1.0) + +PLATFORMS + ruby + +DEPENDENCIES + puma + sinatra + redis + +BUNDLED WITH + 1.17.1
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/config.ru b/benchmarks/workloads/ruby_template/config.ru new file mode 100755 index 000000000..b2d135cc0 --- /dev/null +++ b/benchmarks/workloads/ruby_template/config.ru @@ -0,0 +1,2 @@ +require './main' +run Sinatra::Application
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/index.erb b/benchmarks/workloads/ruby_template/index.erb new file mode 100755 index 000000000..7f7300e80 --- /dev/null +++ b/benchmarks/workloads/ruby_template/index.erb @@ -0,0 +1,8 @@ +<!DOCTYPE html> +<html> +<body> + <% text.each do |t| %> + <p><%= t %></p> + <% end %> +</body> +</html> diff --git a/benchmarks/workloads/ruby_template/main.rb b/benchmarks/workloads/ruby_template/main.rb new file mode 100755 index 000000000..35c239377 --- /dev/null +++ b/benchmarks/workloads/ruby_template/main.rb @@ -0,0 +1,27 @@ +require "sinatra" +require "securerandom" +require "redis" + +redis_host = ENV["host"] +$redis = Redis.new(host: redis_host) + +def generateText + for i in 0..99 + $redis.set(i, randomBody(1024)) + end +end + +def randomBody(length) + return SecureRandom.alphanumeric(length) +end + +generateText +template = ERB.new(File.read('./index.erb')) + +get "/" do + texts = Array.new + for i in 0..4 + texts.push($redis.get(rand(0..99))) + end + template.result_with_hash(text: texts) +end
\ No newline at end of file diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/sleep/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/sleep/Dockerfile b/benchmarks/workloads/sleep/Dockerfile new file mode 100644 index 000000000..24c72e07a --- /dev/null +++ b/benchmarks/workloads/sleep/Dockerfile @@ -0,0 +1,3 @@ +FROM alpine:latest + +CMD ["sleep", "315360000"] diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD new file mode 100644 index 000000000..35f4d460b --- /dev/null +++ b/benchmarks/workloads/sysbench/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "sysbench", + srcs = ["__init__.py"], +) + +py_test( + name = "sysbench_test", + srcs = ["sysbench_test.py"], + python_version = "PY3", + deps = [ + ":sysbench", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/sysbench/Dockerfile b/benchmarks/workloads/sysbench/Dockerfile new file mode 100644 index 000000000..8225e0e14 --- /dev/null +++ b/benchmarks/workloads/sysbench/Dockerfile @@ -0,0 +1,16 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + sysbench \ + && rm -rf /var/lib/apt/lists/* + +# Parameterize the tests. +ENV test cpu +ENV threads 1 +ENV options "" + +# run sysbench once as a warm-up and take the second result +CMD ["sh", "-c", "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null && \ +sysbench --threads=${threads} ${options} ${test} run"] diff --git a/benchmarks/workloads/sysbench/__init__.py b/benchmarks/workloads/sysbench/__init__.py new file mode 100644 index 000000000..de357b4db --- /dev/null +++ b/benchmarks/workloads/sysbench/__init__.py @@ -0,0 +1,167 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sysbench.""" + +import re + +STD_REGEX = r"events per second:\s*(\d*.?\d*)\n" +MEM_REGEX = r"Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)" +ALT_REGEX = r"execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)" +AVG_REGEX = r"avg:[^\n^\d]*(\d*\.?\d*)" + +SAMPLE_CPU_DATA = """ +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Prime numbers limit: 10000 + +Initializing worker threads... + +Threads started! + +CPU speed: + events per second: 9093.38 + +General statistics: + total time: 10.0007s + total number of events: 90949 + +Latency (ms): + min: 0.64 + avg: 0.88 + max: 24.65 + 95th percentile: 1.55 + sum: 79936.91 + +Threads fairness: + events (avg/stddev): 11368.6250/831.38 + execution time (avg/stddev): 9.9921/0.01 +""" + +SAMPLE_MEMORY_DATA = """ +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Running memory speed test with the following options: + block size: 1KiB + total size: 102400MiB + operation: write + scope: global + +Initializing worker threads... + +Threads started! + +Total operations: 47999046 (9597428.64 per second) + +46874.07 MiB transferred (9372.49 MiB/sec) + + +General statistics: + total time: 5.0001s + total number of events: 47999046 + +Latency (ms): + min: 0.00 + avg: 0.00 + max: 0.21 + 95th percentile: 0.00 + sum: 33165.91 + +Threads fairness: + events (avg/stddev): 5999880.7500/111242.52 + execution time (avg/stddev): 4.1457/0.09 +""" + +SAMPLE_MUTEX_DATA = """ +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Initializing worker threads... + +Threads started! + + +General statistics: + total time: 3.7869s + total number of events: 8 + +Latency (ms): + min: 3688.56 + avg: 3754.03 + max: 3780.94 + 95th percentile: 3773.42 + sum: 30032.28 + +Threads fairness: + events (avg/stddev): 1.0000/0.00 + execution time (avg/stddev): 3.7540/0.03 +""" + + +# pylint: disable=unused-argument +def sample(test, **kwargs): + switch = { + "cpu": SAMPLE_CPU_DATA, + "memory": SAMPLE_MEMORY_DATA, + "mutex": SAMPLE_MUTEX_DATA, + "randwr": SAMPLE_CPU_DATA + } + return switch[test] + + +# pylint: disable=unused-argument +def cpu_events_per_second(data: str, **kwargs) -> float: + """Returns events per second.""" + return float(re.compile(STD_REGEX).search(data).group(1)) + + +# pylint: disable=unused-argument +def memory_ops_per_second(data: str, **kwargs) -> float: + """Returns memory operations per second.""" + return float(re.compile(MEM_REGEX).search(data).group(1)) + + +# pylint: disable=unused-argument +def mutex_time(data: str, count: int, locks: int, threads: int, + **kwargs) -> float: + """Returns normalized mutex time (lower is better).""" + value = float(re.compile(ALT_REGEX).search(data).group(1)) + contention = float(threads) / float(locks) + scale = contention * float(count) / 100000000.0 + return value / scale + + +# pylint: disable=unused-argument +def mutex_deviation(data: str, **kwargs) -> float: + """Returns deviation for threads.""" + return float(re.compile(ALT_REGEX).search(data).group(2)) + + +# pylint: disable=unused-argument +def mutex_latency(data: str, **kwargs) -> float: + """Returns average mutex latency.""" + return float(re.compile(AVG_REGEX).search(data).group(1)) diff --git a/benchmarks/workloads/sysbench/sysbench_test.py b/benchmarks/workloads/sysbench/sysbench_test.py new file mode 100644 index 000000000..3fb541fd2 --- /dev/null +++ b/benchmarks/workloads/sysbench/sysbench_test.py @@ -0,0 +1,34 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser test.""" + +import sys + +import pytest + +from benchmarks.workloads import sysbench + + +def test_sysbench_parser(): + """Test the basic parser.""" + assert sysbench.cpu_events_per_second(sysbench.sample("cpu")) == 9093.38 + assert sysbench.memory_ops_per_second(sysbench.sample("memory")) == 9597428.64 + assert sysbench.mutex_time(sysbench.sample("mutex"), 1, 1, + 100000000.0) == 3.754 + assert sysbench.mutex_deviation(sysbench.sample("mutex")) == 0.03 + assert sysbench.mutex_latency(sysbench.sample("mutex")) == 3754.03 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD new file mode 100644 index 000000000..e1ff3059b --- /dev/null +++ b/benchmarks/workloads/syscall/BUILD @@ -0,0 +1,36 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "syscall", + srcs = ["__init__.py"], +) + +py_test( + name = "syscall_test", + srcs = ["syscall_test.py"], + python_version = "PY3", + deps = [ + ":syscall", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "syscall.c", + ], +) diff --git a/benchmarks/workloads/syscall/Dockerfile b/benchmarks/workloads/syscall/Dockerfile new file mode 100644 index 000000000..a2088d953 --- /dev/null +++ b/benchmarks/workloads/syscall/Dockerfile @@ -0,0 +1,6 @@ +FROM gcc:latest +COPY . /usr/src/syscall +WORKDIR /usr/src/syscall +RUN gcc -O2 -o syscall syscall.c +ENV count 1000000 +CMD ["sh", "-c", "./syscall ${count}"] diff --git a/benchmarks/workloads/syscall/__init__.py b/benchmarks/workloads/syscall/__init__.py new file mode 100644 index 000000000..dc9028faa --- /dev/null +++ b/benchmarks/workloads/syscall/__init__.py @@ -0,0 +1,29 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simple syscall test.""" + +import re + +SAMPLE_DATA = "Called getpid syscall 1000000 times: 1117 ms, 500 ns each." + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def syscall_time_ns(data: str, **kwargs) -> int: + """Returns average system call time.""" + return float(re.compile(r"(\d+)\sns each.").search(data).group(1)) diff --git a/benchmarks/workloads/syscall/syscall.c b/benchmarks/workloads/syscall/syscall.c new file mode 100644 index 000000000..ded030397 --- /dev/null +++ b/benchmarks/workloads/syscall/syscall.c @@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define _GNU_SOURCE +#include <stdio.h> +#include <stdlib.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <time.h> +#include <unistd.h> + +// Short program that calls getpid() a number of times and outputs time +// diference from the MONOTONIC clock. +int main(int argc, char** argv) { + struct timespec start, stop; + long result; + char buf[80]; + + if (argc < 2) { + printf("Usage:./syscall NUM_TIMES_TO_CALL"); + return 1; + } + + if (clock_gettime(CLOCK_MONOTONIC, &start)) return 1; + + long loops = atoi(argv[1]); + for (long i = 0; i < loops; i++) { + syscall(SYS_gettimeofday, 0, 0); + } + + if (clock_gettime(CLOCK_MONOTONIC, &stop)) return 1; + + if ((stop.tv_nsec - start.tv_nsec) < 0) { + result = (stop.tv_sec - start.tv_sec - 1) * 1000; + result += (stop.tv_nsec - start.tv_nsec + 1000000000) / (1000 * 1000); + } else { + result = (stop.tv_sec - start.tv_sec) * 1000; + result += (stop.tv_nsec - start.tv_nsec) / (1000 * 1000); + } + + printf("Called getpid syscall %d times: %lu ms, %lu ns each.\n", loops, + result, result * 1000000 / loops); + + return 0; +} diff --git a/benchmarks/workloads/syscall/syscall_test.py b/benchmarks/workloads/syscall/syscall_test.py new file mode 100644 index 000000000..72f027de1 --- /dev/null +++ b/benchmarks/workloads/syscall/syscall_test.py @@ -0,0 +1,27 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +from benchmarks.workloads import syscall + + +def test_syscall_time_ns(): + assert syscall.syscall_time_ns(syscall.sample()) == 500 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD new file mode 100644 index 000000000..17f1f8ebb --- /dev/null +++ b/benchmarks/workloads/tensorflow/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "tensorflow", + srcs = ["__init__.py"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/tensorflow/Dockerfile b/benchmarks/workloads/tensorflow/Dockerfile new file mode 100644 index 000000000..262643b98 --- /dev/null +++ b/benchmarks/workloads/tensorflow/Dockerfile @@ -0,0 +1,14 @@ +FROM tensorflow/tensorflow:1.13.2 + +RUN apt-get update \ + && apt-get install -y git +RUN git clone https://github.com/aymericdamien/TensorFlow-Examples.git +RUN python -m pip install -U pip setuptools +RUN python -m pip install matplotlib + +WORKDIR /TensorFlow-Examples/examples + +ENV PYTHONPATH="$PYTHONPATH:/TensorFlow-Examples/examples" + +ENV workload "3_NeuralNetworks/convolutional_network.py" +CMD python ${workload} diff --git a/benchmarks/workloads/tensorflow/__init__.py b/benchmarks/workloads/tensorflow/__init__.py new file mode 100644 index 000000000..b5ec213f8 --- /dev/null +++ b/benchmarks/workloads/tensorflow/__init__.py @@ -0,0 +1,20 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Tensorflow example.""" + + +# pylint: disable=unused-argument +def run_time(value, **kwargs): + """Returns the startup and runtime of the Tensorflow workload in seconds.""" + return value diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/true/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/true/Dockerfile b/benchmarks/workloads/true/Dockerfile new file mode 100644 index 000000000..2e97c921e --- /dev/null +++ b/benchmarks/workloads/true/Dockerfile @@ -0,0 +1,3 @@ +FROM alpine:latest + +CMD ["true"] @@ -1,21 +1,21 @@ module gvisor.dev/gvisor -go 1.12 +go 1.13 require ( - github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 - github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 - github.com/golang/mock v1.3.1 - github.com/golang/protobuf v1.3.1 - github.com/google/btree v1.0.0 - github.com/google/go-cmp v0.2.0 - github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 - github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d - github.com/kr/pty v1.1.1 - github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 - github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 - github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e - github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 - golang.org/x/net v0.0.0-20190311183353-d8887717615a - golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 + github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 + github.com/golang/mock v1.3.1 + github.com/golang/protobuf v1.3.1 + github.com/google/btree v1.0.0 + github.com/google/go-cmp v0.2.0 + github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 + github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d + github.com/kr/pty v1.1.1 + github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 + github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 + github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e + github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 + golang.org/x/net v0.0.0-20190311183353-d8887717615a + golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a ) diff --git a/kokoro/build.cfg b/kokoro/build.cfg index 6c1d262d4..c9ceda947 100644 --- a/kokoro/build.cfg +++ b/kokoro/build.cfg @@ -19,5 +19,6 @@ action { regex: "**/runsc" regex: "**/runsc.*" regex: "**/dists/**" + regex: "**/pool/**" } } diff --git a/kokoro/iptables_tests.cfg b/kokoro/iptables_tests.cfg new file mode 100644 index 000000000..7af20629a --- /dev/null +++ b/kokoro/iptables_tests.cfg @@ -0,0 +1,10 @@ +build_file: "repo/scripts/iptables_test.sh" + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + regex: "**/runsc_logs_*.tar.gz" + } +} diff --git a/kokoro/kythe/generate_xrefs.cfg b/kokoro/kythe/generate_xrefs.cfg new file mode 100644 index 000000000..03e65c54e --- /dev/null +++ b/kokoro/kythe/generate_xrefs.cfg @@ -0,0 +1,28 @@ +build_file: "gvisor/kokoro/kythe/generate_xrefs.sh" + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73898 + keyname: "kokoro-rbe-service-account" + } + } +} + +bazel_setting { + project_id: "gvisor-rbe" + local_execution: false + auth_credential: { + keystore_config_id: 73898 + keyname: "kokoro-rbe-service-account" + } + bes_backend_address: "buildeventservice.googleapis.com" + foundry_backend_address: "remotebuildexecution.googleapis.com" + upsalite_frontend_address: "https://source.cloud.google.com" +} + +action { + define_artifacts { + regex: "*.kzip" + } +} diff --git a/kokoro/kythe/generate_xrefs.sh b/kokoro/kythe/generate_xrefs.sh new file mode 100644 index 000000000..799467a34 --- /dev/null +++ b/kokoro/kythe/generate_xrefs.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -ex + +# Install the latest version of Bazel. The default on Kokoro images is out of +# date. +if command -v use_bazel.sh >/dev/null; then + use_bazel.sh latest +fi +bazel version + +python3 -V + +readonly KYTHE_VERSION='v0.0.37' +readonly WORKDIR="$(mktemp -d)" +readonly KYTHE_DIR="${WORKDIR}/kythe-${KYTHE_VERSION}" +if [[ -n "$KOKORO_GIT_COMMIT" ]]; then + readonly KZIP_FILENAME="${KOKORO_ARTIFACTS_DIR}/${KOKORO_GIT_COMMIT}.kzip" +else + readonly KZIP_FILENAME="$(git rev-parse HEAD).kzip" +fi + +wget -q -O "${WORKDIR}/kythe.tar.gz" \ + "https://github.com/kythe/kythe/releases/download/${KYTHE_VERSION}/kythe-${KYTHE_VERSION}.tar.gz" +tar --no-same-owner -xzf "${WORKDIR}/kythe.tar.gz" --directory "$WORKDIR" + +if [[ -n "$KOKORO_ARTIFACTS_DIR" ]]; then + cd "${KOKORO_ARTIFACTS_DIR}/github/gvisor" +fi +bazel \ + --bazelrc="${KYTHE_DIR}/extractors.bazelrc" \ + build \ + --override_repository kythe_release="${KYTHE_DIR}" \ + --define=kythe_corpus=gvisor.dev \ + --cxxopt=-std=c++17 \ + //... + +"${KYTHE_DIR}/tools/kzip" merge \ + --output "$KZIP_FILENAME" \ + $(find -L bazel-out/*/extra_actions/ -name '*.kzip') diff --git a/kokoro/release.cfg b/kokoro/release.cfg index b9d35bc51..5cec1790a 100644 --- a/kokoro/release.cfg +++ b/kokoro/release.cfg @@ -1 +1,15 @@ build_file: "repo/scripts/release.sh" + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73898 + keyname: "kokoro-github-access-token" + } + } +} + +env_vars { + key: "KOKORO_GITHUB_ACCESS_TOKEN" + value: "73898_kokoro-github-access-token" +} diff --git a/kokoro/runtime_tests.cfg b/kokoro/runtime_tests.cfg new file mode 100644 index 000000000..7d56d5aca --- /dev/null +++ b/kokoro/runtime_tests.cfg @@ -0,0 +1 @@ +build_file: "repo/scripts/runtime_tests.sh" diff --git a/kokoro/swgso_tests.cfg b/kokoro/swgso_tests.cfg new file mode 100644 index 000000000..101a9c607 --- /dev/null +++ b/kokoro/swgso_tests.cfg @@ -0,0 +1,9 @@ +build_file: "repo/scripts/swgso_tests.sh" + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + } +} diff --git a/kokoro/syscall_kvm_tests.cfg b/kokoro/syscall_kvm_tests.cfg new file mode 100644 index 000000000..3b99e9c13 --- /dev/null +++ b/kokoro/syscall_kvm_tests.cfg @@ -0,0 +1,9 @@ +build_file: "repo/scripts/syscall_kvm_tests.sh" + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + } +} diff --git a/kokoro/ubuntu1604/10_core.sh b/kokoro/ubuntu1604/10_core.sh index e87a6eee8..46dda6bb1 100755 --- a/kokoro/ubuntu1604/10_core.sh +++ b/kokoro/ubuntu1604/10_core.sh @@ -21,8 +21,8 @@ apt-get update && apt-get -y install make git-core build-essential linux-headers # Install a recent go toolchain. if ! [[ -d /usr/local/go ]]; then - wget https://dl.google.com/go/go1.12.linux-amd64.tar.gz - tar -xvf go1.12.linux-amd64.tar.gz + wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz + tar -xvf go1.13.5.linux-amd64.tar.gz mv go /usr/local fi diff --git a/kokoro/ubuntu1604/40_kokoro.sh b/kokoro/ubuntu1604/40_kokoro.sh index 64772d74d..3f50929d5 100755 --- a/kokoro/ubuntu1604/40_kokoro.sh +++ b/kokoro/ubuntu1604/40_kokoro.sh @@ -23,7 +23,10 @@ declare -r ssh_public_keys=( ) # Install dependencies. -apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm +apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm python-pip zip + +# junitparser is used to merge junit xml files. +pip install junitparser # We need a kbuilder user. if useradd -c "kbuilder user" -m -s /bin/bash kbuilder; then diff --git a/kokoro/ubuntu1604/README.md b/kokoro/ubuntu1604/README.md new file mode 100644 index 000000000..64f913b9a --- /dev/null +++ b/kokoro/ubuntu1604/README.md @@ -0,0 +1,34 @@ +## Image Update + +After making changes to files in the directory, you must run the following +commands to update the image Kokoro uses: + +```shell +gcloud config set project gvisor-kokoro-testing +third_party/gvisor/kokoro/ubuntu1604/build.sh +third_party/gvisor/kokoro/ubuntu1804/build.sh +``` + +Note: the command above will change your default project for `gcloud`. Run +`gcloud config set project` again to revert back to your default project. + +Note: Files in `third_party/gvisor/kokoro/ubuntu1804/` as symlinks to +`ubuntu1604`, therefore both images must be updated. + +After the script finishes, the last few lines of the output will container the +image name. If the output was lost, you can run `build.sh` again to print the +image name. + +``` +NAME PROJECT FAMILY DEPRECATED STATUS +image-6777fa4666a968c8 gvisor-kokoro-testing READY ++ cleanup ++ gcloud compute instances delete --quiet build-tlfrdv +Deleted [https://www.googleapis.com/compute/v1/projects/gvisor-kokoro-testing/zones/us-central1-f/instances/build-tlfrdv]. +``` + +To setup Kokoro to use the new image, copy the image names to their +corresponding file below: + +* //devtools/kokoro/config/gcp/gvisor/ubuntu1604.gcl +* //devtools/kokoro/config/gcp/gvisor/ubuntu1804.gcl diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 7c17109a6..9553f164d 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -23,6 +23,8 @@ go_library( "exec.go", "fcntl.go", "file.go", + "file_amd64.go", + "file_arm64.go", "fs.go", "futex.go", "inotify.go", @@ -55,6 +57,7 @@ go_library( "uio.go", "utsname.go", "wait.go", + "xattr.go", ], importpath = "gvisor.dev/gvisor/pkg/abi/linux", visibility = ["//visibility:public"], diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index f78315ebf..6663a199c 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -16,15 +16,17 @@ package linux // Commands from linux/fcntl.h. const ( - F_DUPFD = 0x0 - F_GETFD = 0x1 - F_SETFD = 0x2 - F_GETFL = 0x3 - F_SETFL = 0x4 - F_SETLK = 0x6 - F_SETLKW = 0x7 - F_SETOWN = 0x8 - F_GETOWN = 0x9 + F_DUPFD = 0 + F_GETFD = 1 + F_SETFD = 2 + F_GETFL = 3 + F_SETFL = 4 + F_SETLK = 6 + F_SETLKW = 7 + F_SETOWN = 8 + F_GETOWN = 9 + F_SETOWN_EX = 15 + F_GETOWN_EX = 16 F_DUPFD_CLOEXEC = 1024 + 6 F_SETPIPE_SZ = 1024 + 7 F_GETPIPE_SZ = 1024 + 8 @@ -32,9 +34,9 @@ const ( // Commands for F_SETLK. const ( - F_RDLCK = 0x0 - F_WRLCK = 0x1 - F_UNLCK = 0x2 + F_RDLCK = 0 + F_WRLCK = 1 + F_UNLCK = 2 ) // Flags for fcntl. @@ -42,7 +44,7 @@ const ( FD_CLOEXEC = 00000001 ) -// Lock structure for F_SETLK. +// Flock is the lock structure for F_SETLK. type Flock struct { Type int16 Whence int16 @@ -52,3 +54,16 @@ type Flock struct { Pid int32 _ [4]byte } + +// Flags for F_SETOWN_EX and F_GETOWN_EX. +const ( + F_OWNER_TID = 0 + F_OWNER_PID = 1 + F_OWNER_PGRP = 2 +) + +// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX. +type FOwnerEx struct { + Type int32 + PID int32 +} diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 257f67222..16791d03e 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -144,9 +144,13 @@ const ( ModeCharacterDevice = S_IFCHR ModeNamedPipe = S_IFIFO - ModeSetUID = 04000 - ModeSetGID = 02000 - ModeSticky = 01000 + S_ISUID = 04000 + S_ISGID = 02000 + S_ISVTX = 01000 + + ModeSetUID = S_ISUID + ModeSetGID = S_ISGID + ModeSticky = S_ISVTX ModeUserAll = 0700 ModeUserRead = 0400 @@ -186,25 +190,6 @@ const ( RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC ) -// Stat represents struct stat. -type Stat struct { - Dev uint64 - Ino uint64 - Nlink uint64 - Mode uint32 - UID uint32 - GID uint32 - _ int32 - Rdev uint64 - Size int64 - Blksize int64 - Blocks int64 - ATime Timespec - MTime Timespec - CTime Timespec - _ [3]int64 -} - // SizeOfStat is the size of a Stat struct. var SizeOfStat = binary.Size(Stat{}) @@ -301,6 +286,29 @@ func (m FileMode) String() string { return strings.Join(s, "|") } +// DirentType maps file types to dirent types appropriate for (struct +// dirent)::d_type. +func (m FileMode) DirentType() uint8 { + switch m.FileType() { + case ModeSocket: + return DT_SOCK + case ModeSymlink: + return DT_LNK + case ModeRegular: + return DT_REG + case ModeBlockDevice: + return DT_BLK + case ModeDirectory: + return DT_DIR + case ModeCharacterDevice: + return DT_CHR + case ModeNamedPipe: + return DT_FIFO + default: + return DT_UNKNOWN + } +} + var modeExtraBits = abi.FlagSet{ { Flag: ModeSetUID, diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go new file mode 100644 index 000000000..74c554be6 --- /dev/null +++ b/pkg/abi/linux/file_amd64.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go new file mode 100644 index 000000000..f16c07589 --- /dev/null +++ b/pkg/abi/linux/file_arm64.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Mode uint32 + Nlink uint32 + UID uint32 + GID uint32 + Rdev uint64 + _ uint64 + Size int64 + Blksize int32 + _ int32 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [2]int32 +} diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index b416e3472..2c652baa2 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -92,3 +92,10 @@ const ( SYNC_FILE_RANGE_WRITE = 2 SYNC_FILE_RANGE_WAIT_AFTER = 4 ) + +// Flag argument to renameat2(2), from include/uapi/linux/fs.h. +const ( + RENAME_NOREPLACE = (1 << 0) // Don't overwrite target. + RENAME_EXCHANGE = (1 << 1) // Exchange src and dst. + RENAME_WHITEOUT = (1 << 2) // Whiteout src. +) diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 152f6b144..3898d2314 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -325,3 +325,9 @@ const ( RTA_SPORT = 28 RTA_DPORT = 29 ) + +// Route flags, from include/uapi/linux/route.h. +const ( + RTF_GATEWAY = 0x2 + RTF_UP = 0x1 +) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index d5b731390..766ee4014 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -256,6 +256,17 @@ type SockAddrInet6 struct { Scope_id uint32 } +// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h. +type SockAddrLink struct { + Family uint16 + Protocol uint16 + InterfaceIndex int32 + ARPHardwareType uint16 + PacketType byte + HardwareAddrLen byte + HardwareAddr [8]byte +} + // UnixPathMax is the maximum length of the path in an AF_UNIX socket. // // From uapi/linux/un.h. @@ -278,6 +289,7 @@ type SockAddr interface { func (s *SockAddrInet) implementsSockAddr() {} func (s *SockAddrInet6) implementsSockAddr() {} +func (s *SockAddrLink) implementsSockAddr() {} func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} @@ -410,6 +422,15 @@ type ControlMessageRights []int32 // ControlMessageRights. const SizeOfControlMessageRight = 4 +// SizeOfControlMessageInq is the size of a TCP_INQ control message. +const SizeOfControlMessageInq = 4 + +// SizeOfControlMessageTOS is the size of an IP_TOS control message. +const SizeOfControlMessageTOS = 1 + +// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message. +const SizeOfControlMessageTClass = 4 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go new file mode 100644 index 000000000..a3b6406fa --- /dev/null +++ b/pkg/abi/linux/xattr.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Constants for extended attributes. +const ( + XATTR_NAME_MAX = 255 + XATTR_SIZE_MAX = 65536 + + XATTR_CREATE = 1 + XATTR_REPLACE = 2 + + XATTR_USER_PREFIX = "user." + XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX) +) diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 5d61dc2ff..d37047368 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -183,6 +183,33 @@ const ( X86FeatureAVX512VBMI X86FeatureUMIP X86FeaturePKU + X86FeatureOSPKE + X86FeatureWAITPKG + X86FeatureAVX512_VBMI2 + _ // ecx bit 7 is reserved + X86FeatureGFNI + X86FeatureVAES + X86FeatureVPCLMULQDQ + X86FeatureAVX512_VNNI + X86FeatureAVX512_BITALG + X86FeatureTME + X86FeatureAVX512_VPOPCNTDQ + _ // ecx bit 15 is reserved + X86FeatureLA57 + // ecx bits 17-21 are reserved + _ + _ + _ + _ + _ + X86FeatureRDPID + // ecx bits 23-24 are reserved + _ + _ + X86FeatureCLDEMOTE + _ // ecx bit 26 is reserved + X86FeatureMOVDIRI + X86FeatureMOVDIR64B ) // Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX. @@ -353,9 +380,24 @@ var x86FeatureStrings = map[Feature]string{ X86FeatureAVX512VL: "avx512vl", // Block 3. - X86FeatureAVX512VBMI: "avx512vbmi", - X86FeatureUMIP: "umip", - X86FeaturePKU: "pku", + X86FeatureAVX512VBMI: "avx512vbmi", + X86FeatureUMIP: "umip", + X86FeaturePKU: "pku", + X86FeatureOSPKE: "ospke", + X86FeatureWAITPKG: "waitpkg", + X86FeatureAVX512_VBMI2: "avx512_vbmi2", + X86FeatureGFNI: "gfni", + X86FeatureVAES: "vaes", + X86FeatureVPCLMULQDQ: "vpclmulqdq", + X86FeatureAVX512_VNNI: "avx512_vnni", + X86FeatureAVX512_BITALG: "avx512_bitalg", + X86FeatureTME: "tme", + X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq", + X86FeatureLA57: "la57", + X86FeatureRDPID: "rdpid", + X86FeatureCLDEMOTE: "cldemote", + X86FeatureMOVDIRI: "movdiri", + X86FeatureMOVDIR64B: "movdir64b", // Block 4. X86FeatureXSAVEOPT: "xsaveopt", diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 71f2abc83..0b4b7cc44 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -25,6 +25,7 @@ go_library( proto_library( name = "eventchannel_proto", srcs = ["event.proto"], + visibility = ["//:sandbox"], ) go_proto_library( diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index c7f549428..afa8f7659 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -8,9 +8,6 @@ go_library( srcs = ["fd.go"], importpath = "gvisor.dev/gvisor/pkg/fd", visibility = ["//visibility:public"], - deps = [ - "//pkg/unet", - ], ) go_test( diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go index 7691b477b..83bcfe220 100644 --- a/pkg/fd/fd.go +++ b/pkg/fd/fd.go @@ -22,8 +22,6 @@ import ( "runtime" "sync/atomic" "syscall" - - "gvisor.dev/gvisor/pkg/unet" ) // ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It @@ -187,12 +185,6 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) { return New(f), nil } -// DialUnix connects to a Unix Domain Socket and return the file descriptor. -func DialUnix(path string) (*FD, error) { - socket, err := unet.Connect(path, false) - return New(socket.FD()), err -} - // Close closes the file descriptor contained in the FD. // // Close is safe to call multiple times, but will return an error after the diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 5643d5f26..e590a71ba 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -19,7 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", - "//third_party/gvsync", + "//pkg/syncutil", ], ) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 8390915a2..e7c3a3a0b 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -113,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error { return nil case epsBlocked | epsShutdown: atomic.AddInt32(&ep.ctrl.state, -epsBlocked) - return shutdownError{} + return ShutdownError{} default: // Most likely due to ep.enterFutexWait() being called concurrently // from multiple goroutines. diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 386cee42c..3cdb576e1 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() { // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), // ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer -// Endpoint, to unblock and return errors. It does not wait for concurrent -// calls to return. Successive calls to Shutdown have no effect. +// Endpoint, to unblock and return ShutdownErrors. It does not wait for +// concurrent calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool { return atomic.LoadUint32(&ep.shutdown) != 0 } -type shutdownError struct{} +// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown() +// has been called. +type ShutdownError struct{} // Error implements error.Error. -func (shutdownError) Error() string { +func (ShutdownError) Error() string { return "flipcall connection shutdown" } diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index a37952637..27b8939fc 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -18,7 +18,7 @@ import ( "reflect" "unsafe" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte { var ioSync int64 func raceBecomeActive() { - if gvsync.RaceEnabled { - gvsync.RaceAcquire((unsafe.Pointer)(&ioSync)) + if syncutil.RaceEnabled { + syncutil.RaceAcquire((unsafe.Pointer)(&ioSync)) } } func raceBecomeInactive() { - if gvsync.RaceEnabled { - gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + if syncutil.RaceEnabled { + syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) } } diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index b127a2bbb..168c1ccff 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error { if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { switch cs := atomic.LoadUint32(ep.connState()); cs { case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) } @@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error { return nil case ep.inactiveState: if ep.isShutdownLocally() { - return shutdownError{} + return ShutdownError{} } if err := ep.futexWaitConnState(ep.inactiveState); err != nil { return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) } continue case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index a6cc0617e..de9357389 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -17,7 +17,6 @@ package p9 import ( "fmt" "io" - "runtime" "sync/atomic" "syscall" @@ -45,15 +44,10 @@ func (c *Client) Attach(name string) (File, error) { // newFile returns a new client file. func (c *Client) newFile(fid FID) *clientFile { - cf := &clientFile{ + return &clientFile{ client: c, fid: fid, } - - // Make sure the file is closed. - runtime.SetFinalizer(cf, (*clientFile).Close) - - return cf } // clientFile is provided to clients. @@ -192,7 +186,6 @@ func (c *clientFile) Remove() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the remove message. if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil { @@ -214,7 +207,6 @@ func (c *clientFile) Close() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the close message. if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index 907445e15..96d1f2a8e 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -116,7 +116,7 @@ type File interface { // N.B. The server must resolve any lazy paths when open is called. // After this point, read and write may be called on files with no // deletion check, so resolving in the data path is not viable. - Open(mode OpenFlags) (*fd.FD, QID, uint32, error) + Open(flags OpenFlags) (*fd.FD, QID, uint32, error) // Read reads from this file. Open must be called first. // diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index ba9a55d6d..b9582c07f 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -257,7 +257,6 @@ func CanOpen(mode FileMode) bool { // handle implements handler.handle. func (t *Tlopen) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -272,15 +271,15 @@ func (t *Tlopen) handle(cs *connState) message { return newErr(syscall.EINVAL) } - // Are flags valid? - flags := t.Flags &^ OpenFlagsIgnoreMask - if flags&^OpenFlagsModeMask != 0 { - return newErr(syscall.EINVAL) - } - - // Is this an attempt to open a directory as writable? Don't accept. - if ref.mode.IsDir() && flags != ReadOnly { - return newErr(syscall.EINVAL) + if ref.mode.IsDir() { + // Directory must be opened ReadOnly. + if t.Flags&OpenFlagsModeMask != ReadOnly { + return newErr(syscall.EISDIR) + } + // Directory not truncatable. + if t.Flags&OpenTruncate != 0 { + return newErr(syscall.EISDIR) + } } var ( @@ -294,7 +293,6 @@ func (t *Tlopen) handle(cs *connState) message { return syscall.EINVAL } - // Do the open. osFile, qid, ioUnit, err = ref.file.Open(t.Flags) return err }); err != nil { @@ -311,12 +309,10 @@ func (t *Tlopen) handle(cs *connState) message { } func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return nil, syscall.EBADF @@ -390,12 +386,10 @@ func (t *Tsymlink) handle(cs *connState) message { } func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -426,19 +420,16 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { // handle implements handler.handle. func (t *Tlink) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the other FID. refTarget, ok := cs.LookupFID(t.Target) if !ok { return newErr(syscall.EBADF) @@ -467,7 +458,6 @@ func (t *Tlink) handle(cs *connState) message { // handle implements handler.handle. func (t *Trenameat) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.OldName); err != nil { return newErr(err) } @@ -475,14 +465,12 @@ func (t *Trenameat) handle(cs *connState) message { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.OldDirectory) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the other FID. refTarget, ok := cs.LookupFID(t.NewDirectory) if !ok { return newErr(syscall.EBADF) @@ -523,12 +511,10 @@ func (t *Trenameat) handle(cs *connState) message { // handle implements handler.handle. func (t *Tunlinkat) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -577,19 +563,16 @@ func (t *Tunlinkat) handle(cs *connState) message { // handle implements handler.handle. func (t *Trename) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the target. refTarget, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -641,7 +624,6 @@ func (t *Trename) handle(cs *connState) message { // handle implements handler.handle. func (t *Treadlink) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -669,7 +651,6 @@ func (t *Treadlink) handle(cs *connState) message { // handle implements handler.handle. func (t *Tread) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -708,7 +689,6 @@ func (t *Tread) handle(cs *connState) message { // handle implements handler.handle. func (t *Twrite) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -747,12 +727,10 @@ func (t *Tmknod) handle(cs *connState) message { } func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -791,12 +769,10 @@ func (t *Tmkdir) handle(cs *connState) message { } func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -827,7 +803,6 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { // handle implements handler.handle. func (t *Tgetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -856,7 +831,6 @@ func (t *Tgetattr) handle(cs *connState) message { // handle implements handler.handle. func (t *Tsetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -883,7 +857,6 @@ func (t *Tsetattr) handle(cs *connState) message { // handle implements handler.handle. func (t *Tallocate) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -917,7 +890,6 @@ func (t *Tallocate) handle(cs *connState) message { // handle implements handler.handle. func (t *Txattrwalk) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -930,7 +902,6 @@ func (t *Txattrwalk) handle(cs *connState) message { // handle implements handler.handle. func (t *Txattrcreate) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -943,7 +914,6 @@ func (t *Txattrcreate) handle(cs *connState) message { // handle implements handler.handle. func (t *Treaddir) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -977,7 +947,6 @@ func (t *Treaddir) handle(cs *connState) message { // handle implements handler.handle. func (t *Tfsync) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1001,7 +970,6 @@ func (t *Tfsync) handle(cs *connState) message { // handle implements handler.handle. func (t *Tstatfs) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1192,7 +1160,6 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI // handle implements handler.handle. func (t *Twalk) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1213,7 +1180,6 @@ func (t *Twalk) handle(cs *connState) message { // handle implements handler.handle. func (t *Twalkgetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1270,7 +1236,6 @@ func (t *Tumknod) handle(cs *connState) message { // handle implements handler.handle. func (t *Tlconnect) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1303,7 +1268,6 @@ func (t *Tchannel) handle(cs *connState) message { return newErr(err) } - // Lookup the given channel. ch := cs.lookupChannel(t.ID) if ch == nil { return newErr(syscall.ENOSYS) diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 25530adca..d3090535a 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -32,21 +32,21 @@ import ( type OpenFlags uint32 const ( - // ReadOnly is a Topen and Tcreate flag indicating read-only mode. + // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode. ReadOnly OpenFlags = 0 - // WriteOnly is a Topen and Tcreate flag indicating write-only mode. + // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode. WriteOnly OpenFlags = 1 - // ReadWrite is a Topen flag indicates read-write mode. + // ReadWrite is a Tlopen flag indicates read-write mode. ReadWrite OpenFlags = 2 // OpenFlagsModeMask is a mask of valid OpenFlags mode bits. OpenFlagsModeMask OpenFlags = 3 - // OpenFlagsIgnoreMask is a list of OpenFlags mode bits that are ignored for Tlopen. - // Note that syscall.O_LARGEFILE is set to zero, use value from Linux fcntl.h. - OpenFlagsIgnoreMask OpenFlags = syscall.O_DIRECTORY | syscall.O_NOATIME | 0100000 + // OpenTruncate is a Tlopen flag indicating that the opened file should be + // truncated. + OpenTruncate OpenFlags = 01000 ) // ConnectFlags is the mode passed to Connect operations. @@ -71,25 +71,32 @@ const ( // OSFlags converts a p9.OpenFlags to an int compatible with open(2). func (o OpenFlags) OSFlags() int { - return int(o & OpenFlagsModeMask) + // "flags contains Linux open(2) flags bits" - 9P2000.L + return int(o) } // String implements fmt.Stringer. func (o OpenFlags) String() string { - switch o { + var buf strings.Builder + switch mode := o & OpenFlagsModeMask; mode { case ReadOnly: - return "ReadOnly" + buf.WriteString("ReadOnly") case WriteOnly: - return "WriteOnly" + buf.WriteString("WriteOnly") case ReadWrite: - return "ReadWrite" - case OpenFlagsModeMask: - return "OpenFlagsModeMask" - case OpenFlagsIgnoreMask: - return "OpenFlagsIgnoreMask" + buf.WriteString("ReadWrite") default: - return "UNDEFINED" + fmt.Fprintf(&buf, "%#o", mode) } + otherFlags := o &^ OpenFlagsModeMask + if otherFlags&OpenTruncate != 0 { + buf.WriteString("|OpenTruncate") + otherFlags &^= OpenTruncate + } + if otherFlags != 0 { + fmt.Fprintf(&buf, "|%#o", otherFlags) + } + return buf.String() } // Tag is a message tag. @@ -814,7 +821,7 @@ func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) { attr.Mode = FileMode(s.Mode) } if req.NLink { - attr.NLink = s.Nlink + attr.NLink = uint64(s.Nlink) } if req.UID { attr.UID = UID(s.Uid) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 8bbdb2488..6e758148d 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -1044,11 +1044,11 @@ func TestReaddir(t *testing.T) { if _, err := f.Readdir(0, 1); err != syscall.EINVAL { t.Errorf("readdir got %v, wanted EINVAL", err) } - if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) + if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) } - if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) + if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) } backend.EXPECT().Open(p9.ReadOnly).Times(1) if _, _, _, err := f.Open(p9.ReadOnly); err != nil { @@ -1065,75 +1065,93 @@ func TestReaddir(t *testing.T) { func TestOpen(t *testing.T) { type openTest struct { name string - mode p9.OpenFlags + flags p9.OpenFlags err error match func(p9.FileMode) bool } cases := []openTest{ { - name: "invalid", - mode: ^p9.OpenFlagsModeMask, - err: syscall.EINVAL, - match: func(p9.FileMode) bool { return true }, - }, - { name: "not-openable-read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "not-openable-write-only", - mode: p9.WriteOnly, + flags: p9.WriteOnly, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "not-openable-read-write", - mode: p9.ReadWrite, + flags: p9.ReadWrite, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "directory-read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: nil, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "directory-read-write", - mode: p9.ReadWrite, - err: syscall.EINVAL, + flags: p9.ReadWrite, + err: syscall.EISDIR, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "directory-write-only", - mode: p9.WriteOnly, - err: syscall.EINVAL, + flags: p9.WriteOnly, + err: syscall.EISDIR, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, }, { name: "write-only", - mode: p9.WriteOnly, + flags: p9.WriteOnly, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, }, { name: "read-write", - mode: p9.ReadWrite, + flags: p9.ReadWrite, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "directory-read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: syscall.EISDIR, + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + }, + { + name: "read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "write-only-truncate", + flags: p9.WriteOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "read-write-truncate", + flags: p9.ReadWrite | p9.OpenTruncate, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, }, } - // Open(mode OpenFlags) (*fd.FD, QID, uint32, error) + // Open(flags OpenFlags) (*fd.FD, QID, uint32, error) // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice // - returning a file works as expected for name := range newTypeMap(nil) { @@ -1171,25 +1189,25 @@ func TestOpen(t *testing.T) { // Attempt the given open. if tc.err != nil { // We expect an error, just test and return. - if _, _, _, err := f.Open(tc.mode); err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) + if _, _, _, err := f.Open(tc.flags); err != tc.err { + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) } return } // Run an FD test, since we expect success. fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Open(tc.mode).Return(send, p9.QID{}, uint32(0), nil).Times(1) - recv, _, _, err := f.Open(tc.mode) + backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1) + recv, _, _, err := f.Open(tc.flags) if err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) } return recv }) // If the open was successful, attempt another one. - if _, _, _, err := f.Open(tc.mode); err != syscall.EINVAL { - t.Errorf("second open with mode %v got %v, want EINVAL", tc.mode, err) + if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL { + t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err) } // Ensure that all illegal operations fail. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index e717e6161..40b8fa023 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -453,7 +453,11 @@ func (cs *connState) initializeChannels() (err error) { go func() { // S/R-SAFE: Server side. defer cs.channelWg.Done() if err := res.service(cs); err != nil { - log.Warningf("p9.channel.service: %v", err) + // Don't log flipcall.ShutdownErrors, which we expect to be + // returned during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("p9.channel.service: %v", err) + } } }() } diff --git a/pkg/p9/version.go b/pkg/p9/version.go index f1ffdd23a..36a694c58 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 8 + highestSupportedVersion uint32 = 9 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -155,3 +155,9 @@ func versionSupportsTallocate(v uint32) bool { func versionSupportsFlipcall(v uint32) bool { return v >= 8 } + +// VersionSupportsOpenTruncateFlag returns true if version v supports +// passing the OpenTruncate flag to Tlopen. +func VersionSupportsOpenTruncateFlag(v uint32) bool { + return v >= 9 +} diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 30ec8e6e2..38cea9be3 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index e340d9f98..4f4b70fef 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index c7503f2cc..fc36efa23 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -199,6 +199,10 @@ func ruleViolationLabel(ruleSetIdx int, sysno uintptr, idx int) string { return fmt.Sprintf("ruleViolation_%v_%v_%v", ruleSetIdx, sysno, idx) } +func ruleLabel(ruleSetIdx int, sysno uintptr, idx int, name string) string { + return fmt.Sprintf("rule_%v_%v_%v_%v", ruleSetIdx, sysno, idx, name) +} + func checkArgsLabel(sysno uintptr) string { return fmt.Sprintf("checkArgs_%v", sysno) } @@ -223,6 +227,19 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i)) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) labelled = true + case GreaterThan: + labelGood := fmt.Sprintf("gt%v", i) + high, low := uint32(a>>32), uint32(a) + // assert arg_high < high + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + // arg_high > high + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + // arg_low < low + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true default: return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a)) } diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go index 29eec8db1..84c841d7f 100644 --- a/pkg/seccomp/seccomp_rules.go +++ b/pkg/seccomp/seccomp_rules.go @@ -49,6 +49,9 @@ func (a AllowAny) String() (s string) { // AllowValue specifies a value that needs to be strictly matched. type AllowValue uintptr +// GreaterThan specifies a value that needs to be strictly smaller. +type GreaterThan uintptr + func (a AllowValue) String() (s string) { return fmt.Sprintf("%#x ", uintptr(a)) } diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 353686ed3..abbee7051 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -340,6 +340,54 @@ func TestBasic(t *testing.T) { }, }, }, + { + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + GreaterThan(0xf), + GreaterThan(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + specs: []spec{ + { + desc: "GreaterThan: Syscall argument allowed", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "GreaterThan: Syscall argument disallowed (equal)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "Syscall argument disallowed (smaller)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "GreaterThan2: Syscall argument allowed", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "GreaterThan2: Syscall argument disallowed (equal)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "GreaterThan2: Syscall argument disallowed (smaller)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, } { instrs, err := BuildProgram(test.ruleSets, test.defaultAction) if err != nil { diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index 48413f1fb..da6b9eaaf 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -38,7 +38,7 @@ func main() { syscall.SYS_CLONE: {}, syscall.SYS_CLOSE: {}, syscall.SYS_DUP: {}, - syscall.SYS_DUP2: {}, + syscall.SYS_DUP3: {}, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_WAIT: {}, diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 9e7db8b30..67daa6c24 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -305,7 +305,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs()) return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return c.Native(0), nil } @@ -320,6 +320,6 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { _, err := c.PtraceSetRegs(bytes.NewBuffer(buf)) return err } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return nil } diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go index dfd62cbdb..23e009ef3 100644 --- a/pkg/sentry/context/context.go +++ b/pkg/sentry/context/context.go @@ -12,10 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package context defines the sentry's Context type. +// Package context defines an internal context type. +// +// The given Context conforms to the standard Go context, but mandates +// additional methods that are specific to the kernel internals. Note however, +// that the Context described by this package carries additional constraints +// regarding concurrent access and retaining beyond the scope of a call. +// +// See the Context type for complete details. package context import ( + "context" + "time" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { type Context interface { log.Logger amutex.Sleeper + context.Context // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate @@ -72,19 +83,36 @@ type Context interface { // AddressSpace is activated. Normally activate is the same value as the // deactivate parameter passed to UninterruptibleSleepStart. UninterruptibleSleepFinish(activate bool) +} + +// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep +// methods for anonymous embedding in other types that do not implement sleeps. +type NoopSleeper struct { + amutex.NoopSleeper +} + +// UninterruptibleSleepStart does nothing. +func (NoopSleeper) UninterruptibleSleepStart(bool) {} + +// UninterruptibleSleepFinish does nothing. +func (NoopSleeper) UninterruptibleSleepFinish(bool) {} + +// Deadline returns zero values, meaning no deadline. +func (NoopSleeper) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done returns nil. +func (NoopSleeper) Done() <-chan struct{} { + return nil +} - // Value returns the value associated with this Context for key, or nil if - // no value is associated with key. Successive calls to Value with the same - // key returns the same result. - // - // A key identifies a specific value in a Context. Functions that wish to - // retrieve values from Context typically allocate a key in a global - // variable then use that key as the argument to Context.Value. A key can - // be any type that supports equality; packages should define keys as an - // unexported type to avoid collisions. - Value(key interface{}) interface{} +// Err returns nil. +func (NoopSleeper) Err() error { + return nil } +// logContext implements basic logging. type logContext struct { log.Logger NoopSleeper @@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} { return nil } -// NoopSleeper is a noop implementation of amutex.Sleeper and -// Context.UninterruptibleSleep* methods for anonymous embedding in other types -// that do not want to notify kernel.Task about sleeps. -type NoopSleeper struct { - amutex.NoopSleeper -} - -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} - -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} - // bgContext is the context returned by context.Background. var bgContext = &logContext{Logger: log.Log()} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 1f78d54a2..e1f2fea60 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -22,6 +22,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/urpc" ) @@ -56,6 +57,9 @@ type Profile struct { // traceFile is the current execution trace output file. traceFile *fd.FD + + // Kernel is the kernel under profile. + Kernel *kernel.Kernel } // StartCPUProfile is an RPC stub which starts recording the CPU profile in a @@ -147,6 +151,9 @@ func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { return err } + // Ensure all trace contexts are registered. + p.Kernel.RebuildTraceContexts() + p.traceFile = output return nil } @@ -158,9 +165,15 @@ func (p *Profile) StopTrace(_, _ *struct{}) error { defer p.mu.Unlock() if p.traceFile == nil { - return errors.New("Execution tracing not start") + return errors.New("Execution tracing not started") } + // Similarly to the case above, if tasks have not ended traces, we will + // lose information. Thus we need to rebuild the tasks in order to have + // complete information. This will not lose information if multiple + // traces are overlapping. + p.Kernel.RebuildTraceContexts() + trace.Stop() p.traceFile.Close() p.traceFile = nil diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index c35faeb4c..ced51c66c 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -268,14 +268,17 @@ func (proc *Proc) Ps(args *PsArgs, out *string) error { } // Process contains information about a single process in a Sandbox. -// TODO(b/117881927): Implement TTY field. type Process struct { UID auth.KUID `json:"uid"` PID kernel.ThreadID `json:"pid"` // Parent PID - PPID kernel.ThreadID `json:"ppid"` + PPID kernel.ThreadID `json:"ppid"` + Threads []kernel.ThreadID `json:"threads"` // Processor utilization C int32 `json:"c"` + // TTY name of the process. Will be of the form "pts/N" if there is a + // TTY, or "?" if there is not. + TTY string `json:"tty"` // Start time STime string `json:"stime"` // CPU time @@ -285,18 +288,19 @@ type Process struct { } // ProcessListToTable prints a table with the following format: -// UID PID PPID C STIME TIME CMD -// 0 1 0 0 14:04 505262ns tail +// UID PID PPID C TTY STIME TIME CMD +// 0 1 0 0 pty/4 14:04 505262ns tail func ProcessListToTable(pl []*Process) string { var buf bytes.Buffer tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0) - fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD") + fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD") for _, d := range pl { - fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s", + fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s", d.UID, d.PID, d.PPID, d.C, + d.TTY, d.STime, d.Time, d.Cmd) @@ -307,7 +311,7 @@ func ProcessListToTable(pl []*Process) string { // ProcessListToJSON will return the JSON representation of ps. func ProcessListToJSON(pl []*Process) (string, error) { - b, err := json.Marshal(pl) + b, err := json.MarshalIndent(pl, "", " ") if err != nil { return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err) } @@ -334,7 +338,9 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ts := k.TaskSet() now := k.RealtimeClock().Now() for _, tg := range ts.Root.ThreadGroups() { - pid := tg.PIDNamespace().IDOfThreadGroup(tg) + pidns := tg.PIDNamespace() + pid := pidns.IDOfThreadGroup(tg) + // If tg has already been reaped ignore it. if pid == 0 { continue @@ -345,16 +351,19 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ppid := kernel.ThreadID(0) if p := tg.Leader().Parent(); p != nil { - ppid = p.PIDNamespace().IDOfThreadGroup(p.ThreadGroup()) + ppid = pidns.IDOfThreadGroup(p.ThreadGroup()) } + threads := tg.MemberIDs(pidns) *out = append(*out, &Process{ - UID: tg.Leader().Credentials().EffectiveKUID, - PID: pid, - PPID: ppid, - STime: formatStartTime(now, tg.Leader().StartTime()), - C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), - Time: tg.CPUStats().SysTime.String(), - Cmd: tg.Leader().Name(), + UID: tg.Leader().Credentials().EffectiveKUID, + PID: pid, + PPID: ppid, + Threads: threads, + STime: formatStartTime(now, tg.Leader().StartTime()), + C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), + Time: tg.CPUStats().SysTime.String(), + Cmd: tg.Leader().Name(), + TTY: ttyName(tg.TTY()), }) } sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID }) @@ -395,3 +404,10 @@ func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 { } return int32(percentCPU) } + +func ttyName(tty *kernel.TTY) string { + if tty == nil { + return "?" + } + return fmt.Sprintf("pts/%d", tty.Index) +} diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go index d8ada2694..0a88459b2 100644 --- a/pkg/sentry/control/proc_test.go +++ b/pkg/sentry/control/proc_test.go @@ -34,7 +34,7 @@ func TestProcessListTable(t *testing.T) { }{ { pl: []*Process{}, - expected: "UID PID PPID C STIME TIME CMD", + expected: "UID PID PPID C TTY STIME TIME CMD", }, { pl: []*Process{ @@ -43,6 +43,7 @@ func TestProcessListTable(t *testing.T) { PID: 0, PPID: 0, C: 0, + TTY: "?", STime: "0", Time: "0", Cmd: "zero", @@ -52,14 +53,15 @@ func TestProcessListTable(t *testing.T) { PID: 1, PPID: 1, C: 1, + TTY: "pts/4", STime: "1", Time: "1", Cmd: "one", }, }, - expected: `UID PID PPID C STIME TIME CMD -0 0 0 0 0 0 zero -1 1 1 1 1 1 one`, + expected: `UID PID PPID C TTY STIME TIME CMD +0 0 0 0 ? 0 0 zero +1 1 1 1 pts/4 1 1 one`, }, } diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 378602cc9..c035ffff7 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -68,9 +68,9 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/syncutil", "//pkg/syserror", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go index 8e4d839e1..577445148 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" diff --git a/pkg/sentry/fs/flags.go b/pkg/sentry/fs/flags.go index 0fab876a9..4338ae1fa 100644 --- a/pkg/sentry/fs/flags.go +++ b/pkg/sentry/fs/flags.go @@ -64,6 +64,10 @@ type FileFlags struct { // NonSeekable indicates that file.offset isn't used. NonSeekable bool + + // Truncate indicates that the file should be truncated before opened. + // This is only applicable if the file is regular. + Truncate bool } // SettableFileFlags is a subset of FileFlags above that can be changed @@ -118,6 +122,9 @@ func (f FileFlags) ToLinux() (mask uint) { if f.LargeFile { mask |= linux.O_LARGEFILE } + if f.Truncate { + mask |= linux.O_TRUNC + } switch { case f.Read && f.Write: diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go index c2fbb4be9..bb8312849 100644 --- a/pkg/sentry/fs/gofer/file_state.go +++ b/pkg/sentry/fs/gofer/file_state.go @@ -28,8 +28,14 @@ func (f *fileOperations) afterLoad() { // Manually load the open handles. var err error + + // The file may have been opened with Truncate, but we don't + // want to re-open it with Truncate or we will lose data. + flags := f.flags + flags.Truncate = false + // TODO(b/38173783): Context is not plumbed to save/restore. - f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags, f.inodeOperations.cachingInodeOps) + f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps) if err != nil { return fmt.Errorf("failed to re-open handle: %v", err) } diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index 39c8ec33d..b86c49b39 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -64,7 +64,7 @@ func (h *handles) DecRef() { }) } -func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*handles, error) { +func newHandles(ctx context.Context, client *p9.Client, file contextFile, flags fs.FileFlags) (*handles, error) { _, newFile, err := file.walk(ctx, nil) if err != nil { return nil, err @@ -81,6 +81,9 @@ func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*han default: panic("impossible fs.FileFlags") } + if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(client.Version()) { + p9flags |= p9.OpenTruncate + } hostFile, _, _, err := newFile.open(ctx, p9flags) if err != nil { diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 99910388f..91263ebdc 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -180,7 +180,7 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles) // given flags. func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) { if !i.canShareHandles() { - return newHandles(ctx, i.file, flags) + return newHandles(ctx, i.s.client, i.file, flags) } i.handlesMu.Lock() @@ -201,19 +201,25 @@ func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cac // whether previously open read handle was recreated. Host mappings must be // invalidated if so. func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) { - // Do we already have usable shared handles? - if flags.Write { + // Check if we are able to use cached handles. + if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(i.s.client.Version()) { + // If we are truncating (and the gofer supports it), then we + // always need a new handle. Don't return one from the cache. + } else if flags.Write { if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) { + // File is opened for writing, and we have cached write + // handles that we can use. i.writeHandles.IncRef() return i.writeHandles, false, nil } } else if i.readHandles != nil { + // File is opened for reading and we have cached handles. i.readHandles.IncRef() return i.readHandles, false, nil } - // No; get new handles and cache them for future sharing. - h, err := newHandles(ctx, i.file, flags) + // Get new handles and cache them for future sharing. + h, err := newHandles(ctx, i.s.client, i.file, flags) if err != nil { return nil, false, err } @@ -239,7 +245,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle if !flags.Read { // Writer can't be used for read, must create a new handle. var err error - h, err = newHandles(ctx, i.file, fs.FileFlags{Read: true}) + h, err = newHandles(ctx, i.s.client, i.file, fs.FileFlags{Read: true}) if err != nil { return err } @@ -268,7 +274,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle // operations on the old will see the new data. Then, make the new handle take // ownereship of the old FD and mark the old readHandle to not close the FD // when done. - if err := syscall.Dup2(h.Host.FD(), i.readHandles.Host.FD()); err != nil { + if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil { return err } diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 8c17603f8..c09f3b71c 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -234,6 +234,8 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, if err != nil { return nil, err } + // We're not going to use newFile after return. + defer newFile.close(ctx) // Stabilize the endpoint map while creation is in progress. unlock := i.session().endpoints.lock() @@ -254,7 +256,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, // Get the attributes of the file to create inode key. qid, mask, attr, err := getattr(ctx, newFile) if err != nil { - newFile.close(ctx) return nil, err } @@ -270,7 +271,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, // cloned and re-opened multiple times after creation. _, unopened, err := i.fileState.file.walk(ctx, []string{name}) if err != nil { - newFile.close(ctx) return nil, err } diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 0da608548..4e358a46a 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -143,9 +143,9 @@ type session struct { // socket files. This allows unix domain sockets to be used with paths that // belong to a gofer. // - // TODO(b/77154739): there are few possible races with someone stat'ing the - // file and another deleting it concurrently, where the file will not be - // reported as socket file. + // TODO(gvisor.dev/issue/1200): there are few possible races with someone + // stat'ing the file and another deleting it concurrently, where the file + // will not be reported as socket file. endpoints *endpointMaps `state:"wait"` } diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 1cbed07ae..23daeb528 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -21,6 +21,8 @@ go_library( "socket_unsafe.go", "tty.go", "util.go", + "util_amd64_unsafe.go", + "util_arm64_unsafe.go", "util_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host", diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go index bad61a9a1..e37e687c6 100644 --- a/pkg/sentry/fs/host/util.go +++ b/pkg/sentry/fs/host/util.go @@ -155,7 +155,7 @@ func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr { AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec), ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec), StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec), - Links: s.Nlink, + Links: uint64(s.Nlink), } } diff --git a/pkg/sentry/fs/host/util_amd64_unsafe.go b/pkg/sentry/fs/host/util_amd64_unsafe.go new file mode 100644 index 000000000..66da6e9f5 --- /dev/null +++ b/pkg/sentry/fs/host/util_amd64_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package host + +import ( + "syscall" + "unsafe" +) + +func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { + var stat syscall.Stat_t + namePtr, err := syscall.BytePtrFromString(name) + if err != nil { + return stat, err + } + _, _, errno := syscall.Syscall6( + syscall.SYS_NEWFSTATAT, + uintptr(fd), + uintptr(unsafe.Pointer(namePtr)), + uintptr(unsafe.Pointer(&stat)), + uintptr(flags), + 0, 0) + if errno != 0 { + return stat, errno + } + return stat, nil +} diff --git a/pkg/sentry/fs/host/util_arm64_unsafe.go b/pkg/sentry/fs/host/util_arm64_unsafe.go new file mode 100644 index 000000000..e8cb94aeb --- /dev/null +++ b/pkg/sentry/fs/host/util_arm64_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package host + +import ( + "syscall" + "unsafe" +) + +func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { + var stat syscall.Stat_t + namePtr, err := syscall.BytePtrFromString(name) + if err != nil { + return stat, err + } + _, _, errno := syscall.Syscall6( + syscall.SYS_FSTATAT, + uintptr(fd), + uintptr(unsafe.Pointer(namePtr)), + uintptr(unsafe.Pointer(&stat)), + uintptr(flags), + 0, 0) + if errno != 0 { + return stat, errno + } + return stat, nil +} diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go index 2b76f1065..3ab36b088 100644 --- a/pkg/sentry/fs/host/util_unsafe.go +++ b/pkg/sentry/fs/host/util_unsafe.go @@ -116,22 +116,3 @@ func setTimestamps(fd int, ts fs.TimeSpec) error { } return nil } - -func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { - var stat syscall.Stat_t - namePtr, err := syscall.BytePtrFromString(name) - if err != nil { - return stat, err - } - _, _, errno := syscall.Syscall6( - syscall.SYS_NEWFSTATAT, - uintptr(fd), - uintptr(unsafe.Pointer(namePtr)), - uintptr(unsafe.Pointer(&stat)), - uintptr(flags), - 0, 0) - if errno != 0 { - return stat, errno - } - return stat, nil -} diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index f4ddfa406..91e2fde2f 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -270,6 +270,14 @@ func (i *Inode) Getxattr(name string) (string, error) { return i.InodeOperations.Getxattr(i, name) } +// Setxattr calls i.InodeOperations.Setxattr with i as the Inode. +func (i *Inode) Setxattr(name, value string) error { + if i.overlay != nil { + return overlaySetxattr(i.overlay, name, value) + } + return i.InodeOperations.Setxattr(i, name, value) +} + // Listxattr calls i.InodeOperations.Listxattr with i as the Inode. func (i *Inode) Listxattr() (map[string]struct{}, error) { if i.overlay != nil { @@ -344,6 +352,10 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error // Truncate calls i.InodeOperations.Truncate with i as the Inode. func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error { + if IsDir(i.StableAttr) { + return syserror.EISDIR + } + if i.overlay != nil { return overlayTruncate(ctx, i.overlay, d, size) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 5a388dad1..13d11e001 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -436,7 +436,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena } func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { - if err := copyUp(ctx, parent); err != nil { + if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err } @@ -462,7 +462,9 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri inode.DecRef() return nil, err } - return NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil + // Use the parent's MountSource, since that corresponds to the overlay, + // and not the upper filesystem. + return NewDirent(ctx, newOverlayInode(ctx, entry, parent.Inode.MountSource), name), nil } func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint { @@ -550,6 +552,11 @@ func overlayGetxattr(o *overlayEntry, name string) (string, error) { return s, err } +// TODO(b/146028302): Support setxattr for overlayfs. +func overlaySetxattr(o *overlayEntry, name, value string) error { + return syserror.EOPNOTSUPP +} + func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) { o.copyMu.RLock() defer o.copyMu.RUnlock() diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 1d3ff39e0..25573e986 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syncutil" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/third_party/gvsync" ) // The virtual filesystem implements an overlay configuration. For a high-level @@ -199,7 +199,7 @@ type overlayEntry struct { upper *Inode // dirCacheMu protects dirCache. - dirCacheMu gvsync.DowngradableRWMutex `state:"nosave"` + dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"` // dirCache is cache of DentAttrs from upper and lower Inodes. dirCache *SortedDentryMap diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 497475db5..75cbb0622 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -52,6 +52,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/tcpip/header", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index f70239449..402919924 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "io" + "reflect" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -33,6 +34,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // newNet creates a new proc net entry. @@ -40,7 +43,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo var contents map[string]*fs.Inode if s := p.k.NetworkStack(); s != nil { contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a @@ -57,7 +61,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo // (ClockGetres returns 1ns resolution). "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")), - "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")), + "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), @@ -195,6 +199,193 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netSnmp implements seqfile.SeqSource for /proc/net/snmp. +// +// +stateify savable +type netSnmp struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netSnmp) NeedsUpdate(generation int64) bool { + return true +} + +type snmpLine struct { + prefix string + header string +} + +var snmp = []snmpLine{ + { + prefix: "Ip", + header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates", + }, + { + prefix: "Icmp", + header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps", + }, + { + prefix: "IcmpMsg", + }, + { + prefix: "Tcp", + header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors", + }, + { + prefix: "Udp", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, + { + prefix: "UdpLite", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, +} + +func toSlice(a interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(a)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + +func sprintSlice(s []uint64) string { + if len(s) == 0 { + return "" + } + r := fmt.Sprint(s) + return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's +// net/core/net-procfs.c:dev_seq_show. +func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + contents := make([]string, 0, len(snmp)*2) + types := []interface{}{ + &inet.StatSNMPIP{}, + &inet.StatSNMPICMP{}, + nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats. + &inet.StatSNMPTCP{}, + &inet.StatSNMPUDP{}, + &inet.StatSNMPUDPLite{}, + } + for i, stat := range types { + line := snmp[i] + if stat == nil { + contents = append( + contents, + fmt.Sprintf("%s:\n", line.prefix), + fmt.Sprintf("%s:\n", line.prefix), + ) + continue + } + if err := n.s.Statistics(stat, line.prefix); err != nil { + if err == syserror.EOPNOTSUPP { + log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } else { + log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } + } + var values string + if line.prefix == "Tcp" { + tcp := stat.(*inet.StatSNMPTCP) + // "Tcp" needs special processing because MaxConn is signed. RFC 2012. + values = fmt.Sprintf("%s %d %s", sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) + } else { + values = sprintSlice(toSlice(stat)) + } + contents = append( + contents, + fmt.Sprintf("%s: %s\n", line.prefix, line.header), + fmt.Sprintf("%s: %s\n", line.prefix, values), + ) + } + + data := make([]seqfile.SeqData, 0, len(snmp)*2) + for _, l := range contents { + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netSnmp)(nil)}) + } + + return data, 0 +} + +// netRoute implements seqfile.SeqSource for /proc/net/route. +// +// +stateify savable +type netRoute struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netRoute) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. +func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + interfaces := n.s.Interfaces() + contents := []string{"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT"} + for _, rt := range n.s.RouteTable() { + // /proc/net/route only includes ipv4 routes. + if rt.Family != linux.AF_INET { + continue + } + + // /proc/net/route does not include broadcast or multicast routes. + if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST { + continue + } + + iface, ok := interfaces[rt.OutputInterface] + if !ok || iface.Name == "lo" { + continue + } + + var ( + gw uint32 + prefix uint32 + flags = linux.RTF_UP + ) + if len(rt.GatewayAddr) == header.IPv4AddressSize { + flags |= linux.RTF_GATEWAY + gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) + } + if len(rt.DstAddr) == header.IPv4AddressSize { + prefix = usermem.ByteOrder.Uint32(rt.DstAddr) + } + l := fmt.Sprintf( + "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", + iface.Name, + prefix, + gw, + flags, + 0, // RefCnt. + 0, // Use. + 0, // Metric. + (uint32(1)<<rt.DstLen)-1, + 0, // MTU. + 0, // Window. + 0, // RTT. + ) + contents = append(contents, l) + } + + var data []seqfile.SeqData + for _, l := range contents { + l = fmt.Sprintf("%-127s\n", l) + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netRoute)(nil)}) + } + + return data, 0 +} + // netUnix implements seqfile.SeqSource for /proc/net/unix. // // +stateify savable diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index f3b63dfc2..bd93f83fa 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -64,7 +64,7 @@ var _ fs.InodeOperations = (*tcpMemInode)(nil) func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir tcpMemDir) *fs.Inode { tm := &tcpMemInode{ - SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), s: s, dir: dir, } @@ -77,6 +77,11 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir return fs.NewInode(ctx, tm, msrc, sattr) } +// Truncate implements fs.InodeOperations.Truncate. +func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true @@ -168,14 +173,15 @@ func writeSize(dirType tcpMemDir, s inet.Stack, size inet.TCPBufferSize) error { // +stateify savable type tcpSack struct { + fsutil.SimpleFileInode + stack inet.Stack `state:"wait"` enabled *bool - fsutil.SimpleFileInode } func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { ts := &tcpSack{ - SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), stack: s, } sattr := fs.StableAttr{ @@ -187,6 +193,11 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f return fs.NewInode(ctx, ts, msrc, sattr) } +// Truncate implements fs.InodeOperations.Truncate. +func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. func (s *tcpSack) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 87184ec67..9bf4b4527 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -67,29 +67,28 @@ type taskDir struct { var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. -func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode { +func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - // FIXME(b/123511468): create the correct io file for threads. - "io": newIO(t, msrc), + "auxv": newAuxvec(t, msrc), + "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), + "comm": newComm(t, msrc), + "environ": newExecArgInode(t, msrc, environExecArg), + "exe": newExe(t, msrc), + "fd": newFdDir(t, msrc), + "fdinfo": newFdInfoDir(t, msrc), + "gid_map": newGIDMap(t, msrc), + "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "ns": newNamespaceDir(t, msrc), "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, showSubtasks, p.pidns), + "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), "statm": newStatm(t, msrc), "status": newStatus(t, msrc, p.pidns), "uid_map": newUIDMap(t, msrc), } - if showSubtasks { + if isThreadGroup { contents["task"] = p.newSubtasks(t, msrc) } if len(p.cgroupControllers) > 0 { @@ -605,6 +604,10 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps) fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps) fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(&buf, "Mems_allowed:\t1\n") + fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n") return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0 } @@ -619,8 +622,11 @@ type ioData struct { ioUsage } -func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) +func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { + if isThreadGroup { + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) + } + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -639,7 +645,7 @@ func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se io.Accumulate(i.IOUsage()) var buf bytes.Buffer - fmt.Fprintf(&buf, "char: %d\n", io.CharsRead) + fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead) fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten) fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls) fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls) diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 19b7557d5..6b07f6bf2 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -76,6 +76,11 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn func (mi *masterInodeOperations) Release(ctx context.Context) { } +// Truncate implements fs.InodeOperations.Truncate. +func (*masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. // // It allocates a new terminal. diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index 944c4ada1..2a51e6bab 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -72,6 +72,11 @@ func (si *slaveInodeOperations) Release(ctx context.Context) { si.t.DecRef() } +// Truncate implements fs.InodeOperations.Truncate. +func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. // // This may race with destruction of the terminal. If the terminal is gone, it diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index ff8138820..917f90cc0 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -53,8 +53,8 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal d: d, n: n, ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{}, - slaveKTTY: &kernel.TTY{}, + masterKTTY: &kernel.TTY{Index: n}, + slaveKTTY: &kernel.TTY{Index: n}, } t.EnableLeakCheck("tty.Terminal") return &t diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 7ccff8b0d..880b7bcd3 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -38,6 +38,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/fd", + "//pkg/fspath", "//pkg/log", "//pkg/sentry/arch", "//pkg/sentry/context", diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 10a8083a0..177ce2cb9 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -50,7 +50,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -81,7 +81,11 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf ctx := contexttest.Context(b) creds := auth.CredentialsFromContext(ctx) - if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}); err != nil { + if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } return func() { diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index cea89bcd9..a2d8c3ad6 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -154,7 +154,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds toRead = len(dst) } - n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) + n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) if n < toRead { return n, syserror.EIO } @@ -174,7 +174,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds curChildOff := relFileOff % childCov for i := startIdx; i < endIdx; i++ { var childPhyBlk uint32 - err := readFromDisk(f.regFile.inode.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) + err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) if err != nil { return read, err } diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go index 213aa3919..181727ef7 100644 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -87,12 +87,14 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { mockDisk := make([]byte, mockBMDiskSize) regFile := regularFile{ inode: inode{ + fs: &filesystem{ + dev: bytes.NewReader(mockDisk), + }, diskInode: &disklayout.InodeNew{ InodeOld: disklayout.InodeOld{ SizeLo: getMockBMFileFize(), }, }, - dev: bytes.NewReader(mockDisk), blkSize: uint64(mockBMBlkSize), }, } diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 054fb42b6..a080cb189 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -41,16 +41,18 @@ func newDentry(in *inode) *dentry { } // IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { +func (d *dentry) IncRef() { d.inode.incRef() } // TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { +func (d *dentry) TryIncRef() bool { return d.inode.tryIncRef() } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { - d.inode.decRef(vfsfs.Impl().(*filesystem)) +func (d *dentry) DecRef() { + // FIXME(b/134676337): filesystem.mu may not be locked as required by + // inode.decRef(). + d.inode.decRef() } diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index f10accafc..4b7d17dc6 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -40,14 +40,14 @@ var _ vfs.FilesystemType = (*FilesystemType)(nil) // Currently there are two ways of mounting an ext(2/3/4) fs: // 1. Specify a mount with our internal special MountType in the OCI spec. // 2. Expose the device to the container and mount it from application layer. -func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReaderAt, error) { +func getDeviceFd(source string, opts vfs.GetFilesystemOptions) (io.ReaderAt, error) { if opts.InternalData == nil { // User mount call. // TODO(b/134676337): Open the device specified by `source` and return that. panic("unimplemented") } - // NewFilesystem call originated from within the sentry. + // GetFilesystem call originated from within the sentry. devFd, ok := opts.InternalData.(int) if !ok { return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device") @@ -91,8 +91,8 @@ func isCompatible(sb disklayout.SuperBlock) bool { return true } -// NewFilesystem implements vfs.FilesystemType.NewFilesystem. -func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { // TODO(b/134676337): Ensure that the user is mounting readonly. If not, // EACCESS should be returned according to mount(2). Filesystem independent // flags (like readonly) are currently not available in pkg/sentry/vfs. @@ -103,7 +103,7 @@ func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials } fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) fs.sb, err = readSuperBlock(dev) if err != nil { return nil, nil, err diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 1aa2bd6a4..e9f756732 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -66,7 +66,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -147,55 +147,54 @@ func TestSeek(t *testing.T) { t.Fatalf("vfsfs.OpenAt failed: %v", err) } - if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { + if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { t.Errorf("expected seek position 0, got %d and error %v", n, err) } - stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{}) + stat, err := fd.Stat(ctx, vfs.StatOptions{}) if err != nil { t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err) } // We should be able to seek beyond the end of file. size := int64(stat.Size) - if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { + if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } - if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { + if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err) } // Make sure negative offsets work with SEEK_CUR. - if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } // Make sure SEEK_END works with regular files. - switch fd.Impl().(type) { - case *regularFileFD: + if _, ok := fd.Impl().(*regularFileFD); ok { // Seek back to 0. - if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { + if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", 0, n, err) } // Seek forward beyond EOF. - if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } } @@ -456,7 +455,7 @@ func TestRead(t *testing.T) { want := make([]byte, 1) for { n, err := f.Read(want) - fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) + fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) if diff := cmp.Diff(got, want); diff != "" { t.Errorf("file data mismatch (-want +got):\n%s", diff) @@ -464,7 +463,7 @@ func TestRead(t *testing.T) { // Make sure there is no more file data left after getting EOF. if n == 0 || err == io.EOF { - if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { + if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image) } @@ -509,27 +508,27 @@ func TestIterDirents(t *testing.T) { } wantDirents := []vfs.Dirent{ - vfs.Dirent{ + { Name: ".", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "..", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "lost+found", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "file.txt", Type: linux.DT_REG, }, - vfs.Dirent{ + { Name: "bigfile.txt", Type: linux.DT_REG, }, - vfs.Dirent{ + { Name: "symlink.txt", Type: linux.DT_LNK, }, @@ -574,7 +573,7 @@ func TestIterDirents(t *testing.T) { } cb := &iterDirentsCb{} - if err = fd.Impl().IterDirents(ctx, cb); err != nil { + if err = fd.IterDirents(ctx, cb); err != nil { t.Fatalf("dir fd.IterDirents() failed: %v", err) } diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index 38b68a2d3..3d3ebaca6 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -99,7 +99,7 @@ func (f *extentFile) buildExtTree() error { func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) { var header disklayout.ExtentHeader off := entry.PhysicalBlock() * f.regFile.inode.blkSize - err := readFromDisk(f.regFile.inode.dev, int64(off), &header) + err := readFromDisk(f.regFile.inode.fs.dev, int64(off), &header) if err != nil { return nil, err } @@ -115,7 +115,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla curEntry = &disklayout.ExtentIdx{} } - err := readFromDisk(f.regFile.inode.dev, int64(off), curEntry) + err := readFromDisk(f.regFile.inode.fs.dev, int64(off), curEntry) if err != nil { return nil, err } @@ -229,7 +229,7 @@ func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byt toRead = len(dst) } - n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], int64(readStart)) + n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], int64(readStart)) if n < toRead { return n, syserror.EIO } diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index 42d0a484b..a2382daa3 100644 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -180,13 +180,15 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, [] mockExtentFile := &extentFile{ regFile: regularFile{ inode: inode{ + fs: &filesystem{ + dev: bytes.NewReader(mockDisk), + }, diskInode: &disklayout.InodeNew{ InodeOld: disklayout.InodeOld{ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root), }, }, blkSize: mockExtentBlkSize, - dev: bytes.NewReader(mockDisk), }, }, } diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index 4d18b28cb..841274daf 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -26,33 +26,14 @@ import ( type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - - // flags is the same as vfs.OpenOptions.Flags which are passed to - // vfs.FilesystemImpl.OpenAt. - // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2), - // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set. - // Only close(2), fstat(2), fstatfs(2) should work. - flags uint32 } func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) + return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } func (fd *fileDescription) inode() *inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode -} - -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil + return fd.vfsfd.Dentry().Impl().(*dentry).inode } // Stat implements vfs.FileDescriptionImpl.Stat. diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 2d15e8aaf..d7e87979a 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -20,6 +20,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -441,3 +442,46 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return nil, err + } + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return "", err + } + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index e6c847a71..b2cc826c7 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -16,7 +16,6 @@ package ext import ( "fmt" - "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -42,13 +41,13 @@ type inode struct { // refs is a reference count. refs is accessed using atomic memory operations. refs int64 + // fs is the containing filesystem. + fs *filesystem + // inodeNum is the inode number of this inode on disk. This is used to // identify inodes within the ext filesystem. inodeNum uint32 - // dev represents the underlying device. Same as filesystem.dev. - dev io.ReaderAt - // blkSize is the fs data block size. Same as filesystem.sb.BlockSize(). blkSize uint64 @@ -81,10 +80,10 @@ func (in *inode) tryIncRef() bool { // decRef decrements the inode ref count and releases the inode resources if // the ref count hits 0. // -// Precondition: Must have locked fs.mu. -func (in *inode) decRef(fs *filesystem) { +// Precondition: Must have locked filesystem.mu. +func (in *inode) decRef() { if refs := atomic.AddInt64(&in.refs, -1); refs == 0 { - delete(fs.inodeCache, in.inodeNum) + delete(in.fs.inodeCache, in.inodeNum) } else if refs < 0 { panic("ext.inode.decRef() called without holding a reference") } @@ -117,8 +116,8 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) { // Build the inode based on its type. inode := inode{ + fs: fs, inodeNum: inodeNum, - dev: fs.dev, blkSize: blkSize, diskInode: diskInode, } @@ -154,11 +153,13 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v if err := in.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } + mnt := rp.Mount() switch in.impl.(type) { case *regularFile: var fd regularFileFD - fd.flags = flags - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *directory: // Can't open directories writably. This check is not necessary for a read @@ -167,8 +168,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v return nil, syserror.EISDIR } var fd directoryFD - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) - fd.flags = flags + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *symlink: if flags&linux.O_PATH == 0 { @@ -176,8 +178,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v return nil, syserror.ELOOP } var fd symlinkFD - fd.flags = flags - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil default: panic(fmt.Sprintf("unknown inode type: %T", in.impl)) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD new file mode 100644 index 000000000..52596c090 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -0,0 +1,60 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "slot_list", + out = "slot_list.go", + package = "kernfs", + prefix = "slot", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*slot", + "Linker": "*slot", + }, +) + +go_library( + name = "kernfs", + srcs = [ + "dynamic_bytes_file.go", + "fd_impl_util.go", + "filesystem.go", + "inode_impl_util.go", + "kernfs.go", + "slot_list.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/fspath", + "//pkg/log", + "//pkg/refs", + "//pkg/sentry/context", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) + +go_test( + name = "kernfs_test", + size = "small", + srcs = ["kernfs_test.go"], + deps = [ + ":kernfs", + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go new file mode 100644 index 000000000..51102ce48 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -0,0 +1,117 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// DynamicBytesFile implements kernfs.Inode and represents a read-only +// file whose contents are backed by a vfs.DynamicBytesSource. +// +// Must be initialized with Init before first use. +type DynamicBytesFile struct { + InodeAttrs + InodeNoopRefCount + InodeNotDirectory + InodeNotSymlink + + data vfs.DynamicBytesSource +} + +// Init intializes a dynamic bytes file. +func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource) { + f.InodeAttrs.Init(creds, ino, linux.ModeRegular|0444) + f.data = data +} + +// Open implements Inode.Open. +func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &DynamicBytesFD{} + fd.Init(rp.Mount(), vfsd, f.data, flags) + return &fd.vfsfd, nil +} + +// SetStat implements Inode.SetStat. +func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error { + // DynamicBytesFiles are immutable. + return syserror.EPERM +} + +// DynamicBytesFD implements vfs.FileDescriptionImpl for an FD backed by a +// DynamicBytesFile. +// +// Must be initialized with Init before first use. +type DynamicBytesFD struct { + vfs.FileDescriptionDefaultImpl + vfs.DynamicBytesFileDescriptionImpl + + vfsfd vfs.FileDescription + inode Inode +} + +// Init initializes a DynamicBytesFD. +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) { + m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + fd.inode = d.Impl().(*Dentry).inode + fd.SetDataSource(data) + fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}) +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence) +} + +// Read implmenets vfs.FileDescriptionImpl.Read. +func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts) +} + +// PRead implmenets vfs.FileDescriptionImpl.PRead. +func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *DynamicBytesFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return fd.FileDescriptionDefaultImpl.Write(ctx, src, opts) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.FileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *DynamicBytesFD) Release() {} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(fs), nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { + // DynamicBytesFiles are immutable. + return syserror.EPERM +} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go new file mode 100644 index 000000000..bd402330f --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -0,0 +1,193 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory +// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not +// compatible with dynamic directories. +// +// Note that GenericDirectoryFD holds a lock over OrderedChildren while calling +// IterDirents callback. The IterDirents callback therefore cannot hash or +// unhash children, or recursively call IterDirents on the same underlying +// inode. +// +// Must be initialize with Init before first use. +type GenericDirectoryFD struct { + vfs.FileDescriptionDefaultImpl + vfs.DirectoryFileDescriptionDefaultImpl + + vfsfd vfs.FileDescription + children *OrderedChildren + off int64 +} + +// Init initializes a GenericDirectoryFD. +func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) { + m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + fd.children = children + fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}) +} + +// VFSFileDescription returns a pointer to the vfs.FileDescription representing +// this object. +func (fd *GenericDirectoryFD) VFSFileDescription() *vfs.FileDescription { + return &fd.vfsfd +} + +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *GenericDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.FileDescriptionDefaultImpl.ConfigureMMap(ctx, opts) +} + +// Read implmenets vfs.FileDescriptionImpl.Read. +func (fd *GenericDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.Read(ctx, dst, opts) +} + +// PRead implmenets vfs.FileDescriptionImpl.PRead. +func (fd *GenericDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.PRead(ctx, dst, offset, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *GenericDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.Write(ctx, src, opts) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) +} + +// Release implements vfs.FileDecriptionImpl.Release. +func (fd *GenericDirectoryFD) Release() {} + +func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem() +} + +func (fd *GenericDirectoryFD) inode() Inode { + return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode +} + +// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds +// o.mu when calling cb. +func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + vfsFS := fd.filesystem() + fs := vfsFS.Impl().(*Filesystem) + vfsd := fd.vfsfd.VirtualDentry().Dentry() + + fs.mu.Lock() + defer fs.mu.Unlock() + + // Handle ".". + if fd.off == 0 { + stat := fd.inode().Stat(vfsFS) + dirent := vfs.Dirent{ + Name: ".", + Type: linux.DT_DIR, + Ino: stat.Ino, + NextOff: 1, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + // Handle "..". + if fd.off == 1 { + parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode + stat := parentInode.Stat(vfsFS) + dirent := vfs.Dirent{ + Name: "..", + Type: linux.FileMode(stat.Mode).DirentType(), + Ino: stat.Ino, + NextOff: 2, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + // Handle static children. + fd.children.mu.RLock() + defer fd.children.mu.RUnlock() + // fd.off accounts for "." and "..", but fd.children do not track + // these. + childIdx := fd.off - 2 + for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { + inode := it.Dentry.Impl().(*Dentry).inode + stat := inode.Stat(vfsFS) + dirent := vfs.Dirent{ + Name: it.Name, + Type: linux.FileMode(stat.Mode).DirentType(), + Ino: stat.Ino, + NextOff: fd.off + 1, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + return nil +} + +// Seek implements vfs.FileDecriptionImpl.Seek. +func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fs := fd.filesystem().Impl().(*Filesystem) + fs.mu.Lock() + defer fs.mu.Unlock() + + switch whence { + case linux.SEEK_SET: + // Use offset as given. + case linux.SEEK_CUR: + offset += fd.off + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.filesystem() + inode := fd.inode() + return inode.Stat(fs), nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + fs := fd.filesystem() + inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return inode.SetStat(fs, opts) +} diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go new file mode 100644 index 000000000..3cbbe4b20 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -0,0 +1,743 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements vfs.FilesystemImpl for kernfs. + +package kernfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// stepExistingLocked resolves rp.Component() in parent directory vfsd. +// +// stepExistingLocked is loosely analogous to fs/namei.c:walk_component(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// +// Postcondition: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) (*vfs.Dentry, error) { + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, syserror.ENOTDIR + } + // Directory searchable? + if err := d.inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } +afterSymlink: + d.dirMu.Lock() + nextVFSD, err := rp.ResolveComponent(vfsd) + d.dirMu.Unlock() + if err != nil { + return nil, err + } + if nextVFSD != nil { + // Cached dentry exists, revalidate. + next := nextVFSD.Impl().(*Dentry) + if !next.inode.Valid(ctx) { + d.dirMu.Lock() + rp.VirtualFilesystem().ForceDeleteDentry(nextVFSD) + d.dirMu.Unlock() + fs.deferDecRef(nextVFSD) // Reference from Lookup. + nextVFSD = nil + } + } + if nextVFSD == nil { + // Dentry isn't cached; it either doesn't exist or failed + // revalidation. Attempt to resolve it via Lookup. + name := rp.Component() + var err error + nextVFSD, err = d.inode.Lookup(ctx, name) + // Reference on nextVFSD dropped by a corresponding Valid. + if err != nil { + return nil, err + } + d.InsertChild(name, nextVFSD) + } + next := nextVFSD.Impl().(*Dentry) + + // Resolve any symlink at current path component. + if rp.ShouldFollowSymlink() && d.isSymlink() { + // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks". + target, err := next.inode.Readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + goto afterSymlink + + } + rp.Advance() + return nextVFSD, nil +} + +// walkExistingLocked resolves rp to an existing file. +// +// walkExistingLocked is loosely analogous to Linux's +// fs/namei.c:path_lookupat(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. +// +// Postconditions: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { + vfsd := rp.Start() + for !rp.Done() { + var err error + vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) + if err != nil { + return nil, nil, err + } + } + d := vfsd.Impl().(*Dentry) + if rp.MustBeDir() && !d.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, d.inode, nil +} + +// walkParentDirLocked resolves all but the last path component of rp to an +// existing directory. It does not check that the returned directory is +// searchable by the provider of rp. +// +// walkParentDirLocked is loosely analogous to Linux's +// fs/namei.c:path_parentat(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// +// Postconditions: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { + vfsd := rp.Start() + for !rp.Final() { + var err error + vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) + if err != nil { + return nil, nil, err + } + } + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, d.inode, nil +} + +// checkCreateLocked checks that a file named rp.Component() may be created in +// directory parentVFSD, then returns rp.Component(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. parentInode +// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true. +func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) { + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return "", err + } + pc := rp.Component() + if pc == "." || pc == ".." { + return "", syserror.EEXIST + } + childVFSD, err := rp.ResolveChild(parentVFSD, pc) + if err != nil { + return "", err + } + if childVFSD != nil { + return "", syserror.EEXIST + } + if parentVFSD.IsDisowned() { + return "", syserror.ENOENT + } + return pc, nil +} + +// checkDeleteLocked checks that the file represented by vfsd may be deleted. +// +// Preconditions: Filesystem.mu must be locked for at least reading. +func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error { + parentVFSD := vfsd.Parent() + if parentVFSD == nil { + return syserror.EBUSY + } + if parentVFSD.IsDisowned() { + return syserror.ENOENT + } + if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + return nil +} + +// checkRenameLocked checks that a rename operation may be performed on the +// target dentry across the given set of parent directories. The target dentry +// may be nil. +// +// Precondition: isDir(dstInode) == true. +func checkRenameLocked(creds *auth.Credentials, src, dstDir *vfs.Dentry, dstInode Inode) error { + srcDir := src.Parent() + if srcDir == nil { + return syserror.EBUSY + } + if srcDir.IsDisowned() { + return syserror.ENOENT + } + if dstDir.IsDisowned() { + return syserror.ENOENT + } + // Check for creation permissions on dst dir. + if err := dstInode.CheckPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + + return nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *Filesystem) Release() { +} + +// Sync implements vfs.FilesystemImpl.Sync. +func (fs *Filesystem) Sync(ctx context.Context) error { + // All filesystem state is in-memory. + return nil +} + +// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. +func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { + fs.mu.RLock() + defer fs.processDeferredDecRefs() + defer fs.mu.RUnlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + if err != nil { + return nil, err + } + + if opts.CheckSearchable { + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, syserror.ENOTDIR + } + if err := inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + } + vfsd.IncRef() // Ownership transferred to caller. + return vfsd, nil +} + +// LinkAt implements vfs.FilesystemImpl.LinkAt. +func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if rp.Mount() != vd.Mount() { + return syserror.EXDEV + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + + d := vd.Dentry().Impl().(*Dentry) + if d.isDir() { + return syserror.EPERM + } + + child, err := parentInode.NewLink(ctx, pc, d.inode) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// MkdirAt implements vfs.FilesystemImpl.MkdirAt. +func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + child, err := parentInode.NewDir(ctx, pc, opts) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// MknodAt implements vfs.FilesystemImpl.MknodAt. +func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + new, err := parentInode.NewNode(ctx, pc, opts) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, new) + return nil +} + +// OpenAt implements vfs.FilesystemImpl.OpenAt. +func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // Filter out flags that are not supported by kernfs. O_DIRECTORY and + // O_NOFOLLOW have no effect here (they're handled by VFS by setting + // appropriate bits in rp), but are returned by + // FileDescriptionImpl.StatusFlags(). + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW + ats := vfs.AccessTypesForOpenFlags(opts.Flags) + + // Do not create new file. + if opts.Flags&linux.O_CREAT == 0 { + fs.mu.RLock() + defer fs.processDeferredDecRefs() + defer fs.mu.RUnlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + if err != nil { + return nil, err + } + if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return inode.Open(rp, vfsd, opts.Flags) + } + + // May create new file. + mustCreate := opts.Flags&linux.O_EXCL != 0 + vfsd := rp.Start() + inode := vfsd.Impl().(*Dentry).inode + fs.mu.Lock() + defer fs.mu.Unlock() + if rp.Done() { + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + if mustCreate { + return nil, syserror.EEXIST + } + if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return inode.Open(rp, vfsd, opts.Flags) + } +afterTrailingSymlink: + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return nil, err + } + // Check for search permission in the parent directory. + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + // Reject attempts to open directories with O_CREAT. + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + pc := rp.Component() + if pc == "." || pc == ".." { + return nil, syserror.EISDIR + } + // Determine whether or not we need to create a file. + childVFSD, err := rp.ResolveChild(parentVFSD, pc) + if err != nil { + return nil, err + } + if childVFSD == nil { + // Already checked for searchability above; now check for writability. + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return nil, err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + defer rp.Mount().EndWrite() + // Create and open the child. + child, err := parentInode.NewFile(ctx, pc, opts) + if err != nil { + return nil, err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return child.Impl().(*Dentry).inode.Open(rp, child, opts.Flags) + } + // Open existing file or follow symlink. + if mustCreate { + return nil, syserror.EEXIST + } + childDentry := childVFSD.Impl().(*Dentry) + childInode := childDentry.inode + if rp.ShouldFollowSymlink() { + if childDentry.isSymlink() { + target, err := childInode.Readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + // rp.Final() may no longer be true since we now need to resolve the + // symlink target. + goto afterTrailingSymlink + } + } + if err := childInode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return childInode.Open(rp, childVFSD, opts.Flags) +} + +// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. +func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { + fs.mu.RLock() + d, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return "", err + } + if !d.Impl().(*Dentry).isSymlink() { + return "", syserror.EINVAL + } + return inode.Readlink(ctx) +} + +// RenameAt implements vfs.FilesystemImpl.RenameAt. +func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { + noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 + exchange := opts.Flags&linux.RENAME_EXCHANGE != 0 + whiteout := opts.Flags&linux.RENAME_WHITEOUT != 0 + if exchange && (noReplace || whiteout) { + // Can't specify RENAME_NOREPLACE or RENAME_WHITEOUT with RENAME_EXCHANGE. + return syserror.EINVAL + } + if exchange || whiteout { + // Exchange and Whiteout flags are not supported on kernfs. + return syserror.EINVAL + } + + fs.mu.Lock() + defer fs.mu.Lock() + + mnt := rp.Mount() + if mnt != vd.Mount() { + return syserror.EXDEV + } + + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + + dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + + srcVFSD := vd.Dentry() + srcDirVFSD := srcVFSD.Parent() + + // Can we remove the src dentry? + if err := checkDeleteLocked(rp, srcVFSD); err != nil { + return err + } + + // Can we create the dst dentry? + var dstVFSD *vfs.Dentry + pc, err := checkCreateLocked(rp, dstDirVFSD, dstDirInode) + switch err { + case nil: + // Ok, continue with rename as replacement. + case syserror.EEXIST: + if noReplace { + // Won't overwrite existing node since RENAME_NOREPLACE was requested. + return syserror.EEXIST + } + dstVFSD, err = rp.ResolveChild(dstDirVFSD, pc) + if err != nil { + panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD)) + } + default: + return err + } + + mntns := vfs.MountNamespaceFromContext(ctx) + virtfs := rp.VirtualFilesystem() + + srcDirDentry := srcDirVFSD.Impl().(*Dentry) + dstDirDentry := dstDirVFSD.Impl().(*Dentry) + + // We can't deadlock here due to lock ordering because we're protected from + // concurrent renames by fs.mu held for writing. + srcDirDentry.dirMu.Lock() + defer srcDirDentry.dirMu.Unlock() + dstDirDentry.dirMu.Lock() + defer dstDirDentry.dirMu.Unlock() + + if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil { + return err + } + srcDirInode := srcDirDentry.inode + replaced, err := srcDirInode.Rename(ctx, srcVFSD.Name(), pc, srcVFSD, dstDirVFSD) + if err != nil { + virtfs.AbortRenameDentry(srcVFSD, dstVFSD) + return err + } + virtfs.CommitRenameReplaceDentry(srcVFSD, dstDirVFSD, pc, replaced) + return nil +} + +// RmdirAt implements vfs.FilesystemImpl.RmdirAt. +func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + if err := checkDeleteLocked(rp, vfsd); err != nil { + return err + } + if !vfsd.Impl().(*Dentry).isDir() { + return syserror.ENOTDIR + } + if inode.HasChildren() { + return syserror.ENOTEMPTY + } + virtfs := rp.VirtualFilesystem() + parentDentry := vfsd.Parent().Impl().(*Dentry) + parentDentry.dirMu.Lock() + defer parentDentry.dirMu.Unlock() + if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { + return err + } + if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil { + virtfs.AbortDeleteDentry(vfsd) + return err + } + virtfs.CommitDeleteDentry(vfsd) + return nil +} + +// SetStatAt implements vfs.FilesystemImpl.SetStatAt. +func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { + fs.mu.RLock() + _, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + if opts.Stat.Mask == 0 { + return nil + } + return inode.SetStat(fs.VFSFilesystem(), opts) +} + +// StatAt implements vfs.FilesystemImpl.StatAt. +func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { + fs.mu.RLock() + _, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return linux.Statx{}, err + } + return inode.Stat(fs.VFSFilesystem()), nil +} + +// StatFSAt implements vfs.FilesystemImpl.StatFSAt. +func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return linux.Statfs{}, err + } + // TODO: actually implement statfs + return linux.Statfs{}, syserror.ENOSYS +} + +// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. +func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + child, err := parentInode.NewSymlink(ctx, pc, target) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. +func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + vfsd, _, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + if err := checkDeleteLocked(rp, vfsd); err != nil { + return err + } + if vfsd.Impl().(*Dentry).isDir() { + return syserror.EISDIR + } + virtfs := rp.VirtualFilesystem() + parentDentry := vfsd.Parent().Impl().(*Dentry) + parentDentry.dirMu.Lock() + defer parentDentry.dirMu.Unlock() + if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { + return err + } + if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil { + virtfs.AbortDeleteDentry(vfsd) + return err + } + virtfs.CommitDeleteDentry(vfsd) + return nil +} + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return nil, err + } + // kernfs currently does not support extended attributes. + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return "", err + } + // kernfs currently does not support extended attributes. + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go new file mode 100644 index 000000000..7b45b702a --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -0,0 +1,492 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "fmt" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// InodeNoopRefCount partially implements the Inode interface, specifically the +// inodeRefs sub interface. InodeNoopRefCount implements a simple reference +// count for inodes, performing no extra actions when references are obtained or +// released. This is suitable for simple file inodes that don't reference any +// resources. +type InodeNoopRefCount struct { +} + +// IncRef implements Inode.IncRef. +func (n *InodeNoopRefCount) IncRef() { +} + +// DecRef implements Inode.DecRef. +func (n *InodeNoopRefCount) DecRef() { +} + +// TryIncRef implements Inode.TryIncRef. +func (n *InodeNoopRefCount) TryIncRef() bool { + return true +} + +// Destroy implements Inode.Destroy. +func (n *InodeNoopRefCount) Destroy() { +} + +// InodeDirectoryNoNewChildren partially implements the Inode interface. +// InodeDirectoryNoNewChildren represents a directory inode which does not +// support creation of new children. +type InodeDirectoryNoNewChildren struct{} + +// NewFile implements Inode.NewFile. +func (*InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewDir implements Inode.NewDir. +func (*InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewLink implements Inode.NewLink. +func (*InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewSymlink implements Inode.NewSymlink. +func (*InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewNode implements Inode.NewNode. +func (*InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// InodeNotDirectory partially implements the Inode interface, specifically the +// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not +// represent directories can embed this to provide no-op implementations for +// directory-related functions. +type InodeNotDirectory struct { +} + +// HasChildren implements Inode.HasChildren. +func (*InodeNotDirectory) HasChildren() bool { + return false +} + +// NewFile implements Inode.NewFile. +func (*InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { + panic("NewFile called on non-directory inode") +} + +// NewDir implements Inode.NewDir. +func (*InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { + panic("NewDir called on non-directory inode") +} + +// NewLink implements Inode.NewLinkink. +func (*InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { + panic("NewLink called on non-directory inode") +} + +// NewSymlink implements Inode.NewSymlink. +func (*InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + panic("NewSymlink called on non-directory inode") +} + +// NewNode implements Inode.NewNode. +func (*InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + panic("NewNode called on non-directory inode") +} + +// Unlink implements Inode.Unlink. +func (*InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { + panic("Unlink called on non-directory inode") +} + +// RmDir implements Inode.RmDir. +func (*InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { + panic("RmDir called on non-directory inode") +} + +// Rename implements Inode.Rename. +func (*InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { + panic("Rename called on non-directory inode") +} + +// Lookup implements Inode.Lookup. +func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + panic("Lookup called on non-directory inode") +} + +// Valid implements Inode.Valid. +func (*InodeNotDirectory) Valid(context.Context) bool { + return true +} + +// InodeNoDynamicLookup partially implements the Inode interface, specifically +// the inodeDynamicLookup sub interface. Directory inodes that do not support +// dymanic entries (i.e. entries that are not "hashed" into the +// vfs.Dentry.children) can embed this to provide no-op implementations for +// functions related to dynamic entries. +type InodeNoDynamicLookup struct{} + +// Lookup implements Inode.Lookup. +func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + return nil, syserror.ENOENT +} + +// Valid implements Inode.Valid. +func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool { + return true +} + +// InodeNotSymlink partially implements the Inode interface, specifically the +// inodeSymlink sub interface. All inodes that are not symlinks may embed this +// to return the appropriate errors from symlink-related functions. +type InodeNotSymlink struct{} + +// Readlink implements Inode.Readlink. +func (*InodeNotSymlink) Readlink(context.Context) (string, error) { + return "", syserror.EINVAL +} + +// InodeAttrs partially implements the Inode interface, specifically the +// inodeMetadata sub interface. InodeAttrs provides functionality related to +// inode attributes. +// +// Must be initialized by Init prior to first use. +type InodeAttrs struct { + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 +} + +// Init initializes this InodeAttrs. +func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMode) { + if mode.FileType() == 0 { + panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) + } + + nlink := uint32(1) + if mode.FileType() == linux.ModeDirectory { + nlink = 2 + } + atomic.StoreUint64(&a.ino, ino) + atomic.StoreUint32(&a.mode, uint32(mode)) + atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) + atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) + atomic.StoreUint32(&a.nlink, nlink) +} + +// Mode implements Inode.Mode. +func (a *InodeAttrs) Mode() linux.FileMode { + return linux.FileMode(atomic.LoadUint32(&a.mode)) +} + +// Stat partially implements Inode.Stat. Note that this function doesn't provide +// all the stat fields, and the embedder should consider extending the result +// with filesystem-specific fields. +func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx { + var stat linux.Statx + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Ino = atomic.LoadUint64(&a.ino) + stat.Mode = uint16(a.Mode()) + stat.UID = atomic.LoadUint32(&a.uid) + stat.GID = atomic.LoadUint32(&a.gid) + stat.Nlink = atomic.LoadUint32(&a.nlink) + + // TODO: Implement other stat fields like timestamps. + + return stat +} + +// SetStat implements Inode.SetStat. +func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { + stat := opts.Stat + if stat.Mask&linux.STATX_MODE != 0 { + for { + old := atomic.LoadUint32(&a.mode) + new := old | uint32(stat.Mode & ^uint16(linux.S_IFMT)) + if swapped := atomic.CompareAndSwapUint32(&a.mode, old, new); swapped { + break + } + } + } + + if stat.Mask&linux.STATX_UID != 0 { + atomic.StoreUint32(&a.uid, stat.UID) + } + if stat.Mask&linux.STATX_GID != 0 { + atomic.StoreUint32(&a.gid, stat.GID) + } + + // Note that not all fields are modifiable. For example, the file type and + // inode numbers are immutable after node creation. + + // TODO: Implement other stat fields like timestamps. + + return nil +} + +// CheckPermissions implements Inode.CheckPermissions. +func (a *InodeAttrs) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + mode := a.Mode() + return vfs.GenericCheckPermissions( + creds, + ats, + mode.FileType() == linux.ModeDirectory, + uint16(mode), + auth.KUID(atomic.LoadUint32(&a.uid)), + auth.KGID(atomic.LoadUint32(&a.gid)), + ) +} + +// IncLinks implements Inode.IncLinks. +func (a *InodeAttrs) IncLinks(n uint32) { + if atomic.AddUint32(&a.nlink, n) <= n { + panic("InodeLink.IncLinks called with no existing links") + } +} + +// DecLinks implements Inode.DecLinks. +func (a *InodeAttrs) DecLinks() { + if nlink := atomic.AddUint32(&a.nlink, ^uint32(0)); nlink == ^uint32(0) { + // Negative overflow + panic("Inode.DecLinks called at 0 links") + } +} + +type slot struct { + Name string + Dentry *vfs.Dentry + slotEntry +} + +// OrderedChildrenOptions contains initialization options for OrderedChildren. +type OrderedChildrenOptions struct { + // Writable indicates whether vfs.FilesystemImpl methods implemented by + // OrderedChildren may modify the tracked children. This applies to + // operations related to rename, unlink and rmdir. If an OrderedChildren is + // not writable, these operations all fail with EPERM. + Writable bool +} + +// OrderedChildren partially implements the Inode interface. OrderedChildren can +// be embedded in directory inodes to keep track of the children in the +// directory, and can then be used to implement a generic directory FD -- see +// GenericDirectoryFD. OrderedChildren is not compatible with dynamic +// directories. +// +// Must be initialize with Init before first use. +type OrderedChildren struct { + refs.AtomicRefCount + + // Can children be modified by user syscalls? It set to false, interface + // methods that would modify the children return EPERM. Immutable. + writable bool + + mu sync.RWMutex + order slotList + set map[string]*slot +} + +// Init initializes an OrderedChildren. +func (o *OrderedChildren) Init(opts OrderedChildrenOptions) { + o.writable = opts.Writable + o.set = make(map[string]*slot) +} + +// DecRef implements Inode.DecRef. +func (o *OrderedChildren) DecRef() { + o.AtomicRefCount.DecRefWithDestructor(o.Destroy) +} + +// Destroy cleans up resources referenced by this OrderedChildren. +func (o *OrderedChildren) Destroy() { + o.mu.Lock() + defer o.mu.Unlock() + o.order.Reset() + o.set = nil +} + +// Populate inserts children into this OrderedChildren, and d's dentry +// cache. Populate returns the number of directories inserted, which the caller +// may use to update the link count for the parent directory. +// +// Precondition: d.Impl() must be a kernfs Dentry. d must represent a directory +// inode. children must not contain any conflicting entries already in o. +func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 { + var links uint32 + for name, child := range children { + if child.isDir() { + links++ + } + if err := o.Insert(name, child.VFSDentry()); err != nil { + panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d)) + } + d.InsertChild(name, child.VFSDentry()) + } + return links +} + +// HasChildren implements Inode.HasChildren. +func (o *OrderedChildren) HasChildren() bool { + o.mu.RLock() + defer o.mu.RUnlock() + return len(o.set) > 0 +} + +// Insert inserts child into o. This ignores the writability of o, as this is +// not part of the vfs.FilesystemImpl interface, and is a lower-level operation. +func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error { + o.mu.Lock() + defer o.mu.Unlock() + if _, ok := o.set[name]; ok { + return syserror.EEXIST + } + s := &slot{ + Name: name, + Dentry: child, + } + o.order.PushBack(s) + o.set[name] = s + return nil +} + +// Precondition: caller must hold o.mu for writing. +func (o *OrderedChildren) removeLocked(name string) { + if s, ok := o.set[name]; ok { + delete(o.set, name) + o.order.Remove(s) + } +} + +// Precondition: caller must hold o.mu for writing. +func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry { + if s, ok := o.set[name]; ok { + // Existing slot with given name, simply replace the dentry. + var old *vfs.Dentry + old, s.Dentry = s.Dentry, new + return old + } + + // No existing slot with given name, create and hash new slot. + s := &slot{ + Name: name, + Dentry: new, + } + o.order.PushBack(s) + o.set[name] = s + return nil +} + +// Precondition: caller must hold o.mu for reading or writing. +func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error { + s, ok := o.set[name] + if !ok { + return syserror.ENOENT + } + if s.Dentry != child { + panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child)) + } + return nil +} + +// Unlink implements Inode.Unlink. +func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error { + if !o.writable { + return syserror.EPERM + } + o.mu.Lock() + defer o.mu.Unlock() + if err := o.checkExistingLocked(name, child); err != nil { + return err + } + o.removeLocked(name) + return nil +} + +// Rmdir implements Inode.Rmdir. +func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error { + // We're not responsible for checking that child is a directory, that it's + // empty, or updating any link counts; so this is the same as unlink. + return o.Unlink(ctx, name, child) +} + +type renameAcrossDifferentImplementationsError struct{} + +func (renameAcrossDifferentImplementationsError) Error() string { + return "rename across inodes with different implementations" +} + +// Rename implements Inode.Rename. +// +// Precondition: Rename may only be called across two directory inodes with +// identical implementations of Rename. Practically, this means filesystems that +// implement Rename by embedding OrderedChildren for any directory +// implementation must use OrderedChildren for all directory implementations +// that will support Rename. +// +// Postcondition: reference on any replaced dentry transferred to caller. +func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) { + dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren) + if !ok { + return nil, renameAcrossDifferentImplementationsError{} + } + if !o.writable || !dst.writable { + return nil, syserror.EPERM + } + // Note: There's a potential deadlock below if concurrent calls to Rename + // refer to the same src and dst directories in reverse. We avoid any + // ordering issues because the caller is required to serialize concurrent + // calls to Rename in accordance with the interface declaration. + o.mu.Lock() + defer o.mu.Unlock() + if dst != o { + dst.mu.Lock() + defer dst.mu.Unlock() + } + if err := o.checkExistingLocked(oldname, child); err != nil { + return nil, err + } + replaced := dst.replaceChildLocked(newname, child) + return replaced, nil +} + +// nthLocked returns an iterator to the nth child tracked by this object. The +// iterator is valid until the caller releases o.mu. Returns nil if the +// requested index falls out of bounds. +// +// Preconditon: Caller must hold o.mu for reading. +func (o *OrderedChildren) nthLocked(i int64) *slot { + for it := o.order.Front(); it != nil && i >= 0; it = it.Next() { + if i == 0 { + return it + } + i-- + } + return nil +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go new file mode 100644 index 000000000..bb01c3d01 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -0,0 +1,405 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package kernfs provides the tools to implement inode-based filesystems. +// Kernfs has two main features: +// +// 1. The Inode interface, which maps VFS2's path-based filesystem operations to +// specific filesystem nodes. Kernfs uses the Inode interface to provide a +// blanket implementation for the vfs.FilesystemImpl. Kernfs also serves as +// the synchronization mechanism for all filesystem operations by holding a +// filesystem-wide lock across all operations. +// +// 2. Various utility types which provide generic implementations for various +// parts of the Inode and vfs.FileDescription interfaces. Client filesystems +// based on kernfs can embed the appropriate set of these to avoid having to +// reimplement common filesystem operations. See inode_impl_util.go and +// fd_impl_util.go. +// +// Reference Model: +// +// Kernfs dentries represents named pointers to inodes. Dentries and inode have +// independent lifetimes and reference counts. A child dentry unconditionally +// holds a reference on its parent directory's dentry. A dentry also holds a +// reference on the inode it points to. Multiple dentries can point to the same +// inode (for example, in the case of hardlinks). File descriptors hold a +// reference to the dentry they're opened on. +// +// Dentries are guaranteed to exist while holding Filesystem.mu for +// reading. Dropping dentries require holding Filesystem.mu for writing. To +// queue dentries for destruction from a read critical section, see +// Filesystem.deferDecRef. +// +// Lock ordering: +// +// kernfs.Filesystem.mu +// kernfs.Dentry.dirMu +// vfs.VirtualFilesystem.mountMu +// vfs.Dentry.mu +// kernfs.Filesystem.droppedDentriesMu +// (inode implementation locks, if any) +package kernfs + +import ( + "fmt" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// FilesystemType implements vfs.FilesystemType. +type FilesystemType struct{} + +// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory +// filesystem. Concrete implementations are expected to embed this in their own +// Filesystem type. +type Filesystem struct { + vfsfs vfs.Filesystem + + droppedDentriesMu sync.Mutex + + // droppedDentries is a list of dentries waiting to be DecRef()ed. This is + // used to defer dentry destruction until mu can be acquired for + // writing. Protected by droppedDentriesMu. + droppedDentries []*vfs.Dentry + + // mu synchronizes the lifetime of Dentries on this filesystem. Holding it + // for reading guarantees continued existence of any resolved dentries, but + // the dentry tree may be modified. + // + // Kernfs dentries can only be DecRef()ed while holding mu for writing. For + // example: + // + // fs.mu.Lock() + // defer fs.mu.Unlock() + // ... + // dentry1.DecRef() + // defer dentry2.DecRef() // Ok, will run before Unlock. + // + // If discarding dentries in a read context, use Filesystem.deferDecRef. For + // example: + // + // fs.mu.RLock() + // fs.mu.processDeferredDecRefs() + // defer fs.mu.RUnlock() + // ... + // fs.deferDecRef(dentry) + mu sync.RWMutex + + // nextInoMinusOne is used to to allocate inode numbers on this + // filesystem. Must be accessed by atomic operations. + nextInoMinusOne uint64 +} + +// deferDecRef defers dropping a dentry ref until the next call to +// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu. +// +// Precondition: d must not already be pending destruction. +func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { + fs.droppedDentriesMu.Lock() + fs.droppedDentries = append(fs.droppedDentries, d) + fs.droppedDentriesMu.Unlock() +} + +// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the +// droppedDentries list. See comment on Filesystem.mu. +func (fs *Filesystem) processDeferredDecRefs() { + fs.mu.Lock() + fs.processDeferredDecRefsLocked() + fs.mu.Unlock() +} + +// Precondition: fs.mu must be held for writing. +func (fs *Filesystem) processDeferredDecRefsLocked() { + fs.droppedDentriesMu.Lock() + for _, d := range fs.droppedDentries { + d.DecRef() + } + fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse. + fs.droppedDentriesMu.Unlock() +} + +// Init initializes a kernfs filesystem. This should be called from during +// vfs.FilesystemType.NewFilesystem for the concrete filesystem embedding +// kernfs. +func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem) { + fs.vfsfs.Init(vfsObj, fs) +} + +// VFSFilesystem returns the generic vfs filesystem object. +func (fs *Filesystem) VFSFilesystem() *vfs.Filesystem { + return &fs.vfsfs +} + +// NextIno allocates a new inode number on this filesystem. +func (fs *Filesystem) NextIno() uint64 { + return atomic.AddUint64(&fs.nextInoMinusOne, 1) +} + +// These consts are used in the Dentry.flags field. +const ( + // Dentry points to a directory inode. + dflagsIsDir = 1 << iota + + // Dentry points to a symlink inode. + dflagsIsSymlink +) + +// Dentry implements vfs.DentryImpl. +// +// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a +// named reference to an inode. A dentry generally lives as long as it's part of +// a mounted filesystem tree. Kernfs doesn't cache dentries once all references +// to them are removed. Dentries hold a single reference to the inode they point +// to, and child dentries hold a reference on their parent. +// +// Must be initialized by Init prior to first use. +type Dentry struct { + refs.AtomicRefCount + + vfsd vfs.Dentry + inode Inode + + refs uint64 + + // flags caches useful information about the dentry from the inode. See the + // dflags* consts above. Must be accessed by atomic ops. + flags uint32 + + // dirMu protects vfsd.children for directory dentries. + dirMu sync.Mutex +} + +// Init initializes this dentry. +// +// Precondition: Caller must hold a reference on inode. +// +// Postcondition: Caller's reference on inode is transferred to the dentry. +func (d *Dentry) Init(inode Inode) { + d.vfsd.Init(d) + d.inode = inode + ftype := inode.Mode().FileType() + if ftype == linux.ModeDirectory { + d.flags |= dflagsIsDir + } + if ftype == linux.ModeSymlink { + d.flags |= dflagsIsSymlink + } +} + +// VFSDentry returns the generic vfs dentry for this kernfs dentry. +func (d *Dentry) VFSDentry() *vfs.Dentry { + return &d.vfsd +} + +// isDir checks whether the dentry points to a directory inode. +func (d *Dentry) isDir() bool { + return atomic.LoadUint32(&d.flags)&dflagsIsDir != 0 +} + +// isSymlink checks whether the dentry points to a symlink inode. +func (d *Dentry) isSymlink() bool { + return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef() { + d.AtomicRefCount.DecRefWithDestructor(d.destroy) +} + +// Precondition: Dentry must be removed from VFS' dentry cache. +func (d *Dentry) destroy() { + d.inode.DecRef() // IncRef from Init. + d.inode = nil + if parent := d.vfsd.Parent(); parent != nil { + parent.DecRef() // IncRef from Dentry.InsertChild. + } +} + +// InsertChild inserts child into the vfs dentry cache with the given name under +// this dentry. This does not update the directory inode, so calling this on +// it's own isn't sufficient to insert a child into a directory. InsertChild +// updates the link count on d if required. +// +// Precondition: d must represent a directory inode. +func (d *Dentry) InsertChild(name string, child *vfs.Dentry) { + if !d.isDir() { + panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) + } + vfsDentry := d.VFSDentry() + vfsDentry.IncRef() // DecRef in child's Dentry.destroy. + d.dirMu.Lock() + vfsDentry.InsertChild(child, name) + d.dirMu.Unlock() +} + +// The Inode interface maps filesystem-level operations that operate on paths to +// equivalent operations on specific filesystem nodes. +// +// The interface methods are groups into logical categories as sub interfaces +// below. Generally, an implementation for each sub interface can be provided by +// embedding an appropriate type from inode_impl_utils.go. The sub interfaces +// are purely organizational. Methods declared directly in the main interface +// have no generic implementations, and should be explicitly provided by the +// client filesystem. +// +// Generally, implementations are not responsible for tasks that are common to +// all filesystems. These include: +// +// - Checking that dentries passed to methods are of the appropriate file type. +// - Checking permissions. +// - Updating link and reference counts. +// +// Specific responsibilities of implementations are documented below. +type Inode interface { + // Methods related to reference counting. A generic implementation is + // provided by InodeNoopRefCount. These methods are generally called by the + // equivalent Dentry methods. + inodeRefs + + // Methods related to node metadata. A generic implementation is provided by + // InodeAttrs. + inodeMetadata + + // Method for inodes that represent symlink. InodeNotSymlink provides a + // blanket implementation for all non-symlink inodes. + inodeSymlink + + // Method for inodes that represent directories. InodeNotDirectory provides + // a blanket implementation for all non-directory inodes. + inodeDirectory + + // Method for inodes that represent dynamic directories and their + // children. InodeNoDynamicLookup provides a blanket implementation for all + // non-dynamic-directory inodes. + inodeDynamicLookup + + // Open creates a file description for the filesystem object represented by + // this inode. The returned file description should hold a reference on the + // inode for its lifetime. + // + // Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry. + Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) +} + +type inodeRefs interface { + IncRef() + DecRef() + TryIncRef() bool + // Destroy is called when the inode reaches zero references. Destroy release + // all resources (references) on objects referenced by the inode, including + // any child dentries. + Destroy() +} + +type inodeMetadata interface { + // CheckPermissions checks that creds may access this inode for the + // requested access type, per the the rules of + // fs/namei.c:generic_permission(). + CheckPermissions(creds *auth.Credentials, atx vfs.AccessTypes) error + + // Mode returns the (struct stat)::st_mode value for this inode. This is + // separated from Stat for performance. + Mode() linux.FileMode + + // Stat returns the metadata for this inode. This corresponds to + // vfs.FilesystemImpl.StatAt. + Stat(fs *vfs.Filesystem) linux.Statx + + // SetStat updates the metadata for this inode. This corresponds to + // vfs.FilesystemImpl.SetStatAt. + SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error +} + +// Precondition: All methods in this interface may only be called on directory +// inodes. +type inodeDirectory interface { + // The New{File,Dir,Node,Symlink} methods below should return a new inode + // hashed into this inode. + // + // These inode constructors are inode-level operations rather than + // filesystem-level operations to allow client filesystems to mix different + // implementations based on the new node's location in the + // filesystem. + + // HasChildren returns true if the directory inode has any children. + HasChildren() bool + + // NewFile creates a new regular file inode. + NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) + + // NewDir creates a new directory inode. + NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) + + // NewLink creates a new hardlink to a specified inode in this + // directory. Implementations should create a new kernfs Dentry pointing to + // target, and update target's link count. + NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error) + + // NewSymlink creates a new symbolic link inode. + NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) + + // NewNode creates a new filesystem node for a mknod syscall. + NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) + + // Unlink removes a child dentry from this directory inode. + Unlink(ctx context.Context, name string, child *vfs.Dentry) error + + // RmDir removes an empty child directory from this directory + // inode. Implementations must update the parent directory's link count, + // if required. Implementations are not responsible for checking that child + // is a directory, checking for an empty directory. + RmDir(ctx context.Context, name string, child *vfs.Dentry) error + + // Rename is called on the source directory containing an inode being + // renamed. child should point to the resolved child in the source + // directory. If Rename replaces a dentry in the destination directory, it + // should return the replaced dentry or nil otherwise. + // + // Precondition: Caller must serialize concurrent calls to Rename. + Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error) +} + +type inodeDynamicLookup interface { + // Lookup should return an appropriate dentry if name should resolve to a + // child of this dynamic directory inode. This gives the directory an + // opportunity on every lookup to resolve additional entries that aren't + // hashed into the directory. This is only called when the inode is a + // directory. If the inode is not a directory, or if the directory only + // contains a static set of children, the implementer can unconditionally + // return an appropriate error (ENOTDIR and ENOENT respectively). + // + // The child returned by Lookup will be hashed into the VFS dentry tree. Its + // lifetime can be controlled by the filesystem implementation with an + // appropriate implementation of Valid. + // + // Lookup returns the child with an extra reference and the caller owns this + // reference. + Lookup(ctx context.Context, name string) (*vfs.Dentry, error) + + // Valid should return true if this inode is still valid, or needs to + // be resolved again by a call to Lookup. + Valid(ctx context.Context) bool +} + +type inodeSymlink interface { + // Readlink resolves the target of a symbolic link. If an inode is not a + // symlink, the implementation should return EINVAL. + Readlink(ctx context.Context) (string, error) +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go new file mode 100644 index 000000000..f78bb7b04 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -0,0 +1,423 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs_test + +import ( + "bytes" + "fmt" + "io" + "runtime" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const defaultMode linux.FileMode = 01777 +const staticFileContent = "This is sample content for a static test file." + +// RootDentryFn is a generator function for creating the root dentry of a test +// filesystem. See newTestSystem. +type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry + +// TestSystem represents the context for a single test. +type TestSystem struct { + t *testing.T + ctx context.Context + creds *auth.Credentials + vfs *vfs.VirtualFilesystem + mns *vfs.MountNamespace + root vfs.VirtualDentry +} + +// newTestSystem sets up a minimal environment for running a test, including an +// instance of a test filesystem. Tests can control the contents of the +// filesystem by providing an appropriate rootFn, which should return a +// pre-populated root dentry. +func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + v := vfs.New() + v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}) + mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("Failed to create testfs root mount: %v", err) + } + + s := &TestSystem{ + t: t, + ctx: ctx, + creds: creds, + vfs: v, + mns: mns, + root: mns.Root(), + } + runtime.SetFinalizer(s, func(s *TestSystem) { s.root.DecRef() }) + return s +} + +// PathOpAtRoot constructs a vfs.PathOperation for a path from the +// root of the test filesystem. +// +// Precondition: path should be relative path. +func (s *TestSystem) PathOpAtRoot(path string) vfs.PathOperation { + return vfs.PathOperation{ + Root: s.root, + Start: s.root, + Pathname: path, + } +} + +// GetDentryOrDie attempts to resolve a dentry referred to by the +// provided path operation. If unsuccessful, the test fails. +func (s *TestSystem) GetDentryOrDie(pop vfs.PathOperation) vfs.VirtualDentry { + vd, err := s.vfs.GetDentryAt(s.ctx, s.creds, &pop, &vfs.GetDentryOptions{}) + if err != nil { + s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err) + } + return vd +} + +func (s *TestSystem) ReadToEnd(fd *vfs.FileDescription) (string, error) { + buf := make([]byte, usermem.PageSize) + bufIOSeq := usermem.BytesIOSequence(buf) + opts := vfs.ReadOptions{} + + var content bytes.Buffer + for { + n, err := fd.Impl().Read(s.ctx, bufIOSeq, opts) + if n == 0 || err != nil { + if err == io.EOF { + err = nil + } + return content.String(), err + } + content.Write(buf[:n]) + } +} + +type fsType struct { + rootFn RootDentryFn +} + +type filesystem struct { + kernfs.Filesystem +} + +type file struct { + kernfs.DynamicBytesFile + content string +} + +func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry { + f := &file{} + f.content = content + f.DynamicBytesFile.Init(creds, fs.NextIno(), f) + + d := &kernfs.Dentry{} + d.Init(f) + return d +} + +func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%s", f.content) + return nil +} + +type attrs struct { + kernfs.InodeAttrs +} + +func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error { + return syserror.EPERM +} + +type readonlyDir struct { + attrs + kernfs.InodeNotSymlink + kernfs.InodeNoDynamicLookup + kernfs.InodeDirectoryNoNewChildren + + kernfs.OrderedChildren + dentry kernfs.Dentry +} + +func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { + dir := &readonlyDir{} + dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + dir.dentry.Init(dir) + + dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) + + return &dir.dentry +} + +func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} + +type dir struct { + attrs + kernfs.InodeNotSymlink + kernfs.InodeNoDynamicLookup + + fs *filesystem + dentry kernfs.Dentry + kernfs.OrderedChildren +} + +func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { + dir := &dir{} + dir.fs = fs + dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) + dir.dentry.Init(dir) + + dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) + + return &dir.dentry +} + +func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} + +func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { + creds := auth.CredentialsFromContext(ctx) + dir := d.fs.newDir(creds, opts.Mode, nil) + dirVFSD := dir.VFSDentry() + if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil { + dir.DecRef() + return nil, err + } + d.IncLinks(1) + return dirVFSD, nil +} + +func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { + creds := auth.CredentialsFromContext(ctx) + f := d.fs.newFile(creds, "") + fVFSD := f.VFSDentry() + if err := d.OrderedChildren.Insert(name, fVFSD); err != nil { + f.DecRef() + return nil, err + } + return fVFSD, nil +} + +func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (fst *fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + fs := &filesystem{} + fs.Init(vfsObj) + root := fst.rootFn(creds, fs) + return fs.VFSFilesystem(), root.VFSDentry(), nil +} + +// -------------------- Remainder of the file are test cases -------------------- + +func TestBasic(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "file1": fs.newFile(creds, staticFileContent), + }) + }) + sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef() +} + +func TestMkdirGetDentry(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir1": fs.newDir(creds, 0755, nil), + }) + }) + + pop := sys.PathOpAtRoot("dir1/a new directory") + if err := sys.vfs.MkdirAt(sys.ctx, sys.creds, &pop, &vfs.MkdirOptions{Mode: 0755}); err != nil { + t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err) + } + sys.GetDentryOrDie(pop).DecRef() +} + +func TestReadStaticFile(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "file1": fs.newFile(creds, staticFileContent), + }) + }) + + pop := sys.PathOpAtRoot("file1") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + content, err := sys.ReadToEnd(fd) + if err != nil { + sys.t.Fatalf("Read failed: %v", err) + } + if diff := cmp.Diff(staticFileContent, content); diff != "" { + sys.t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff) + } +} + +func TestCreateNewFileInStaticDir(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir1": fs.newDir(creds, 0755, nil), + }) + }) + + pop := sys.PathOpAtRoot("dir1/newfile") + opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode} + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, opts) + if err != nil { + sys.t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err) + } + + // Close the file. The file should persist. + fd.DecRef() + + fd, err = sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) + } + fd.DecRef() +} + +// direntCollector provides an implementation for vfs.IterDirentsCallback for +// testing. It simply iterates to the end of a given directory FD and collects +// all dirents emitted by the callback. +type direntCollector struct { + mu sync.Mutex + dirents map[string]vfs.Dirent +} + +// Handle implements vfs.IterDirentsCallback.Handle. +func (d *direntCollector) Handle(dirent vfs.Dirent) bool { + d.mu.Lock() + if d.dirents == nil { + d.dirents = make(map[string]vfs.Dirent) + } + d.dirents[dirent.Name] = dirent + d.mu.Unlock() + return true +} + +// count returns the number of dirents currently in the collector. +func (d *direntCollector) count() int { + d.mu.Lock() + defer d.mu.Unlock() + return len(d.dirents) +} + +// contains checks whether the collector has a dirent with the given name and +// type. +func (d *direntCollector) contains(name string, typ uint8) error { + d.mu.Lock() + defer d.mu.Unlock() + dirent, ok := d.dirents[name] + if !ok { + return fmt.Errorf("No dirent named %q found", name) + } + if dirent.Type != typ { + return fmt.Errorf("Dirent named %q found, but was expecting type %d, got: %+v", name, typ, dirent) + } + return nil +} + +func TestDirFDReadWrite(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, nil) + }) + + pop := sys.PathOpAtRoot("/") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + // Read/Write should fail for directory FDs. + if _, err := fd.Read(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { + sys.t.Fatalf("Read for directory FD failed with unexpected error: %v", err) + } + if _, err := fd.Write(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EISDIR { + sys.t.Fatalf("Wrire for directory FD failed with unexpected error: %v", err) + } +} + +func TestDirFDIterDirents(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + // Fill root with nodes backed by various inode implementations. + "dir1": fs.newReadonlyDir(creds, 0755, nil), + "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir3": fs.newDir(creds, 0755, nil), + }), + "file1": fs.newFile(creds, staticFileContent), + }) + }) + + pop := sys.PathOpAtRoot("/") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + collector := &direntCollector{} + if err := fd.IterDirents(sys.ctx, collector); err != nil { + sys.t.Fatalf("IterDirent failed: %v", err) + } + + // Root directory should contain ".", ".." and 3 children: + if collector.count() != 5 { + sys.t.Fatalf("IterDirent returned too many dirents") + } + for _, dirName := range []string{".", "..", "dir1", "dir2"} { + if err := collector.contains(dirName, linux.DT_DIR); err != nil { + sys.t.Fatalf("IterDirent had unexpected results: %v", err) + } + } + if err := collector.contains("file1", linux.DT_REG); err != nil { + sys.t.Fatalf("IterDirent had unexpected results: %v", err) + } + +} diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index 7e364c5fd..0cc751eb8 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -24,14 +24,19 @@ go_library( "directory.go", "filesystem.go", "memfs.go", + "named_pipe.go", "regular_file.go", "symlink.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs", deps = [ "//pkg/abi/linux", + "//pkg/amutex", + "//pkg/fspath", + "//pkg/sentry/arch", "//pkg/sentry/context", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/pipe", "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", @@ -45,6 +50,7 @@ go_test( deps = [ ":memfs", "//pkg/abi/linux", + "//pkg/refs", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", @@ -54,3 +60,19 @@ go_test( "//pkg/syserror", ], ) + +go_test( + name = "memfs_test", + size = "small", + srcs = ["pipe_test.go"], + embed = [":memfs"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/memfs/benchmark_test.go index a94b17db6..4a7a94a52 100644 --- a/pkg/sentry/fsimpl/memfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/memfs/benchmark_test.go @@ -21,6 +21,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -160,6 +161,8 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { b.Fatalf("stat(%q) failed: %v", filePath, err) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -173,10 +176,11 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } + defer mntns.DecRef(vfsObj) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') @@ -186,7 +190,6 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { defer root.DecRef() vd := root vd.IncRef() - defer vd.DecRef() for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ @@ -219,6 +222,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) + vd.DecRef() + vd = vfs.VirtualDentry{} if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } @@ -243,6 +248,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { b.Fatalf("got wrong permissions (%0o)", stat.Mode) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -343,6 +350,8 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { b.Fatalf("stat(%q) failed: %v", filePath, err) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -356,10 +365,11 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } + defer mntns.DecRef(vfsObj) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') @@ -384,7 +394,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { } defer mountPoint.DecRef() // Create and mount the submount. - if err := vfsObj.NewMount(ctx, creds, "", &pop, "memfs", &vfs.NewFilesystemOptions{}); err != nil { + if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } filePathBuilder.WriteString(mountPointName) @@ -395,7 +405,6 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount root: %v", err) } - defer vd.DecRef() for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ @@ -435,6 +444,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) + vd.DecRef() if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } @@ -459,6 +469,14 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { b.Fatalf("got wrong permissions (%0o)", stat.Mode) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } + +func init() { + // Turn off reference leak checking for a fair comparison between vfs1 and + // vfs2. + refs.SetLeakMode(refs.NoLeakChecking) +} diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go index f79e2d9c8..af4389459 100644 --- a/pkg/sentry/fsimpl/memfs/filesystem.go +++ b/pkg/sentry/fsimpl/memfs/filesystem.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -159,7 +160,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op return nil, err } } - inode.incRef() // vfsd.IncRef(&fs.vfsfs) + inode.incRef() return vfsd, nil } @@ -233,7 +234,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if err != nil { return err } - _, err = checkCreateLocked(rp, parentVFSD, parentInode) + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) if err != nil { return err } @@ -241,17 +242,49 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - // TODO: actually implement mknod - return syserror.EPERM + + switch opts.Mode.FileType() { + case 0: + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + fallthrough + case linux.ModeRegular: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeNamedPipe: + child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode)) + parentVFSD.InsertChild(&child.vfsd, pc) + parentInode.impl.(*directory).childList.PushBack(child) + return nil + + case linux.ModeSocket: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeCharacterDevice: + fallthrough + case linux.ModeBlockDevice: + // TODO(b/72101894): We don't support creating block or character + // devices at the moment. + // + // When we start supporting block and character devices, we'll + // need to check for CAP_MKNOD here. + return syserror.EPERM + + default: + // "EINVAL - mode requested creation of something other than a + // regular file, device special file, FIFO or socket." - mknod(2) + return syserror.EINVAL + } } // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { // Filter out flags that are not supported by memfs. O_DIRECTORY and // O_NOFOLLOW have no effect here (they're handled by VFS by setting - // appropriate bits in rp), but are returned by - // FileDescriptionImpl.StatusFlags(). - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW + // appropriate bits in rp), but are visible in FD status flags. O_NONBLOCK + // is supported only by pipes. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() @@ -260,7 +293,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return nil, err } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } mustCreate := opts.Flags&linux.O_EXCL != 0 @@ -275,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } afterTrailingSymlink: // Walk to the parent directory of the last path component. @@ -320,7 +353,7 @@ afterTrailingSymlink: child := fs.newDentry(childInode) vfsd.InsertChild(&child.vfsd, pc) inode.impl.(*directory).childList.PushBack(child) - return childInode.open(rp, &child.vfsd, opts.Flags, true) + return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true) } // Open existing file or follow symlink. if mustCreate { @@ -336,29 +369,31 @@ afterTrailingSymlink: // symlink target. goto afterTrailingSymlink } - return childInode.open(rp, childVFSD, opts.Flags, false) + return childInode.open(ctx, rp, childVFSD, opts.Flags, false) } -func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { +func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(flags) if !afterCreate { if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil { return nil, err } } + mnt := rp.Mount() switch impl := i.impl.(type) { case *regularFile: var fd regularFileFD - fd.flags = flags fd.readable = vfs.MayReadFileWithOpenFlags(flags) fd.writable = vfs.MayWriteFileWithOpenFlags(flags) if fd.writable { - if err := rp.Mount().CheckBeginWrite(); err != nil { + if err := mnt.CheckBeginWrite(); err != nil { return nil, err } - // Mount.EndWrite() is called by regularFileFD.Release(). + // mnt.EndWrite() is called by regularFileFD.Release(). } - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) if flags&linux.O_TRUNC != 0 { impl.mu.Lock() impl.data = impl.data[:0] @@ -372,12 +407,15 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte return nil, syserror.EISDIR } var fd directoryFD - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) - fd.flags = flags + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *symlink: // Can't open symlinks without O_PATH (which is unimplemented). return nil, syserror.ELOOP + case *namedPipe: + return newNamedPipeFD(ctx, impl, rp, vfsd, flags) default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } @@ -542,3 +580,58 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error inode.decLinksLocked() return nil } + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return nil, err + } + // TODO(b/127675828): support extended attributes + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return "", err + } + // TODO(b/127675828): support extended attributes + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index b78471c0f..9d509f6e4 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -52,10 +52,10 @@ type filesystem struct { nextInoMinusOne uint64 // accessed using atomic memory operations } -// NewFilesystem implements vfs.FilesystemType.NewFilesystem. -func (fstype FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { var fs filesystem - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) root := fs.newDentry(fs.newDirectory(creds, 01777)) return &fs.vfsfs, &root.vfsd, nil } @@ -99,17 +99,17 @@ func (fs *filesystem) newDentry(inode *inode) *dentry { } // IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { +func (d *dentry) IncRef() { d.inode.incRef() } // TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { +func (d *dentry) TryIncRef() bool { return d.inode.tryIncRef() } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { +func (d *dentry) DecRef() { d.inode.decRef() } @@ -227,6 +227,8 @@ func (i *inode) statTo(stat *linux.Statx) { stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS stat.Size = uint64(len(impl.target)) stat.Blocks = allocatedBlocksForSize(stat.Size) + case *namedPipe: + stat.Mode |= linux.S_IFIFO default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } @@ -259,28 +261,14 @@ func (i *inode) direntType() uint8 { type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - - flags uint32 // status flags; immutable } func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) + return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } func (fd *fileDescription) inode() *inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode -} - -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil + return fd.vfsfd.Dentry().Impl().(*dentry).inode } // Stat implements vfs.FileDescriptionImpl.Stat. diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/memfs/named_pipe.go new file mode 100644 index 000000000..d5060850e --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/named_pipe.go @@ -0,0 +1,62 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +type namedPipe struct { + inode inode + + pipe *pipe.VFSPipe +} + +// Preconditions: +// * fs.mu must be locked. +// * rp.Mount().CheckBeginWrite() has been called successfully. +func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode { + file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)} + file.inode.init(file, fs, creds, mode) + file.inode.nlink = 1 // Only the parent has a link. + return &file.inode +} + +// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented +// entirely via struct embedding. +type namedPipeFD struct { + fileDescription + + *pipe.VFSPipeFD +} + +func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + var err error + var fd namedPipeFD + fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags) + if err != nil { + return nil, err + } + mnt := rp.Mount() + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) + return &fd.vfsfd, nil +} diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go new file mode 100644 index 000000000..5bf527c80 --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/pipe_test.go @@ -0,0 +1,233 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const fileName = "mypipe" + +func TestSeparateFDs(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side. This is done in a concurrently because opening + // One end the pipe blocks until the other end is opened. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + rfdchan := make(chan *vfs.FileDescription) + go func() { + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY} + rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + rfdchan <- rfd + }() + + // Open the write side. + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + rfd, ok := <-rfdchan + if !ok { + t.Fatalf("failed to open pipe for reading %q", fileName) + } + defer rfd.DecRef() + + const msg = "vamos azul" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingRead(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side as nonblocking. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK} + rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) + } + defer rfd.DecRef() + + // Open the write side. + openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + const msg = "geh blau" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingWriteError(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the write side as nonblocking, which should return ENXIO. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} + _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != syserror.ENXIO { + t.Fatalf("expected ENXIO, but got error: %v", err) + } +} + +func TestSingleFD(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the pipe as readable and writable. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDWR} + fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer fd.DecRef() + + const msg = "forza blu" + checkEmpty(ctx, t, fd) + checkWrite(ctx, t, fd, msg) + checkRead(ctx, t, fd, msg) +} + +// setup creates a VFS with a pipe in the root directory at path fileName. The +// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal +// upon failure. +func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + + // Create VFS. + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("failed to create tmpfs root mount: %v", err) + } + + // Create the pipe. + root := mntns.Root() + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644} + if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil { + t.Fatalf("failed to create file %q: %v", fileName, err) + } + + // Sanity check: the file pipe exists and has the correct mode. + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + }, &vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat(%q) failed: %v", fileName, err) + } + if stat.Mode&^linux.S_IFMT != 0644 { + t.Errorf("got wrong permissions (%0o)", stat.Mode) + } + if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe { + t.Errorf("got wrong file type (%0o)", stat.Mode) + } + + return ctx, creds, vfsObj, root +} + +// checkEmpty calls t.Fatal if the pipe in fd is not empty. +func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + readData := make([]byte, 1) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) + if err != syserror.ErrWouldBlock { + t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) + } + if bytesRead != 0 { + t.Fatalf("expected to read 0 bytes, but got %d", bytesRead) + } +} + +// checkWrite calls t.Fatal if it fails to write all of msg to fd. +func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + writeData := []byte(msg) + src := usermem.BytesIOSequence(writeData) + bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("error writing to pipe %q: %v", fileName, err) + } + if bytesWritten != int64(len(writeData)) { + t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten) + } +} + +// checkRead calls t.Fatal if it fails to read msg from fd. +func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + readData := make([]byte, len(msg)) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) + if err != nil { + t.Fatalf("error reading from pipe %q: %v", fileName, err) + } + if bytesRead != int64(len(msg)) { + t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead) + } + if !bytes.Equal(readData, []byte(msg)) { + t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData)) + } +} diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go index c36c4aff5..0e016bca5 100644 --- a/pkg/sentry/fsimpl/proc/filesystems.go +++ b/pkg/sentry/fsimpl/proc/filesystems.go @@ -19,7 +19,7 @@ package proc // +stateify savable type filesystemsData struct{} -// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for +// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for // filesystemsData. We would need to retrive filesystem names from // vfs.VirtualFilesystem. Also needs vfs replacement for // fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev. diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go index e81b1e910..8683cf677 100644 --- a/pkg/sentry/fsimpl/proc/mounts.go +++ b/pkg/sentry/fsimpl/proc/mounts.go @@ -16,7 +16,7 @@ package proc import "gvisor.dev/gvisor/pkg/sentry/kernel" -// TODO(b/138862512): Implement mountInfoFile and mountsFile. +// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile. // mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo. // diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index c46e05c3a..0d87be52b 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -229,6 +229,10 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps) fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(buf, "Mems_allowed:\t1\n") + fmt.Fprintf(buf, "Mems_allowed_list:\t0\n") return nil } diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index d5284f0d9..8d60ad4ad 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -13,5 +13,8 @@ go_library( "test_stack.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], + deps = [ + "//pkg/sentry/context", + "//pkg/tcpip/stack", + ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80f227dbe..a7dfb78a7 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,6 +15,8 @@ // Package inet defines semantics for IP stacks. package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // Stack represents a TCP/IP stack. type Stack interface { // Interfaces returns all network interfaces as a mapping from interface @@ -58,6 +60,16 @@ type Stack interface { // Resume restarts the network stack after restore. Resume() + + // RegisteredEndpoints returns all endpoints which are currently registered. + RegisteredEndpoints() []stack.TransportEndpoint + + // CleanupEndpoints returns endpoints currently in the cleanup state. + CleanupEndpoints() []stack.TransportEndpoint + + // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful + // for restoring a stack after a save. + RestoreCleanupEndpoints([]stack.TransportEndpoint) } // Interface contains information about a network interface. @@ -153,3 +165,23 @@ type Route struct { // GatewayAddr is the route gateway address (RTA_GATEWAY). GatewayAddr []byte } + +// Below SNMP metrics are from Linux/usr/include/linux/snmp.h. + +// StatSNMPIP describes Ip line of /proc/net/snmp. +type StatSNMPIP [19]uint64 + +// StatSNMPICMP describes Icmp line of /proc/net/snmp. +type StatSNMPICMP [27]uint64 + +// StatSNMPICMPMSG describes IcmpMsg line of /proc/net/snmp. +type StatSNMPICMPMSG [512]uint64 + +// StatSNMPTCP describes Tcp line of /proc/net/snmp. +type StatSNMPTCP [15]uint64 + +// StatSNMPUDP describes Udp line of /proc/net/snmp. +type StatSNMPUDP [8]uint64 + +// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. +type StatSNMPUDPLite [8]uint64 diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index b9eed7c3a..dcfcbd97e 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,6 +14,8 @@ package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // TestStack is a dummy implementation of Stack for tests. type TestStack struct { InterfacesMap map[int32]Interface @@ -94,5 +96,17 @@ func (s *TestStack) RouteTable() []Route { } // Resume implements Stack.Resume. -func (s *TestStack) Resume() { +func (s *TestStack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { + return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { + return nil +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e041c51b3..2706927ff 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -35,7 +35,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//third_party/gvsync:generic_seqatomic", + template = "//pkg/syncutil:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -209,12 +209,12 @@ go_library( "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", + "//pkg/syncutil", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 51de4568a..04c244447 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//third_party/gvsync:generic_atomicptr", + template = "//pkg/syncutil:generic_atomicptr", types = { "Value": "Credentials", }, diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index e3f5b0d83..3c9dceaba 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,6 +15,8 @@ package kernel import ( + "time" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task { return nil } +// Deadline implements context.Context.Deadline. +func (*Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done implements context.Context.Done. +func (*Task) Done() <-chan struct{} { + return nil +} + +// Err implements context.Context.Err. +func (*Task) Err() error { + return nil +} + // AsyncContext returns a context.Context that may be used by goroutines that // do work on behalf of t and therefore share its contextual values, but are // not t's task goroutine (e.g. asynchronous I/O). @@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool { return ctx.t.IsLogging(level) } +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return ctx.t.Deadline() +} + +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return ctx.t.Done() +} + +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return ctx.t.Err() +} + // Value implements context.Context.Value. func (ctx taskAsyncContext) Value(key interface{}) interface{} { return ctx.t.Value(key) diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index cc3f43a45..11f613a11 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -81,6 +81,9 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` + // next is start position to find fd. + next int32 + // used contains the number of non-nil entries. It must be accessed // atomically. It may be read atomically without holding mu (but not // written). @@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.mu.Lock() defer f.mu.Unlock() + // From f.next to find available fd. + if fd < f.next { + fd = f.next + } + // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { @@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, syscall.EMFILE } + if fd == f.next { + // Update next search start position. + f.next = fds[len(fds)-1] + 1 + } + return fds, nil } @@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File { f.mu.Lock() defer f.mu.Unlock() + // Update current available position. + if fd < f.next { + f.next = fd + } + orig, _, _ := f.get(fd) if orig != nil { orig.IncRef() // Reference for caller. @@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) { f.forEach(func(fd int32, file *fs.File, flags FDFlags) { if cond(file, flags) { f.set(fd, nil, FDFlags{}) // Clear from table. + // Update current available position. + if fd < f.next { + f.next = fd + } } }) } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2413788e7..2bcb6216a 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) { if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) } + + i := int32(2) + fdTable.Remove(i) + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { + t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) + } + }) +} + +func TestFDTableOverLimit(t *testing.T) { + runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { + if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error") + } + + if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error") + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil { + t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) + } else { + for _, fd := range fds { + fdTable.Remove(fd) + } + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 { + t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) + } + + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { + t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) + } else if len(fds) != 1 || fds[0] != 0 { + t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) + } }) } diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 34286c7a8..75ec31761 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//third_party/gvsync:generic_atomicptr", + template = "//pkg/syncutil:generic_atomicptr", types = { "Value": "bucket", }, diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 3cda03891..bd3fb4c03 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -391,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(w, k.FeatureSet(), nil); err != nil { + if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) @@ -399,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Save(w, k, &stats); err != nil { + if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { return err } log.Infof("Kernel save stats: %s", &stats) @@ -407,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(w); err != nil { + if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -542,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(r, &features, nil); err != nil { + if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -558,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Load(r, k, &stats); err != nil { + if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { return err } log.Infof("Kernel load stats: %s", &stats) @@ -566,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(r); err != nil { + if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -804,8 +804,21 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, // Create a fresh task context. remainingTraversals = uint(args.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: mounts, + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: true, + Filename: args.Filename, + File: args.File, + CloseOnExec: false, + Argv: args.Argv, + Envv: args.Envv, + Features: k.featureSet, + } - tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet) + tc, se := k.LoadTaskImage(ctx, loadArgs) if se != nil { return nil, 0, errors.New(se.String()) } @@ -828,9 +841,11 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, AbstractSocketNamespace: args.AbstractSocketNamespace, ContainerID: args.ContainerID, } - if _, err := k.tasks.NewTask(config); err != nil { + t, err := k.tasks.NewTask(config) + if err != nil { return nil, 0, err } + t.traceExecEvent(tc) // Simulate exec for tracing. // Success. tgid := k.tasks.Root.IDOfThreadGroup(tg) @@ -1105,6 +1120,22 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { return lastErr } +// RebuildTraceContexts rebuilds the trace context for all tasks. +// +// Unfortunately, if these are built while tracing is not enabled, then we will +// not have meaningful trace data. Rebuilding here ensures that we can do so +// after tracing has been enabled. +func (k *Kernel) RebuildTraceContexts() { + k.extMu.Lock() + defer k.extMu.Unlock() + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + + for t, tid := range k.tasks.Root.tids { + t.rebuildTraceContext(tid) + } +} + // FeatureSet returns the FeatureSet. func (k *Kernel) FeatureSet() *cpuid.FeatureSet { return k.featureSet @@ -1309,6 +1340,7 @@ func (k *Kernel) ListSockets() []*SocketEntry { return socks } +// supervisorContext is a privileged context. type supervisorContext struct { context.NoopSleeper log.Logger diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index cde647139..9d34f6d4d 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -24,8 +24,10 @@ go_library( "device.go", "node.go", "pipe.go", + "pipe_util.go", "reader.go", "reader_writer.go", + "vfs.go", "writer.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", @@ -40,6 +42,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index a2dc72204..4a19ab7ce 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -91,10 +90,10 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi switch { case flags.Read && !flags.Write: // O_RDONLY. r := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) + newHandleLocked(&i.rWakeup) if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { - if !i.waitFor(&i.wWakeup, ctx) { + if !waitFor(&i.mu, &i.wWakeup, ctx) { r.DecRef() return nil, syserror.ErrInterrupted } @@ -107,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Write && !flags.Read: // O_WRONLY. w := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.wWakeup) if i.p.isNamed && !i.p.HasReaders() { // On a nonblocking, write-only open, the open fails with ENXIO if the @@ -117,7 +116,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi return nil, syserror.ENXIO } - if !i.waitFor(&i.rWakeup, ctx) { + if !waitFor(&i.mu, &i.rWakeup, ctx) { w.DecRef() return nil, syserror.ErrInterrupted } @@ -127,8 +126,8 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Read && flags.Write: // O_RDWR. // Pipes opened for read-write always succeeds without blocking. rw := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.rWakeup) + newHandleLocked(&i.wWakeup) return rw, nil default: @@ -136,65 +135,6 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi } } -// waitFor blocks until the underlying pipe has at least one reader/writer is -// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this -// function will block for either readers or writers, depending on where -// 'wakeupChan' points. -// -// f.mu must be held by the caller. waitFor returns with f.mu held, but it will -// drop f.mu before blocking for any reader/writers. -func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { - // Ideally this function would simply use a condition variable. However, the - // wait needs to be interruptible via 'sleeper', so we must sychronize via a - // channel. The synchronization below relies on the fact that closing a - // channel unblocks all receives on the channel. - - // Does an appropriate wakeup channel already exist? If not, create a new - // one. This is all done under f.mu to avoid races. - if *wakeupChan == nil { - *wakeupChan = make(chan struct{}) - } - - // Grab a local reference to the wakeup channel since it may disappear as - // soon as we drop f.mu. - wakeup := *wakeupChan - - // Drop the lock and prepare to sleep. - i.mu.Unlock() - cancel := sleeper.SleepStart() - - // Wait for either a new reader/write to be signalled via 'wakeup', or - // for the sleep to be cancelled. - select { - case <-wakeup: - sleeper.SleepFinish(true) - case <-cancel: - sleeper.SleepFinish(false) - } - - // Take the lock and check if we were woken. If we were woken and - // interrupted, the former takes priority. - i.mu.Lock() - select { - case <-wakeup: - return true - default: - return false - } -} - -// newHandleLocked signals a new pipe reader or writer depending on where -// 'wakeupChan' points. This unblocks any corresponding reader or writer -// waiting for the other end of the channel to be opened, see Fifo.waitFor. -// -// i.mu must be held. -func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) { - if *wakeupChan != nil { - close(*wakeupChan) - *wakeupChan = nil - } -} - func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error { return syserror.EPIPE } diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 8e4e8e82e..1a1b38f83 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -111,11 +111,27 @@ func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { if atomicIOBytes > sizeBytes { atomicIOBytes = sizeBytes } - return &Pipe{ - isNamed: isNamed, - max: sizeBytes, - atomicIOBytes: atomicIOBytes, + var p Pipe + initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + return &p +} + +func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { + if sizeBytes < MinimumPipeSize { + sizeBytes = MinimumPipeSize + } + if sizeBytes > MaximumPipeSize { + sizeBytes = MaximumPipeSize + } + if atomicIOBytes <= 0 { + atomicIOBytes = 1 + } + if atomicIOBytes > sizeBytes { + atomicIOBytes = sizeBytes } + pipe.isNamed = isNamed + pipe.max = sizeBytes + pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go new file mode 100644 index 000000000..ef9641e6a --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -0,0 +1,213 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "io" + "math" + "sync" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains Pipe file functionality that is tied to neither VFS nor +// the old fs architecture. + +// Release cleans up the pipe's state. +func (p *Pipe) Release() { + p.rClose() + p.wClose() + + // Wake up readers and writers. + p.Notify(waiter.EventIn | waiter.EventOut) +} + +// Read reads from the Pipe into dst. +func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) { + n, err := p.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo writes to w from the Pipe. +func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return p.dup(ctx, ops) + } + n, err := p.read(ctx, ops) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// Write writes to the Pipe from src. +func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom reads from r to the Pipe. +func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// Readiness returns the ready events in the underlying pipe. +func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask { + return p.rwReadiness() & mask +} + +// Ioctl implements ioctls on the Pipe. +func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + // Switch on ioctl request. + switch int(args[1].Int()) { + case linux.FIONREAD: + v := p.queued() + if v > math.MaxInt32 { + v = math.MaxInt32 // Silently truncate. + } + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + default: + return 0, syscall.ENOTTY + } +} + +// waitFor blocks until the underlying pipe has at least one reader/writer is +// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this +// function will block for either readers or writers, depending on where +// 'wakeupChan' points. +// +// mu must be held by the caller. waitFor returns with mu held, but it will +// drop mu before blocking for any reader/writers. +func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { + // Ideally this function would simply use a condition variable. However, the + // wait needs to be interruptible via 'sleeper', so we must sychronize via a + // channel. The synchronization below relies on the fact that closing a + // channel unblocks all receives on the channel. + + // Does an appropriate wakeup channel already exist? If not, create a new + // one. This is all done under f.mu to avoid races. + if *wakeupChan == nil { + *wakeupChan = make(chan struct{}) + } + + // Grab a local reference to the wakeup channel since it may disappear as + // soon as we drop f.mu. + wakeup := *wakeupChan + + // Drop the lock and prepare to sleep. + mu.Unlock() + cancel := sleeper.SleepStart() + + // Wait for either a new reader/write to be signalled via 'wakeup', or + // for the sleep to be cancelled. + select { + case <-wakeup: + sleeper.SleepFinish(true) + case <-cancel: + sleeper.SleepFinish(false) + } + + // Take the lock and check if we were woken. If we were woken and + // interrupted, the former takes priority. + mu.Lock() + select { + case <-wakeup: + return true + default: + return false + } +} + +// newHandleLocked signals a new pipe reader or writer depending on where +// 'wakeupChan' points. This unblocks any corresponding reader or writer +// waiting for the other end of the channel to be opened, see Fifo.waitFor. +// +// Precondition: the mutex protecting wakeupChan must be held. +func newHandleLocked(wakeupChan *chan struct{}) { + if *wakeupChan != nil { + close(*wakeupChan) + *wakeupChan = nil + } +} diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index 7c307f013..b4d29fc77 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -16,16 +16,12 @@ package pipe import ( "io" - "math" - "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) // ReaderWriter satisfies the FileOperations interface and services both @@ -45,124 +41,27 @@ type ReaderWriter struct { *Pipe } -// Release implements fs.FileOperations.Release. -func (rw *ReaderWriter) Release() { - rw.Pipe.rClose() - rw.Pipe.wClose() - - // Wake up readers and writers. - rw.Pipe.Notify(waiter.EventIn | waiter.EventOut) -} - // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, readOps{ - left: func() int64 { - return dst.NumBytes() - }, - limit: func(l int64) { - dst = dst.TakeFirst64(l) - }, - read: func(buf *buffer) (int64, error) { - n, err := dst.CopyOutFrom(ctx, buf) - dst = dst.DropFirst64(n) - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventOut) - } - return n, err + return rw.Pipe.Read(ctx, dst) } // WriteTo implements fs.FileOperations.WriteTo. func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { - ops := readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(buf *buffer) (int64, error) { - n, err := buf.ReadToWriter(w, count, dup) - count -= n - return n, err - }, - } - if dup { - // There is no notification for dup operations. - return rw.Pipe.dup(ctx, ops) - } - n, err := rw.Pipe.read(ctx, ops) - if n > 0 { - rw.Pipe.Notify(waiter.EventOut) - } - return n, err + return rw.Pipe.WriteTo(ctx, w, count, dup) } // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, writeOps{ - left: func() int64 { - return src.NumBytes() - }, - limit: func(l int64) { - src = src.TakeFirst64(l) - }, - write: func(buf *buffer) (int64, error) { - n, err := src.CopyInTo(ctx, buf) - src = src.DropFirst64(n) - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventIn) - } - return n, err + return rw.Pipe.Write(ctx, src) } // ReadFrom implements fs.FileOperations.WriteTo. func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - n, err := rw.Pipe.write(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(buf *buffer) (int64, error) { - n, err := buf.WriteFromReader(r, count) - count -= n - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventIn) - } - return n, err -} - -// Readiness returns the ready events in the underlying pipe. -func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask { - return rw.Pipe.rwReadiness() & mask + return rw.Pipe.ReadFrom(ctx, r, count) } // Ioctl implements fs.FileOperations.Ioctl. func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // Switch on ioctl request. - switch int(args[1].Int()) { - case linux.FIONREAD: - v := rw.queued() - if v > math.MaxInt32 { - v = math.MaxInt32 // Silently truncate. - } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) - return 0, err - default: - return 0, syscall.ENOTTY - } + return rw.Pipe.Ioctl(ctx, io, args) } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go new file mode 100644 index 000000000..6416e0dd8 --- /dev/null +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -0,0 +1,220 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains types enabling the pipe package to be used with the vfs +// package. + +// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should +// not be copied. +type VFSPipe struct { + // mu protects the fields below. + mu sync.Mutex `state:"nosave"` + + // pipe is the underlying pipe. + pipe Pipe + + // Channels for synchronizing the creation of new readers and writers + // of this fifo. See waitFor and newHandleLocked. + // + // These are not saved/restored because all waiters are unblocked on + // save, and either automatically restart (via ERESTARTSYS) or return + // EINTR on resume. On restarts via ERESTARTSYS, the appropriate + // channel will be recreated. + rWakeup chan struct{} `state:"nosave"` + wWakeup chan struct{} `state:"nosave"` +} + +// NewVFSPipe returns an initialized VFSPipe. +func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe { + var vp VFSPipe + initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes) + return &vp +} + +// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics +// during open: +// +// "Normally, opening the FIFO blocks until the other end is opened also. A +// process can open a FIFO in nonblocking mode. In this case, opening for +// read-only will succeed even if no-one has opened on the write side yet, +// opening for write-only will fail with ENXIO (no such device or address) +// unless the other end has already been opened. Under Linux, opening a FIFO +// for read and write will succeed both in blocking and nonblocking mode. POSIX +// leaves this behavior undefined. This can be used to open a FIFO for writing +// while there are no readers available." - fifo(7) +func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + vp.mu.Lock() + defer vp.mu.Unlock() + + readable := vfs.MayReadFileWithOpenFlags(flags) + writable := vfs.MayWriteFileWithOpenFlags(flags) + if !readable && !writable { + return nil, syserror.EINVAL + } + + vfd, err := vp.open(rp, vfsd, vfsfd, flags) + if err != nil { + return nil, err + } + + switch { + case readable && writable: + // Pipes opened for read-write always succeed without blocking. + newHandleLocked(&vp.rWakeup) + newHandleLocked(&vp.wWakeup) + + case readable: + newHandleLocked(&vp.rWakeup) + // If this pipe is being opened as nonblocking and there's no + // writer, we have to wait for a writer to open the other end. + if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { + return nil, syserror.EINTR + } + + case writable: + newHandleLocked(&vp.wWakeup) + + if !vp.pipe.HasReaders() { + // Nonblocking, write-only opens fail with ENXIO when + // the read side isn't open yet. + if flags&linux.O_NONBLOCK != 0 { + return nil, syserror.ENXIO + } + // Wait for a reader to open the other end. + if !waitFor(&vp.mu, &vp.rWakeup, ctx) { + return nil, syserror.EINTR + } + } + + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return vfd, nil +} + +// Preconditions: vp.mu must be held. +func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + var fd VFSPipeFD + fd.flags = flags + fd.readable = vfs.MayReadFileWithOpenFlags(flags) + fd.writable = vfs.MayWriteFileWithOpenFlags(flags) + fd.vfsfd = vfsfd + fd.pipe = &vp.pipe + if fd.writable { + // The corresponding Mount.EndWrite() is in VFSPipe.Release(). + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + } + + switch { + case fd.readable && fd.writable: + vp.pipe.rOpen() + vp.pipe.wOpen() + case fd.readable: + vp.pipe.rOpen() + case fd.writable: + vp.pipe.wOpen() + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return &fd, nil +} + +// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is +// expected that filesystesm will use this in a struct implementing +// vfs.FileDescriptionImpl. +type VFSPipeFD struct { + pipe *Pipe + flags uint32 + readable bool + writable bool + vfsfd *vfs.FileDescription +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *VFSPipeFD) Release() { + var event waiter.EventMask + if fd.readable { + fd.pipe.rClose() + event |= waiter.EventIn + } + if fd.writable { + fd.pipe.wClose() + event |= waiter.EventOut + } + if event == 0 { + panic("invalid pipe flags: must be readable, writable, or both") + } + + if fd.writable { + fd.vfsfd.VirtualDentry().Mount().EndWrite() + } + + fd.pipe.Notify(event) +} + +// OnClose implements vfs.FileDescriptionImpl.OnClose. +func (fd *VFSPipeFD) OnClose(_ context.Context) error { + return nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { + if !fd.readable { + return 0, syserror.EINVAL + } + + return fd.pipe.Read(ctx, dst) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { + if !fd.writable { + return 0, syserror.EINVAL + } + + return fd.pipe.Write(ctx, src) +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.pipe.Ioctl(ctx, uio, args) +} diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go index 0acdf769d..61e412911 100644 --- a/pkg/sentry/kernel/ptrace_arm64.go +++ b/pkg/sentry/kernel/ptrace_arm64.go @@ -17,7 +17,6 @@ package kernel import ( - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 93fe68a3e..de9617e9d 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -302,7 +302,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred return syserror.ERANGE } - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = val sem.pid = pid s.changeTime = ktime.NowFromContext(ctx) @@ -336,7 +336,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti for i, val := range vals { sem := &s.sems[i] - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = int16(val) sem.pid = pid sem.wakeWaiters() @@ -481,7 +481,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch } // All operations succeeded, apply them. - // TODO(b/29354920): handle undo operations. + // TODO(gvisor.dev/issue/137): handle undo operations. for i, v := range tmpVals { s.sems[i].value = v s.sems[i].wakeWaiters() diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 50b69d154..9f7e19b4d 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "signalfd", srcs = ["signalfd.go"], diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 220fa73a2..2fdee0282 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -339,6 +339,14 @@ func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn { return nil } +// LookupName looks up a syscall name. +func (s *SyscallTable) LookupName(sysno uintptr) string { + if sc, ok := s.Table[sysno]; ok { + return sc.Name + } + return fmt.Sprintf("sys_%d", sysno) // Unlikely. +} + // LookupEmulate looks up an emulation syscall number. func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index c82ef5486..ab0c6c4aa 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -15,6 +15,8 @@ package kernel import ( + gocontext "context" + "runtime/trace" "sync" "sync/atomic" @@ -35,8 +37,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syncutil" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/third_party/gvsync" ) // Task represents a thread of execution in the untrusted app. It @@ -83,7 +85,7 @@ type Task struct { // // gosched is protected by goschedSeq. gosched is owned by the task // goroutine. - goschedSeq gvsync.SeqCount `state:"nosave"` + goschedSeq syncutil.SeqCount `state:"nosave"` gosched TaskGoroutineSchedInfo // yieldCount is the number of times the task goroutine has called @@ -390,7 +392,14 @@ type Task struct { // logPrefix is a string containing the task's thread ID in the root PID // namespace, and is prepended to log messages emitted by Task.Infof etc. - logPrefix atomic.Value `state:".(string)"` + logPrefix atomic.Value `state:"nosave"` + + // traceContext and traceTask are both used for tracing, and are + // updated along with the logPrefix in updateInfoLocked. + // + // These are exclusive to the task goroutine. + traceContext gocontext.Context `state:"nosave"` + traceTask *trace.Task `state:"nosave"` // creds is the task's credentials. // @@ -528,14 +537,6 @@ func (t *Task) loadPtraceTracer(tracer *Task) { t.ptraceTracer.Store(tracer) } -func (t *Task) saveLogPrefix() string { - return t.logPrefix.Load().(string) -} - -func (t *Task) loadLogPrefix(prefix string) { - t.logPrefix.Store(prefix) -} - func (t *Task) saveSyscallFilters() []bpf.Program { if f := t.syscallFilters.Load(); f != nil { return f.([]bpf.Program) @@ -549,6 +550,7 @@ func (t *Task) loadSyscallFilters(filters []bpf.Program) { // afterLoad is invoked by stateify. func (t *Task) afterLoad() { + t.updateInfoLocked() t.interruptChan = make(chan struct{}, 1) t.gosched.State = TaskGoroutineNonexistent if t.stop != nil { @@ -709,9 +711,9 @@ func (t *Task) FDTable() *FDTable { return t.fdTable } -// GetFile is a convenience wrapper t.FDTable().GetFile. +// GetFile is a convenience wrapper for t.FDTable().Get. // -// Precondition: same as FDTable. +// Precondition: same as FDTable.Get. func (t *Task) GetFile(fd int32) *fs.File { f, _ := t.fdTable.Get(fd) return f diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index dd69939f9..4a4a69ee2 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -16,6 +16,7 @@ package kernel import ( "runtime" + "runtime/trace" "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -133,19 +134,24 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { runtime.Gosched() } + region := trace.StartRegion(t.traceContext, blockRegion) select { case <-C: + region.End() t.SleepFinish(true) + // Woken by event. return nil case <-interrupt: + region.End() t.SleepFinish(false) // Return the indicated error on interrupt. return syserror.ErrInterrupted case <-timerChan: - // We've timed out. + region.End() t.SleepFinish(true) + // We've timed out. return syserror.ETIMEDOUT } } diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 0916fd658..3eadfedb4 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -299,6 +299,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { // nt that it must receive before its task goroutine starts running. tid := nt.k.tasks.Root.IDOfTask(nt) defer nt.Start(tid) + t.traceCloneEvent(tid) // "If fork/clone and execve are allowed by @prog, any child processes will // be constrained to the same filters and system call ABI as the parent." - diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 8639d379f..bb5560acf 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -18,10 +18,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -132,30 +130,21 @@ func (t *Task) Stack() *arch.Stack { return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} } -// LoadTaskImage loads filename into a new TaskContext. +// LoadTaskImage loads a specified file into a new TaskContext. // -// It takes several arguments: -// * mounts: MountNamespace to lookup filename in -// * root: Root to lookup filename under -// * wd: Working directory to lookup filename under -// * maxTraversals: maximum number of symlinks to follow -// * filename: path to binary to load -// * file: an open fs.File object of the binary to load. If set, -// file will be loaded and not filename. -// * argv: Binary argv -// * envv: Binary envv -// * fs: Binary FeatureSet -func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) { - // If File is not nil, we should load that instead of resolving filename. - if file != nil { - filename = file.MappedName(ctx) +// args.MemoryManager does not need to be set by the caller. +func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) { + // If File is not nil, we should load that instead of resolving Filename. + if args.File != nil { + args.Filename = args.File.MappedName(ctx) } // Prepare a new user address space to load into. m := mm.NewMemoryManager(k, k) defer m.DecUsers(ctx) + args.MemoryManager = m - os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso) + os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) if err != nil { return nil, err } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 17a089b90..90a6190f1 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -129,6 +129,7 @@ type runSyscallAfterExecStop struct { } func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { + t.traceExecEvent(r.tc) t.tg.pidns.owner.mu.Lock() t.tg.execing = nil if t.killed() { @@ -253,7 +254,7 @@ func (t *Task) promoteLocked() { t.tg.leader = t t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t]) - t.updateLogPrefixLocked() + t.updateInfoLocked() // Reap the original leader. If it has a tracer, detach it instead of // waiting for it to acknowledge the original leader's death. oldLeader.exitParentNotified = true diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 535f03e50..435761e5a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -236,6 +236,7 @@ func (*runExit) execute(t *Task) taskRunState { type runExitMain struct{} func (*runExitMain) execute(t *Task) taskRunState { + t.traceExitEvent() lastExiter := t.exitThreadGroup() // If the task has a cleartid, and the thread group wasn't killed by a diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index a29e9b9eb..0fb3661de 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -16,6 +16,7 @@ package kernel import ( "fmt" + "runtime/trace" "sort" "gvisor.dev/gvisor/pkg/log" @@ -127,11 +128,88 @@ func (t *Task) debugDumpStack() { } } -// updateLogPrefix updates the task's cached log prefix to reflect its -// current thread ID. +// trace definitions. +// +// Note that all region names are prefixed by ':' in order to ensure that they +// are lexically ordered before all system calls, which use the naked system +// call name (e.g. "read") for maximum clarity. +const ( + traceCategory = "task" + runRegion = ":run" + blockRegion = ":block" + cpuidRegion = ":cpuid" + faultRegion = ":fault" +) + +// updateInfoLocked updates the task's cached log prefix and tracing +// information to reflect its current thread ID. // // Preconditions: The task's owning TaskSet.mu must be locked. -func (t *Task) updateLogPrefixLocked() { +func (t *Task) updateInfoLocked() { // Use the task's TID in the root PID namespace for logging. - t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t])) + tid := t.tg.pidns.owner.Root.tids[t] + t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid)) + t.rebuildTraceContext(tid) +} + +// rebuildTraceContext rebuilds the trace context. +// +// Precondition: the passed tid must be the tid in the root namespace. +func (t *Task) rebuildTraceContext(tid ThreadID) { + // Re-initialize the trace context. + if t.traceTask != nil { + t.traceTask.End() + } + + // Note that we define the "task type" to be the dynamic TID. This does + // not align perfectly with the documentation for "tasks" in the + // tracing package. Tasks may be assumed to be bounded by analysis + // tools. However, if we just use a generic "task" type here, then the + // "user-defined tasks" page on the tracing dashboard becomes nearly + // unusable, as it loads all traces from all tasks. + // + // We can assume that the number of tasks in the system is not + // arbitrarily large (in general it won't be, especially for cases + // where we're collecting a brief profile), so using the TID is a + // reasonable compromise in this case. + t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid)) +} + +// traceCloneEvent is called when a new task is spawned. +// +// ntid must be the new task's ThreadID in the root namespace. +func (t *Task) traceCloneEvent(ntid ThreadID) { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid) +} + +// traceExitEvent is called when a task exits. +func (t *Task) traceExitEvent() { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status()) +} + +// traceExecEvent is called when a task calls exec. +func (t *Task) traceExecEvent(tc *TaskContext) { + if !trace.IsEnabled() { + return + } + d := tc.MemoryManager.Executable() + if d == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") + return + } + defer d.DecRef() + root := t.fsContext.RootDirectory() + if root == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>") + return + } + defer root.DecRef() + n, _ := d.FullName(root) + trace.Logf(t.traceContext, traceCategory, "exec: %s", n) } diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index c92266c59..d97f8c189 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -17,6 +17,7 @@ package kernel import ( "bytes" "runtime" + "runtime/trace" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -205,9 +206,11 @@ func (*runApp) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.RUnlock() } + region := trace.StartRegion(t.traceContext, runRegion) t.accountTaskGoroutineEnter(TaskGoroutineRunningApp) info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU) t.accountTaskGoroutineLeave(TaskGoroutineRunningApp) + region.End() if clearSinglestep { t.Arch().ClearSingleStep() @@ -225,6 +228,7 @@ func (*runApp) execute(t *Task) taskRunState { case platform.ErrContextSignalCPUID: // Is this a CPUID instruction? + region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) @@ -232,10 +236,12 @@ func (*runApp) execute(t *Task) taskRunState { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) + region.End() // Resume execution. return (*runApp)(nil) } + region.End() // Not an actual CPUID, but required copy-in. // The instruction at the given RIP was not a CPUID, and we // fallthrough to the default signal deliver behavior below. @@ -251,8 +257,10 @@ func (*runApp) execute(t *Task) taskRunState { // an application-generated signal and we should continue execution // normally. if at.Any() { + region := trace.StartRegion(t.traceContext, faultRegion) addr := usermem.Addr(info.Addr()) err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack())) + region.End() if err == nil { // The fault was handled appropriately. // We can resume running the application. @@ -260,6 +268,12 @@ func (*runApp) execute(t *Task) taskRunState { } // Is this a vsyscall that we need emulate? + // + // Note that we don't track vsyscalls as part of a + // specific trace region. This is because regions don't + // stack, and the actual system call will count as a + // region. We should be able to easily identify + // vsyscalls by having a <fault><syscall> pair. if at.Execute { if sysno, ok := t.tc.st.LookupEmulate(addr); ok { return t.doVsyscall(addr, sysno) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index ae6fc4025..3522a4ae5 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -154,10 +154,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { // Below this point, newTask is expected not to fail (there is no rollback // of assignTIDsLocked or any of the following). - // Logging on t's behalf will panic if t.logPrefix hasn't been initialized. - // This is the earliest point at which we can do so (since t now has thread - // IDs). - t.updateLogPrefixLocked() + // Logging on t's behalf will panic if t.logPrefix hasn't been + // initialized. This is the earliest point at which we can do so + // (since t now has thread IDs). + t.updateInfoLocked() if cfg.InheritParent != nil { t.parent = cfg.InheritParent.parent diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index b543d536a..3180f5560 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -17,6 +17,7 @@ package kernel import ( "fmt" "os" + "runtime/trace" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -160,6 +161,10 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u ctrl = ctrlStopAndReinvokeSyscall } else { fn := s.Lookup(sysno) + var region *trace.Region // Only non-nil if tracing == true. + if trace.IsEnabled() { + region = trace.StartRegion(t.traceContext, s.LookupName(sysno)) + } if fn != nil { // Call our syscall implementation. rval, ctrl, err = fn(t, args) @@ -167,6 +172,9 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u // Use the missing function if not found. rval, err = t.SyscallTable().Missing(t, sysno, args) } + if region != nil { + region.End() + } } if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) { diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index 34f84487a..048de26dc 100644 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go @@ -21,8 +21,19 @@ import "sync" // // +stateify savable type TTY struct { + // Index is the terminal index. It is immutable. + Index uint32 + mu sync.Mutex `state:"nosave"` // tg is protected by mu. tg *ThreadGroup } + +// TTY returns the thread group's controlling terminal. If nil, there is no +// controlling terminal. +func (tg *ThreadGroup) TTY() *TTY { + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + return tg.tty +} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 2d9251e92..6299a3e2f 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -408,6 +408,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el start = vaddr } if vaddr < end { + // NOTE(b/37474556): Linux allows out-of-order + // segments, in violation of the spec. ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end) return loadedELF{}, syserror.ENOEXEC } @@ -624,15 +626,15 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, in return loadParsedELF(ctx, m, f, info, 0) } -// loadELF loads f into the Task address space. +// loadELF loads args.File into the Task address space. // // If loadELF returns ErrSwitchFile it should be called again with the returned // path and argv. // // Preconditions: -// * f is an ELF file -func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) { - bin, ac, err := loadInitialELF(ctx, m, fs, f) +// * args.File is an ELF file +func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) { + bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File) if err != nil { ctx.Infof("Error loading binary: %v", err) return loadedELF{}, nil, err @@ -640,7 +642,14 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace var interp loadedELF if bin.interpreter != "" { - d, i, err := openPath(ctx, mounts, root, wd, maxTraversals, bin.interpreter) + // Even if we do not allow the final link of the script to be + // resolved, the interpreter should still be resolved if it is + // a symlink. + args.ResolveFinal = true + // Refresh the traversal limit. + *args.RemainingTraversals = linux.MaxSymlinkTraversals + args.Filename = bin.interpreter + d, i, err := openPath(ctx, args) if err != nil { ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err) return loadedELF{}, nil, err @@ -649,7 +658,7 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace // We don't need the Dirent. d.DecRef() - interp, err = loadInterpreterELF(ctx, m, i, bin) + interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin) if err != nil { ctx.Infof("Error loading interpreter: %v", err) return loadedELF{}, nil, err diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 089d1635b..b03eeb005 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package loader loads a binary into a MemoryManager. +// Package loader loads an executable file into a MemoryManager. package loader import ( @@ -20,6 +20,7 @@ import ( "fmt" "io" "path" + "strings" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" @@ -35,6 +36,54 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LoadArgs holds specifications for an executable file to be loaded. +type LoadArgs struct { + // MemoryManager is the memory manager to load the executable into. + MemoryManager *mm.MemoryManager + + // Mounts is the mount namespace in which to look up Filename. + Mounts *fs.MountNamespace + + // Root is the root directory under which to look up Filename. + Root *fs.Dirent + + // WorkingDirectory is the working directory under which to look up + // Filename. + WorkingDirectory *fs.Dirent + + // RemainingTraversals is the maximum number of symlinks to follow to + // resolve Filename. This counter is passed by reference to keep it + // updated throughout the call stack. + RemainingTraversals *uint + + // ResolveFinal indicates whether the final link of Filename should be + // resolved, if it is a symlink. + ResolveFinal bool + + // Filename is the path for the executable. + Filename string + + // File is an open fs.File object of the executable. If File is not + // nil, then File will be loaded and Filename will be ignored. + File *fs.File + + // CloseOnExec indicates that the executable (or one of its parent + // directories) was opened with O_CLOEXEC. If the executable is an + // interpreter script, then cause an ENOENT error to occur, since the + // script would otherwise be inaccessible to the interpreter. + CloseOnExec bool + + // Argv is the vector of arguments to pass to the executable. + Argv []string + + // Envv is the vector of environment variables to pass to the + // executable. + Envv []string + + // Features specifies the CPU feature set for the executable. + Features *cpuid.FeatureSet +} + // readFull behaves like io.ReadFull for an *fs.File. func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { var total int64 @@ -51,80 +100,82 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in return total, nil } -// openPath opens name for loading. +// openPath opens args.Filename and checks that it is valid for loading. // -// openPath returns the fs.Dirent and an *fs.File for name, which is not +// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not // installed in the Task FDTable. The caller takes ownership of both. // -// name must be a readable, executable, regular file. -func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, name string) (*fs.Dirent, *fs.File, error) { - if name == "" { +// args.Filename must be a readable, executable, regular file. +func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) { + if args.Filename == "" { ctx.Infof("cannot open empty name") return nil, nil, syserror.ENOENT } - d, err := mm.FindInode(ctx, root, wd, name, maxTraversals) + var d *fs.Dirent + var err error + if args.ResolveFinal { + d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } else { + d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } if err != nil { return nil, nil, err } - - // Open file will take a reference to Dirent, so destroy this one. + // Defer a DecRef for the sake of failure cases. defer d.DecRef() - return openFile(ctx, nil, d, name) -} + if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) { + return nil, nil, syserror.ELOOP + } -// openFile performs checks on a file to be executed. If provided a *fs.File, -// openFile takes that file's Dirent and performs checks on it. If provided a -// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's -// Inode and performs checks on that. -// -// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership -// of both. -// -// "dirent" and "file" must not both be nil and point to a readable, executable, regular file. -func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) { - // file and dirent must not be nil. - if dirent == nil && file == nil { - ctx.Infof("dirent and file cannot both be nil.") - return nil, nil, syserror.ENOENT + if err := checkPermission(ctx, d); err != nil { + return nil, nil, err } - if file != nil { - dirent = file.Dirent + // If they claim it's a directory, then make sure. + // + // N.B. we reject directories below, but we must first reject + // non-directories passed as directories. + if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) { + return nil, nil, syserror.ENOTDIR } - // Perform permissions checks on the file. - if err := checkFile(ctx, dirent, name); err != nil { + if err := checkIsRegularFile(ctx, d, args.Filename); err != nil { return nil, nil, err } - if file == nil { - var ferr error - if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil { - return nil, nil, ferr - } - } else { - // GetFile takes a reference to the created file, so make one in the case - // that the file reference already existed. - file.IncRef() + f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) + if err != nil { + return nil, nil, err + } + // Defer a DecRef for the sake of failure cases. + defer f.DecRef() + + if err := checkPread(ctx, f, args.Filename); err != nil { + return nil, nil, err + } + + d.IncRef() + f.IncRef() + return d, f, err +} + +// checkFile performs checks on a file to be executed. +func checkFile(ctx context.Context, f *fs.File, filename string) error { + if err := checkPermission(ctx, f.Dirent); err != nil { + return err } - // We must be able to read at arbitrary offsets. - if !file.Flags().Pread { - file.DecRef() - ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags()) - return nil, nil, syserror.EACCES + if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil { + return err } - // Grab reference for caller. - dirent.IncRef() - return dirent, file, nil + return checkPread(ctx, f, filename) } -// checkFile performs file permissions checks for binaries called in openPath -// and openFile -func checkFile(ctx context.Context, d *fs.Dirent, name string) error { +// checkPermission checks whether the file is readable and executable. +func checkPermission(ctx context.Context, d *fs.Dirent) error { perms := fs.PermMask{ // TODO(gvisor.dev/issue/160): Linux requires only execute // permission, not read. However, our backing filesystems may @@ -135,26 +186,26 @@ func checkFile(ctx context.Context, d *fs.Dirent, name string) error { Read: true, Execute: true, } - if err := d.Inode.CheckPermission(ctx, perms); err != nil { - return err - } + return d.Inode.CheckPermission(ctx, perms) +} - // If they claim it's a directory, then make sure. - // - // N.B. we reject directories below, but we must first reject - // non-directories passed as directories. - if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) { - return syserror.ENOTDIR +// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc. +func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error { + attr := d.Inode.StableAttr + if !fs.IsRegular(attr) { + ctx.Infof("%s is not regular: %v", filename, attr) + return syserror.EACCES } + return nil +} - // No exec-ing directories, pipes, etc! - if !fs.IsRegular(d.Inode.StableAttr) { - ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr) +// checkPread checks whether we can read the file at arbitrary offsets. +func checkPread(ctx context.Context, f *fs.File, filename string) error { + if !f.Flags().Pread { + ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags()) return syserror.EACCES } - return nil - } // allocStack allocates and maps a stack in to any available part of the address space. @@ -173,45 +224,49 @@ const ( maxLoaderAttempts = 6 ) -// loadBinary loads a binary that is pointed to by "file". If nil, the path -// "filename" is resolved and loaded. +// loadExecutable loads an executable that is pointed to by args.File. If nil, +// the path args.Filename is resolved and loaded. If the executable is an +// interpreter script rather than an ELF, the binary of the corresponding +// interpreter will be loaded. // // It returns: // * loadedELF, description of the loaded binary // * arch.Context matching the binary arch // * fs.Dirent of the binary file -// * Possibly updated argv -func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) { +// * Possibly updated args.Argv +func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) { for i := 0; i < maxLoaderAttempts; i++ { var ( d *fs.Dirent - f *fs.File err error ) - if passedFile == nil { - d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename) - + if args.File == nil { + d, args.File, err = openPath(ctx, args) + // We will return d in the successful case, but defer a DecRef for the + // sake of intermediate loops and failure cases. + if d != nil { + defer d.DecRef() + } + if args.File != nil { + defer args.File.DecRef() + } } else { - d, f, err = openFile(ctx, passedFile, nil, "") - // Set to nil in case we loop on a Interpreter Script. - passedFile = nil + d = args.File.Dirent + d.IncRef() + defer d.DecRef() + err = checkFile(ctx, args.File, args.Filename) } - if err != nil { - ctx.Infof("Error opening %s: %v", filename, err) + ctx.Infof("Error opening %s: %v", args.Filename, err) return loadedELF{}, nil, nil, nil, err } - defer f.DecRef() - // We will return d in the successful case, but defer a DecRef - // for intermediate loops and failure cases. - defer d.DecRef() // Check the header. Is this an ELF or interpreter script? var hdr [4]uint8 // N.B. We assume that reading from a regular file cannot block. - _, err = readFull(ctx, f, usermem.BytesIOSequence(hdr[:]), 0) - // Allow unexpected EOF, as a valid executable could be only three - // bytes (e.g., #!a). + _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0) + // Allow unexpected EOF, as a valid executable could be only three bytes + // (e.g., #!a). if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { err = syserror.ENOEXEC @@ -221,33 +276,38 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp switch { case bytes.Equal(hdr[:], []byte(elfMagic)): - loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f) + loaded, ac, err := loadELF(ctx, args) if err != nil { ctx.Infof("Error loading ELF: %v", err) return loadedELF{}, nil, nil, nil, err } // An ELF is always terminal. Hold on to d. d.IncRef() - return loaded, ac, d, argv, err + return loaded, ac, d, args.Argv, err case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)): - newpath, newargv, err := parseInterpreterScript(ctx, filename, f, argv) + if args.CloseOnExec { + return loadedELF{}, nil, nil, nil, syserror.ENOENT + } + args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv) if err != nil { ctx.Infof("Error loading interpreter script: %v", err) return loadedELF{}, nil, nil, nil, err } - filename = newpath - argv = newargv + // Refresh the traversal limit for the interpreter. + *args.RemainingTraversals = linux.MaxSymlinkTraversals default: ctx.Infof("Unknown magic: %v", hdr) return loadedELF{}, nil, nil, nil, syserror.ENOEXEC } + // Set to nil in case we loop on a Interpreter Script. + args.File = nil } return loadedELF{}, nil, nil, nil, syserror.ELOOP } -// Load loads "file" into a MemoryManager. If file is nil, the path "filename" -// is resolved and loaded instead. +// Load loads args.File into a MemoryManager. If args.File is nil, the path +// args.Filename is resolved and loaded instead. // // If Load returns ErrSwitchFile it should be called again with the returned // path and argv. @@ -255,37 +315,37 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp // Preconditions: // * The Task MemoryManager is empty. // * Load is called on the Task goroutine. -func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { - // Load the binary itself. - loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv) +func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { + // Load the executable itself. + loaded, ac, d, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer d.DecRef() // Load the VDSO. - vdsoAddr, err := loadVDSO(ctx, m, vdso, loaded) + vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the - // binary. Userspace can assume that the remainer of the page after + // executable. Userspace can assume that the remainer of the page after // loaded.end is available for its use. e, ok := loaded.end.RoundUp() if !ok { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC) } - m.BrkSetup(ctx, e) + args.MemoryManager.BrkSetup(ctx, e) // Allocate our stack. - stack, err := allocStack(ctx, m, ac) + stack, err := allocStack(ctx, args.MemoryManager, ac) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux()) } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(filename) + execfn, err := stack.Push(args.Filename) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } @@ -319,11 +379,12 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r }...) auxv = append(auxv, extraAuxv...) - sl, err := stack.Load(argv, envv, auxv) + sl, err := stack.Load(newArgv, args.Envv, auxv) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux()) } + m := args.MemoryManager m.SetArgvStart(sl.ArgvStart) m.SetArgvEnd(sl.ArgvEnd) m.SetEnvvStart(sl.EnvvStart) @@ -334,7 +395,7 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r ac.SetIP(uintptr(loaded.entry)) ac.SetStack(uintptr(stack.Bottom)) - name := path.Base(filename) + name := path.Base(args.Filename) if len(name) > linux.TASK_COMM_LEN-1 { name = name[:linux.TASK_COMM_LEN-1] } diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index ada28aea3..df8a81907 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -268,6 +268,8 @@ func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, er // some applications may not be able to handle multiple [vdso] // hints. vdso: mm.NewSpecialMappable("", mfp, vdso), + os: info.os, + arch: info.arch, phdrs: info.phdrs, }, nil } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 3ef84245b..112794e9c 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -41,7 +41,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", - "//pkg/refs", "//pkg/sentry/context", "//pkg/sentry/platform", "//pkg/sentry/usermem", diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 03b99aaea..16a722a13 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -18,7 +18,6 @@ package memmap import ( "fmt" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usermem" @@ -235,8 +234,11 @@ type InvalidateOpts struct { // coincidental; fs.File implements MappingIdentity, and some // fs.InodeOperations implement Mappable.) type MappingIdentity interface { - // MappingIdentity is reference-counted. - refs.RefCounter + // IncRef increments the MappingIdentity's reference count. + IncRef() + + // DecRef decrements the MappingIdentity's reference count. + DecRef() // MappedName returns the application-visible name shown in // /proc/[pid]/maps. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index a804b8b5c..839931f67 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -118,9 +118,9 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/syncutil", "//pkg/syserror", "//pkg/tcpip/buffer", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index f350e0109..58a5c186d 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -44,7 +44,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // MemoryManager implements a virtual address space. @@ -82,7 +82,7 @@ type MemoryManager struct { users int32 // mappingMu is analogous to Linux's struct mm_struct::mmap_sem. - mappingMu gvsync.DowngradableRWMutex `state:"nosave"` + mappingMu syncutil.DowngradableRWMutex `state:"nosave"` // vmas stores virtual memory areas. Since vmas are stored by value, // clients should usually use vmaIterator.ValuePtr() instead of @@ -125,7 +125,7 @@ type MemoryManager struct { // activeMu is loosely analogous to Linux's struct // mm_struct::page_table_lock. - activeMu gvsync.DowngradableRWMutex `state:"nosave"` + activeMu syncutil.DowngradableRWMutex `state:"nosave"` // pmas stores platform mapping areas used to implement vmas. Since pmas // are stored by value, clients should usually use pmaIterator.ValuePtr() diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index 1effc7735..aafce1d00 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -16,6 +16,7 @@ package pgalloc import ( "bytes" + "context" "fmt" "io" "runtime" @@ -29,7 +30,7 @@ import ( ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // Save metadata. - if err := state.Save(w, &f.fileSize, nil); err != nil { + if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { return err } - if err := state.Save(w, &f.usage, nil); err != nil { + if err := state.Save(ctx, w, &f.usage, nil); err != nil { return err } @@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { // Load metadata. - if err := state.Load(r, &f.fileSize, nil); err != nil { + if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(r, &f.usage, nil); err != nil { + if err := state.Load(ctx, r, &f.usage, nil); err != nil { return err } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 31fa48ec5..f3afd98da 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -12,19 +12,30 @@ go_library( "bluepill_amd64.go", "bluepill_amd64.s", "bluepill_amd64_unsafe.go", + "bluepill_arm64.go", + "bluepill_arm64.s", + "bluepill_arm64_unsafe.go", "bluepill_fault.go", "bluepill_unsafe.go", "context.go", - "filters.go", + "filters_amd64.go", + "filters_arm64.go", "kvm.go", "kvm_amd64.go", "kvm_amd64_unsafe.go", + "kvm_arm64.go", + "kvm_arm64_unsafe.go", "kvm_const.go", + "kvm_const_arm64.go", "machine.go", "machine_amd64.go", "machine_amd64_unsafe.go", + "machine_arm64.go", + "machine_arm64_unsafe.go", "machine_unsafe.go", "physical_map.go", + "physical_map_amd64.go", + "physical_map_arm64.go", "virtual_map.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index acd41f73d..ea8b9632e 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // not have physical mappings, the KVM module may inject // spurious exceptions when emulation fails (i.e. it tries to // emulate because the RIP is pointed at those pages). - as.machine.mapPhysical(physical, length) + as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) // Install the page table mappings. Note that the ordering is // important; if the pagetable mappings were installed before diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go index 80942e9c9..3f35414bb 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/allocator.go @@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { // //go:nosplit func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { - virtualStart, physicalStart, _, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) } diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 043de51b3..30dbb74d6 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -20,6 +20,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" ) @@ -36,6 +37,18 @@ func sighandler() func dieTrampoline() var ( + // bounceSignal is the signal used for bouncing KVM. + // + // We use SIGCHLD because it is not masked by the runtime, and + // it will be ignored properly by other parts of the kernel. + bounceSignal = syscall.SIGCHLD + + // bounceSignalMask has only bounceSignal set. + bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1)) + + // bounce is the interrupt vector used to return to the kernel. + bounce = uint32(ring0.VirtualizationException) + // savedHandler is a pointer to the previous handler. // // This is called by bluepillHandler. @@ -45,6 +58,13 @@ var ( dieTrampolineAddr uintptr ) +// redpill invokes a syscall with -1. +// +//go:nosplit +func redpill() { + syscall.RawSyscall(^uintptr(0), 0, 0, 0) +} + // dieHandler is called by dieTrampoline. // //go:nosplit @@ -73,8 +93,8 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { - panic(fmt.Sprintf("Unable to set handler for signal %d: %v", syscall.SIGSEGV, err)) + if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } // Extract the address for the trampoline. diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 421c88220..133c2203d 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -24,26 +24,10 @@ import ( ) var ( - // bounceSignal is the signal used for bouncing KVM. - // - // We use SIGCHLD because it is not masked by the runtime, and - // it will be ignored properly by other parts of the kernel. - bounceSignal = syscall.SIGCHLD - - // bounceSignalMask has only bounceSignal set. - bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1)) - - // bounce is the interrupt vector used to return to the kernel. - bounce = uint32(ring0.VirtualizationException) + // The action for bluepillSignal is changed by sigaction(). + bluepillSignal = syscall.SIGSEGV ) -// redpill on amd64 invokes a syscall with -1. -// -//go:nosplit -func redpill() { - syscall.RawSyscall(^uintptr(0), 0, 0, 0) -} - // bluepillArchEnter is called during bluepillEnter. // //go:nosplit diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 9d8af143e..a63a6a071 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -23,13 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) -// bluepillArchContext returns the arch-specific context. -// -//go:nosplit -func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { - return &((*arch.UContext64)(context).MContext) -} - // dieArchSetup initializes the state for dieTrampoline. // // The amd64 dieTrampoline requires the vCPU to be set in BX, and the last RIP diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go new file mode 100644 index 000000000..552341721 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -0,0 +1,79 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" +) + +var ( + // The action for bluepillSignal is changed by sigaction(). + bluepillSignal = syscall.SIGILL +) + +// bluepillArchEnter is called during bluepillEnter. +// +//go:nosplit +func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { + c = vCPUPtr(uintptr(context.Regs[8])) + regs := c.CPU.Registers() + regs.Regs = context.Regs + regs.Sp = context.Sp + regs.Pc = context.Pc + regs.Pstate = context.Pstate + regs.Pstate &^= uint64(ring0.KernelFlagsClear) + regs.Pstate |= ring0.KernelFlagsSet + return +} + +// bluepillArchExit is called during bluepillEnter. +// +//go:nosplit +func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { + regs := c.CPU.Registers() + context.Regs = regs.Regs + context.Sp = regs.Sp + context.Pc = regs.Pc + context.Pstate = regs.Pstate + context.Pstate &^= uint64(ring0.UserFlagsClear) + context.Pstate |= ring0.UserFlagsSet +} + +// KernelSyscall handles kernel syscalls. +// +//go:nosplit +func (c *vCPU) KernelSyscall() { + regs := c.Registers() + if regs.Regs[8] != ^uint64(0) { + regs.Pc -= 4 // Rewind. + } + ring0.Halt() +} + +// KernelException handles kernel exceptions. +// +//go:nosplit +func (c *vCPU) KernelException(vector ring0.Vector) { + regs := c.Registers() + if vector == ring0.Vector(bounce) { + regs.Pc = 0 + } + ring0.Halt() +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s new file mode 100644 index 000000000..c61700892 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -0,0 +1,87 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// VCPU_CPU is the location of the CPU in the vCPU struct. +// +// This is guaranteed to be zero. +#define VCPU_CPU 0x0 + +// CPU_SELF is the self reference in ring0's percpu. +// +// This is guaranteed to be zero. +#define CPU_SELF 0x0 + +// Context offsets. +// +// Only limited use of the context is done in the assembly stub below, most is +// done in the Go handlers. +#define SIGINFO_SIGNO 0x0 +#define CONTEXT_PC 0x1B8 +#define CONTEXT_R0 0xB8 + +// See bluepill.go. +TEXT ·bluepill(SB),NOSPLIT,$0 +begin: + MOVD vcpu+0(FP), R8 + MOVD $VCPU_CPU(R8), R9 + ORR $0xffff000000000000, R9, R9 + // Trigger sigill. + // In ring0.Start(), the value of R8 will be stored into tpidr_el1. + // When the context was loaded into vcpu successfully, + // we will check if the value of R10 and R9 are the same. + WORD $0xd538d08a // MRS TPIDR_EL1, R10 +check_vcpu: + CMP R10, R9 + BEQ right_vCPU +wrong_vcpu: + CALL ·redpill(SB) + B begin +right_vCPU: + RET + +// sighandler: see bluepill.go for documentation. +// +// The arguments are the following: +// +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +// +TEXT ·sighandler(SB),NOSPLIT,$0 + // si_signo should be sigill. + MOVD SIGINFO_SIGNO(R1), R7 + CMPW $4, R7 + BNE fallback + + MOVD CONTEXT_PC(R2), R7 + CMPW $0, R7 + BEQ fallback + + MOVD R2, 8(RSP) + BL ·bluepillHandler(SB) // Call the handler. + + RET + +fallback: + // Jump to the previous signal handler. + MOVD ·savedHandler(SB), R7 + B (R7) + +// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. +TEXT ·dieTrampoline(SB),NOSPLIT,$0 + // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. + MOVD R9, 8(RSP) + BL ·dieHandler(SB) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go new file mode 100644 index 000000000..e5fac0d6a --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" +) + +//go:nosplit +func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { + // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. +} diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index b97476053..f6459cda9 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -46,9 +46,9 @@ func yield() { // calculateBluepillFault calculates the fault address range. // //go:nosplit -func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) { +func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) { alignedPhysical := physical &^ uintptr(usermem.PageSize-1) - for _, pr := range physicalRegions { + for _, pr := range phyRegions { end := pr.physical + pr.length if physical < pr.physical || physical >= end { continue @@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng // The corresponding virtual address is returned. This may throw on error. // //go:nosplit -func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { +func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) { // Paging fault: we need to map the underlying physical pages for this // fault. This all has to be done in this function because we're in a // signal handler context. (We can't call any functions that might // split the stack.) - virtualStart, physicalStart, length, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { return 0, false } @@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { yield() // Race with another call. slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0)) } - errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart) + errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { // Successfully added region; we can increment nextSlot and // allow another set to proceed here. diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 7e8e9f42a..9add7c944 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -23,6 +23,8 @@ import ( "sync/atomic" "syscall" "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" ) //go:linkname throw runtime.throw @@ -49,6 +51,13 @@ func uintptrValue(addr *byte) uintptr { return (uintptr)(unsafe.Pointer(addr)) } +// bluepillArchContext returns the UContext64. +// +//go:nosplit +func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { + return &((*arch.UContext64)(context).MContext) +} + // bluepillHandler is called from the signal stub. // // The world may be stopped while this is executing, and it executes on the @@ -80,13 +89,17 @@ func bluepillHandler(context unsafe.Pointer) { // interrupted KVM. Since we're in a signal handler // currently, all signals are masked and the signal // must have been delivered directly to this thread. + timeout := syscall.Timespec{} sig, _, errno := syscall.RawSyscall6( syscall.SYS_RT_SIGTIMEDWAIT, uintptr(unsafe.Pointer(&bounceSignalMask)), - 0, // siginfo. - 0, // timeout. - 8, // sigset size. + 0, // siginfo. + uintptr(unsafe.Pointer(&timeout)), // timeout. + 8, // sigset size. 0, 0) + if errno == syscall.EAGAIN { + continue + } if errno != 0 { throw("error waiting for pending signal") } @@ -162,7 +175,7 @@ func bluepillHandler(context unsafe.Pointer) { // For MMIO, the physical address is the first data item. physical := uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical) + virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) if !ok { c.die(bluepillArchContext(context), "invalid physical address") return diff --git a/pkg/sentry/platform/kvm/filters.go b/pkg/sentry/platform/kvm/filters_amd64.go index 7d949f1dd..7d949f1dd 100644 --- a/pkg/sentry/platform/kvm/filters.go +++ b/pkg/sentry/platform/kvm/filters_amd64.go diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go new file mode 100644 index 000000000..9245d07c2 --- /dev/null +++ b/pkg/sentry/platform/kvm/filters_arm64.go @@ -0,0 +1,32 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + +// SyscallFilters returns syscalls made exclusively by the KVM platform. +func (*KVM) SyscallFilters() seccomp.SyscallRules { + return seccomp.SyscallRules{ + syscall.SYS_IOCTL: {}, + syscall.SYS_MMAP: {}, + syscall.SYS_RT_SIGSUSPEND: {}, + syscall.SYS_RT_SIGTIMEDWAIT: {}, + 0xffffffffffffffff: {}, // KVM uses syscall -1 to transition to host. + } +} diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index ee4cd2f4d..f2c2c059e 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -21,7 +21,6 @@ import ( "sync" "syscall" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" @@ -56,9 +55,7 @@ func New(deviceFile *os.File) (*KVM, error) { // Ensure global initialization is done. globalOnce.Do(func() { - physicalInit() - globalErr = updateSystemValues(int(fd)) - ring0.Init(cpuid.HostFeatureSet()) + globalErr = updateGlobalOnce(int(fd)) }) if globalErr != nil { return nil, globalErr diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go index 5d8ef4761..c5a6f9c7d 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64.go +++ b/pkg/sentry/platform/kvm/kvm_amd64.go @@ -17,6 +17,7 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) @@ -211,3 +212,11 @@ type cpuidEntries struct { _ uint32 entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry } + +// updateGlobalOnce does global initialization. It has to be called only once. +func updateGlobalOnce(fd int) error { + physicalInit() + err := updateSystemValues(int(fd)) + ring0.Init(cpuid.HostFeatureSet()) + return err +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go new file mode 100644 index 000000000..2319c86d3 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -0,0 +1,83 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "syscall" +) + +// userMemoryRegion is a region of physical memory. +// +// This mirrors kvm_memory_region. +type userMemoryRegion struct { + slot uint32 + flags uint32 + guestPhysAddr uint64 + memorySize uint64 + userspaceAddr uint64 +} + +type kvmOneReg struct { + id uint64 + addr uint64 +} + +const KVM_NR_SPSR = 5 + +type userFpsimdState struct { + vregs [64]uint64 + fpsr uint32 + fpcr uint32 + reserved [2]uint32 +} + +type userRegs struct { + Regs syscall.PtraceRegs + sp_el1 uint64 + elr_el1 uint64 + spsr [KVM_NR_SPSR]uint64 + fpRegs userFpsimdState +} + +// runData is the run structure. This may be mapped for synchronous register +// access (although that doesn't appear to be supported by my kernel at least). +// +// This mirrors kvm_run. +type runData struct { + requestInterruptWindow uint8 + _ [7]uint8 + + exitReason uint32 + readyForInterruptInjection uint8 + ifFlag uint8 + _ [2]uint8 + + cr8 uint64 + apicBase uint64 + + // This is the union data for exits. Interpretation depends entirely on + // the exitReason above (see vCPU code for more information). + data [32]uint64 +} + +// updateGlobalOnce does global initialization. It has to be called only once. +func updateGlobalOnce(fd int) error { + physicalInit() + err := updateSystemValues(int(fd)) + updateVectorTable() + return err +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go new file mode 100644 index 000000000..6531bae1d --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go @@ -0,0 +1,39 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "fmt" + "syscall" +) + +var ( + runDataSize int +) + +func updateSystemValues(fd int) error { + // Extract the mmap size. + sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0) + if errno != 0 { + return fmt.Errorf("getting VCPU mmap size: %v", errno) + } + // Save the data. + runDataSize = int(sz) + + // Success. + return nil +} diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index d05f05c29..1d5c77ff4 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -49,11 +49,13 @@ const ( _KVM_EXIT_SHUTDOWN = 0x8 _KVM_EXIT_FAIL_ENTRY = 0x9 _KVM_EXIT_INTERNAL_ERROR = 0x11 + _KVM_EXIT_SYSTEM_EVENT = 0x18 ) // KVM capability options. const ( - _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 ) // KVM limits. @@ -62,3 +64,10 @@ const ( _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 ) + +// KVM kvm_memory_region::flags. +const ( + _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0 + _KVM_MEM_READONLY = uint32(1) << 1 + _KVM_MEM_FLAGS_NONE = 0 +) diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go new file mode 100644 index 000000000..5a74c6e36 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -0,0 +1,132 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +// KVM ioctls for Arm64. +const ( + _KVM_GET_ONE_REG = 0x4010aeab + _KVM_SET_ONE_REG = 0x4010aeac + + _KVM_ARM_PREFERRED_TARGET = 0x8020aeaf + _KVM_ARM_VCPU_INIT = 0x4020aeae + _KVM_ARM64_REGS_PSTATE = 0x6030000000100042 + _KVM_ARM64_REGS_SP_EL1 = 0x6030000000100044 + _KVM_ARM64_REGS_R0 = 0x6030000000100000 + _KVM_ARM64_REGS_R1 = 0x6030000000100002 + _KVM_ARM64_REGS_R2 = 0x6030000000100004 + _KVM_ARM64_REGS_R3 = 0x6030000000100006 + _KVM_ARM64_REGS_R8 = 0x6030000000100010 + _KVM_ARM64_REGS_R18 = 0x6030000000100024 + _KVM_ARM64_REGS_PC = 0x6030000000100040 + _KVM_ARM64_REGS_MAIR_EL1 = 0x603000000013c510 + _KVM_ARM64_REGS_TCR_EL1 = 0x603000000013c102 + _KVM_ARM64_REGS_TTBR0_EL1 = 0x603000000013c100 + _KVM_ARM64_REGS_TTBR1_EL1 = 0x603000000013c101 + _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080 + _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082 + _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600 +) + +// Arm64: Architectural Feature Access Control Register EL1. +const ( + _FPEN_NOTRAP = 0x3 + _FPEN_SHIFT = 0x20 +) + +// Arm64: System Control Register EL1. +const ( + _SCTLR_M = 1 << 0 + _SCTLR_C = 1 << 2 + _SCTLR_I = 1 << 12 +) + +// Arm64: Translation Control Register EL1. +const ( + _TCR_IPS_40BITS = 2 << 32 // PA=40 + _TCR_IPS_48BITS = 5 << 32 // PA=48 + + _TCR_T0SZ_OFFSET = 0 + _TCR_T1SZ_OFFSET = 16 + _TCR_IRGN0_SHIFT = 8 + _TCR_IRGN1_SHIFT = 24 + _TCR_ORGN0_SHIFT = 10 + _TCR_ORGN1_SHIFT = 26 + _TCR_SH0_SHIFT = 12 + _TCR_SH1_SHIFT = 28 + _TCR_TG0_SHIFT = 14 + _TCR_TG1_SHIFT = 30 + + _TCR_T0SZ_VA48 = 64 - 48 // VA=48 + _TCR_T1SZ_VA48 = 64 - 48 // VA=48 + + _TCR_ASID16 = 1 << 36 + _TCR_TBI0 = 1 << 37 + + _TCR_TXSZ_VA48 = (_TCR_T0SZ_VA48 << _TCR_T0SZ_OFFSET) | (_TCR_T1SZ_VA48 << _TCR_T1SZ_OFFSET) + + _TCR_TG0_4K = 0 << _TCR_TG0_SHIFT // 4K + _TCR_TG0_64K = 1 << _TCR_TG0_SHIFT // 64K + + _TCR_TG1_4K = 2 << _TCR_TG1_SHIFT + + _TCR_TG_FLAGS = _TCR_TG0_4K | _TCR_TG1_4K + + _TCR_IRGN0_WBWA = 1 << _TCR_IRGN0_SHIFT + _TCR_IRGN1_WBWA = 1 << _TCR_IRGN1_SHIFT + _TCR_IRGN_WBWA = _TCR_IRGN0_WBWA | _TCR_IRGN1_WBWA + + _TCR_ORGN0_WBWA = 1 << _TCR_ORGN0_SHIFT + _TCR_ORGN1_WBWA = 1 << _TCR_ORGN1_SHIFT + + _TCR_ORGN_WBWA = _TCR_ORGN0_WBWA | _TCR_ORGN1_WBWA + + _TCR_SHARED = (3 << _TCR_SH0_SHIFT) | (3 << _TCR_SH1_SHIFT) + + _TCR_CACHE_FLAGS = _TCR_IRGN_WBWA | _TCR_ORGN_WBWA +) + +// Arm64: Memory Attribute Indirection Register EL1. +const ( + _MT_DEVICE_nGnRnE = 0 + _MT_DEVICE_nGnRE = 1 + _MT_DEVICE_GRE = 2 + _MT_NORMAL_NC = 3 + _MT_NORMAL = 4 + _MT_NORMAL_WT = 5 + _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8) +) + +const ( + _KVM_ARM_VCPU_POWER_OFF = 0 // CPU is started in OFF state + _KVM_ARM_VCPU_PSCI_0_2 = 2 // CPU uses PSCI v0.2 +) + +// Arm64: Exception Syndrome Register EL1. +const ( + _ESR_ELx_FSC = 0x3F + + _ESR_SEGV_MAPERR_L0 = 0x4 + _ESR_SEGV_MAPERR_L1 = 0x5 + _ESR_SEGV_MAPERR_L2 = 0x6 + _ESR_SEGV_MAPERR_L3 = 0x7 + + _ESR_SEGV_ACCERR_L1 = 0x9 + _ESR_SEGV_ACCERR_L2 = 0xa + _ESR_SEGV_ACCERR_L3 = 0xb + + _ESR_SEGV_PEMERR_L1 = 0xd + _ESR_SEGV_PEMERR_L2 = 0xe + _ESR_SEGV_PEMERR_L3 = 0xf +) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index cc6c138b2..7d02ebf19 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) + var physicalRegionsReadOnly []physicalRegion + var physicalRegionsAvailable []physicalRegion + + physicalRegionsReadOnly = rdonlyRegionsForSetMem() + physicalRegionsAvailable = availableRegionsForSetMem() + + // Map all read-only regions. + for _, r := range physicalRegionsReadOnly { + m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) + } + // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to @@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) { if excludeVirtualRegion(vr) { return // skip region. } + + for _, r := range physicalRegionsReadOnly { + if vr.virtual == r.virtual { + return + } + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length) + m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) virtual += length } }) @@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. -func (m *machine) mapPhysical(physical, length uintptr) { +func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { - _, physicalStart, length, ok := calculateBluepillFault(physical) + _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { // Should never happen. panic("mapPhysical on unknown physical address") @@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) { if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { // Not present in the cache; requires setting the slot. - if _, ok := handleBluepillFault(m, physical); !ok { + if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c1cbe33be..b99fe425e 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) { } } } + +// On x86 platform, the flags for "setMemoryRegion" can always be set as 0. +// There is no need to return read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + return nil +} + +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 506ec9af1..7156c245f 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -26,30 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" ) -// setMemoryRegion initializes a region. -// -// This may be called from bluepillHandler, and therefore returns an errno -// directly (instead of wrapping in an error) to avoid allocations. -// -//go:nosplit -func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { - userRegion := userMemoryRegion{ - slot: uint32(slot), - flags: 0, - guestPhysAddr: uint64(physical), - memorySize: uint64(length), - userspaceAddr: uint64(virtual), - } - - // Set the region. - _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(m.fd), - _KVM_SET_USER_MEMORY_REGION, - uintptr(unsafe.Pointer(&userRegion))) - return errno -} - // loadSegments copies the current segments. // // This may be called from within the signal context and throws on error. @@ -159,3 +135,43 @@ func (c *vCPU) setSignalMask() error { } return nil } + +// setUserRegisters sets user registers in the vCPU. +func (c *vCPU) setUserRegisters(uregs *userRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_REGS, + uintptr(unsafe.Pointer(uregs))); errno != 0 { + return fmt.Errorf("error setting user registers: %v", errno) + } + return nil +} + +// getUserRegisters reloads user registers in the vCPU. +// +// This is safe to call from a nosplit context. +// +//go:nosplit +func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_REGS, + uintptr(unsafe.Pointer(uregs))); errno != 0 { + return errno + } + return 0 +} + +// setSystemRegisters sets system registers. +func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SREGS, + uintptr(unsafe.Pointer(sregs))); errno != 0 { + return fmt.Errorf("error setting system registers: %v", errno) + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go new file mode 100644 index 000000000..7ae47f291 --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -0,0 +1,183 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +type vCPUArchState struct { + // PCIDs is the set of PCIDs for this vCPU. + // + // This starts above fixedKernelPCID. + PCIDs *pagetables.PCIDs +} + +const ( + // fixedKernelPCID is a fixed kernel PCID used for the kernel page + // tables. We must start allocating user PCIDs above this in order to + // avoid any conflict (see below). + fixedKernelPCID = 1 + + // poolPCIDs is the number of PCIDs to record in the database. As this + // grows, assignment can take longer, since it is a simple linear scan. + // Beyond a relatively small number, there are likely few perform + // benefits, since the TLB has likely long since lost any translations + // from more than a few PCIDs past. + poolPCIDs = 8 +) + +// Get all read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + var rdonlyRegions []region + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return + } + + if !vr.accessType.Write && vr.accessType.Read { + rdonlyRegions = append(rdonlyRegions, vr.region) + } + }) + + for _, r := range rdonlyRegions { + physical, _, ok := translateToPhysical(r.virtual) + if !ok { + continue + } + + phyRegions = append(phyRegions, physicalRegion{ + region: region{ + virtual: r.virtual, + length: r.length, + }, + physical: physical, + }) + } + + return phyRegions +} + +// Get all available physicalRegions. +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + var excludeRegions []region + applyVirtualRegions(func(vr virtualRegion) { + if !vr.accessType.Write { + excludeRegions = append(excludeRegions, vr.region) + } + }) + + phyRegions = computePhysicalRegions(excludeRegions) + + return phyRegions +} + +// dropPageTables drops cached page table entries. +func (m *machine) dropPageTables(pt *pagetables.PageTables) { + m.mu.Lock() + defer m.mu.Unlock() + + // Clear from all PCIDs. + for _, c := range m.vCPUs { + c.PCIDs.Drop(pt) + } +} + +// nonCanonical generates a canonical address return. +// +//go:nosplit +func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { + *info = arch.SignalInfo{ + Signo: signal, + Code: arch.SignalInfoKernel, + } + info.SetAddr(addr) // Include address. + return usermem.NoAccess, platform.ErrContextSignal +} + +// fault generates an appropriate fault return. +// +//go:nosplit +func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { + faultAddr := c.GetFaultAddr() + code, user := c.ErrorCode() + + // Reset the pointed SignalInfo. + *info = arch.SignalInfo{Signo: signal} + info.SetAddr(uint64(faultAddr)) + + read := true + write := false + execute := true + + ret := code & _ESR_ELx_FSC + switch ret { + case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3: + info.Code = 1 //SEGV_MAPERR + read = false + write = true + execute = false + case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3: + info.Code = 2 // SEGV_ACCERR. + read = true + write = false + execute = false + default: + info.Code = 2 + } + + if !user { + read = true + write = false + execute = true + + } + accessType := usermem.AccessType{ + Read: read, + Write: write, + Execute: execute, + } + + return accessType, platform.ErrContextSignal +} + +// retryInGuest runs the given function in guest mode. +// +// If the function does not complete in guest mode (due to execution of a +// system call due to a GC stall, for example), then it will be retried. The +// given function must be idempotent as a result of the retry mechanism. +func (m *machine) retryInGuest(fn func()) { + c := m.Get() + defer m.Put(c) + for { + c.ClearErrorCode() // See below. + bluepill(c) // Force guest mode. + fn() // Execute the given function. + _, user := c.ErrorCode() + if user { + // If user is set, then we haven't bailed back to host + // mode via a kernel exception or system call. We + // consider the full function to have executed in guest + // mode and we can return. + break + } + } +} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go new file mode 100644 index 000000000..3f2f97a6b --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -0,0 +1,362 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "fmt" + "reflect" + "sync/atomic" + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: 0, + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + +type kvmVcpuInit struct { + target uint32 + features [7]uint32 +} + +var vcpuInit kvmVcpuInit + +// initArchState initializes architecture-specific state. +func (m *machine) initArchState() error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_ARM_PREFERRED_TARGET, + uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { + panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno)) + } + return nil +} + +func getPageWithReflect(p uintptr) []byte { + return (*(*[0xFFFFFF]byte)(unsafe.Pointer(p & ^uintptr(syscall.Getpagesize()-1))))[:syscall.Getpagesize()] +} + +// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. +// +// According to the design documentation of Arm64, +// the start address of exception vector table should be 11-bits aligned. +// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S +// But, we can't align a function's start address to a specific address by using golang. +// We have raised this question in golang community: +// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I +// This function will be removed when golang supports this feature. +// +// There are 2 jobs were implemented in this function: +// 1, move the start address of exception vector table into the specific address. +// 2, modify the offset of each instruction. +func updateVectorTable() { + fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() + offset := fromLocation & (1<<11 - 1) + if offset != 0 { + offset = 1<<11 - offset + } + + toLocation := fromLocation + offset + page := getPageWithReflect(toLocation) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { + panic(err) + } + + page = getPageWithReflect(toLocation + 4096) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { + panic(err) + } + + // Move exception-vector-table into the specific address. + var entry *uint32 + var entryFrom *uint32 + for i := 1; i <= 0x800; i++ { + entry = (*uint32)(unsafe.Pointer(toLocation + 0x800 - uintptr(i))) + entryFrom = (*uint32)(unsafe.Pointer(fromLocation + 0x800 - uintptr(i))) + *entry = *entryFrom + } + + // The offset from the address of each unconditionally branch is changed. + // We should modify the offset of each instruction. + nums := []uint32{0x0, 0x80, 0x100, 0x180, 0x200, 0x280, 0x300, 0x380, 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780} + for _, num := range nums { + entry = (*uint32)(unsafe.Pointer(toLocation + uintptr(num))) + *entry = *entry - (uint32)(offset/4) + } + + page = getPageWithReflect(toLocation) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { + panic(err) + } + + page = getPageWithReflect(toLocation + 4096) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { + panic(err) + } +} + +// initArchState initializes architecture-specific state. +func (c *vCPU) initArchState() error { + var ( + reg kvmOneReg + data uint64 + regGet kvmOneReg + dataGet uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer()) + + vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_ARM_VCPU_INIT, + uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { + panic(fmt.Sprintf("error setting KVM_ARM_VCPU_INIT failed: %v", errno)) + } + + // cpacr_el1 + reg.id = _KVM_ARM64_REGS_CPACR_EL1 + data = (_FPEN_NOTRAP << _FPEN_SHIFT) + if err := c.setOneRegister(®); err != nil { + return err + } + + // sctlr_el1 + regGet.id = _KVM_ARM64_REGS_SCTLR_EL1 + if err := c.getOneRegister(®Get); err != nil { + return err + } + + dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I) + data = dataGet + reg.id = _KVM_ARM64_REGS_SCTLR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // tcr_el1 + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS + reg.id = _KVM_ARM64_REGS_TCR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // mair_el1 + data = _MT_EL1_INIT + reg.id = _KVM_ARM64_REGS_MAIR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // ttbr0_el1 + data = c.machine.kernel.PageTables.TTBR0_EL1(false, 0) + + reg.id = _KVM_ARM64_REGS_TTBR0_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + c.SetTtbr0Kvm(uintptr(data)) + + // ttbr1_el1 + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) + + reg.id = _KVM_ARM64_REGS_TTBR1_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // sp_el1 + data = c.CPU.StackTop() + reg.id = _KVM_ARM64_REGS_SP_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // pc + reg.id = _KVM_ARM64_REGS_PC + data = uint64(reflect.ValueOf(ring0.Start).Pointer()) + if err := c.setOneRegister(®); err != nil { + return err + } + + // r8 + reg.id = _KVM_ARM64_REGS_R8 + data = uint64(reflect.ValueOf(&c.CPU).Pointer()) + if err := c.setOneRegister(®); err != nil { + return err + } + + // vbar_el1 + reg.id = _KVM_ARM64_REGS_VBAR_EL1 + + fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() + offset := fromLocation & (1<<11 - 1) + if offset != 0 { + offset = 1<<11 - offset + } + + toLocation := fromLocation + offset + data = uint64(ring0.KernelStartAddress | toLocation) + if err := c.setOneRegister(®); err != nil { + return err + } + + data = ring0.PsrDefaultSet | ring0.KernelFlagsSet + reg.id = _KVM_ARM64_REGS_PSTATE + if err := c.setOneRegister(®); err != nil { + return err + } + + return nil +} + +//go:nosplit +func (c *vCPU) loadSegments(tid uint64) { + // TODO(gvisor.dev/issue/1238): TLS is not supported. + // Get TLS from tpidr_el0. + atomic.StoreUint64(&c.tid, tid) +} + +func (c *vCPU) setOneRegister(reg *kvmOneReg) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_ONE_REG, + uintptr(unsafe.Pointer(reg))); errno != 0 { + return fmt.Errorf("error setting one register: %v", errno) + } + return nil +} + +func (c *vCPU) getOneRegister(reg *kvmOneReg) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_ONE_REG, + uintptr(unsafe.Pointer(reg))); errno != 0 { + return fmt.Errorf("error setting one register: %v", errno) + } + return nil +} + +// setCPUID sets the CPUID to be used by the guest. +func (c *vCPU) setCPUID() error { + return nil +} + +// setSystemTime sets the TSC for the vCPU. +func (c *vCPU) setSystemTime() error { + return nil +} + +// setSignalMask sets the vCPU signal mask. +// +// This must be called prior to running the vCPU. +func (c *vCPU) setSignalMask() error { + // The layout of this structure implies that it will not necessarily be + // the same layout chosen by the Go compiler. It gets fudged here. + var data struct { + length uint32 + mask1 uint32 + mask2 uint32 + _ uint32 + } + data.length = 8 // Fixed sigset size. + data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) + data.mask2 = ^uint32(bounceSignalMask >> 32) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SIGNAL_MASK, + uintptr(unsafe.Pointer(&data))); errno != 0 { + return fmt.Errorf("error setting signal mask: %v", errno) + } + + return nil +} + +// SwitchToUser unpacks architectural-details. +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) { + // Check for canonical addresses. + if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { + return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info) + } else if !ring0.IsCanonical(regs.Sp) { + return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) + } + + var vector ring0.Vector + ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0) + c.SetTtbr0App(uintptr(ttbr0App)) + + // TODO(gvisor.dev/issue/1238): full context-switch supporting for Arm64. + // The Arm64 user-mode execution state consists of: + // x0-x30 + // PC, SP, PSTATE + // V0-V31: 32 128-bit registers for floating point, and simd + // FPSR + // TPIDR_EL0, used for TLS + appRegs := switchOpts.Registers + c.SetAppAddr(ring0.KernelStartAddress | uintptr(unsafe.Pointer(appRegs))) + + entersyscall() + bluepill(c) + vector = c.CPU.SwitchToUser(switchOpts) + exitsyscall() + + switch vector { + case ring0.Syscall: + // Fast path: system call executed. + return usermem.NoAccess, nil + + case ring0.PageFault: + return c.fault(int32(syscall.SIGSEGV), info) + case 0xaa: + return usermem.NoAccess, nil + default: + return usermem.NoAccess, platform.ErrContextSignal + } + +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 405e00292..f04be2ab5 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -35,6 +35,30 @@ func entersyscall() //go:linkname exitsyscall runtime.exitsyscall func exitsyscall() +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: uint32(flags), + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + // mapRunData maps the vCPU run data. func mapRunData(fd int) (*runData, error) { r, _, errno := syscall.RawSyscall6( @@ -63,46 +87,6 @@ func unmapRunData(r *runData) error { return nil } -// setUserRegisters sets user registers in the vCPU. -func (c *vCPU) setUserRegisters(uregs *userRegs) error { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_REGS, - uintptr(unsafe.Pointer(uregs))); errno != 0 { - return fmt.Errorf("error setting user registers: %v", errno) - } - return nil -} - -// getUserRegisters reloads user registers in the vCPU. -// -// This is safe to call from a nosplit context. -// -//go:nosplit -func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_GET_REGS, - uintptr(unsafe.Pointer(uregs))); errno != 0 { - return errno - } - return 0 -} - -// setSystemRegisters sets system registers. -func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_SREGS, - uintptr(unsafe.Pointer(sregs))); errno != 0 { - return fmt.Errorf("error setting system registers: %v", errno) - } - return nil -} - // atomicAddressSpace is an atomic address space pointer. type atomicAddressSpace struct { pointer unsafe.Pointer diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index 586e91bb2..91de5dab1 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -24,15 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" ) -const ( - // reservedMemory is a chunk of physical memory reserved starting at - // physical address zero. There are some special pages in this region, - // so we just call the whole thing off. - // - // Other architectures may define this to be zero. - reservedMemory = 0x100000000 -) - type region struct { virtual uintptr length uintptr @@ -59,8 +50,7 @@ func fillAddressSpace() (excludedRegions []region) { // We can cut vSize in half, because the kernel will be using the top // half and we ignore it while constructing mappings. It's as if we've // already excluded half the possible addresses. - vSize := uintptr(1) << ring0.VirtualAddressBits() - vSize = vSize >> 1 + vSize := ring0.UserspaceSize // We exclude reservedMemory below from our physical memory size, so it // needs to be dropped here as well. Otherwise, we could end up with diff --git a/pkg/sentry/platform/kvm/physical_map_amd64.go b/pkg/sentry/platform/kvm/physical_map_amd64.go new file mode 100644 index 000000000..c5adfb577 --- /dev/null +++ b/pkg/sentry/platform/kvm/physical_map_amd64.go @@ -0,0 +1,22 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +const ( + // reservedMemory is a chunk of physical memory reserved starting at + // physical address zero. There are some special pages in this region, + // so we just call the whole thing off. + reservedMemory = 0x100000000 +) diff --git a/pkg/sentry/platform/kvm/physical_map_arm64.go b/pkg/sentry/platform/kvm/physical_map_arm64.go new file mode 100644 index 000000000..4d8561453 --- /dev/null +++ b/pkg/sentry/platform/kvm/physical_map_arm64.go @@ -0,0 +1,19 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +const ( + reservedMemory = 0 +) diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 2cd28b2d2..0bebee852 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -50,6 +50,21 @@ TEXT ·SpinLoop(SB),NOSPLIT,$0 start: B start +TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 + NO_LOCAL_POINTERS + FMOVD $(9.9), F0 + MOVD $SYS_GETPID, R8 // getpid + SVC + FMOVD $(9.9), F1 + FCMPD F0, F1 + BNE isNaN + MOVD $1, R0 + MOVD R0, ret+0(FP) + RET +isNaN: + MOVD $0, ret+0(FP) + RET + // MVN: bitwise logical NOT // This case simulates an application that modified R0-R30. #define TWIDDLE_REGS() \ diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index ebcc8c098..0df8cfa0f 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/procid", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/hostcpu", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/safecopy", diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 9f0ecfbe4..ddb1f41e3 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -327,6 +327,20 @@ func (t *thread) dumpAndPanic(message string) { panic(message) } +func (t *thread) unexpectedStubExit() { + msg, err := t.getEventMessage() + status := syscall.WaitStatus(msg) + if status.Signaled() && status.Signal() == syscall.SIGKILL { + // SIGKILL can be only sent by an user or OOM-killer. In both + // these cases, we don't need to panic. There is no reasons to + // think that something wrong in gVisor. + log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid) + pid := os.Getpid() + syscall.Tgkill(pid, pid, syscall.Signal(syscall.SIGKILL)) + } + t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) +} + // wait waits for a stop event. // // Precondition: outcome is a valid waitOutcome. @@ -355,8 +369,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { } if stopSig == syscall.SIGTRAP { if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { - msg, err := t.getEventMessage() - t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) + t.unexpectedStubExit() } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 4649a94a7..606dc2b1d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -21,6 +21,8 @@ import ( "strings" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" ) @@ -143,3 +145,49 @@ func (t *thread) adjustInitRegsRip() { func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.R15 = uint64(ppid) } + +// patchSignalInfo patches the signal info to account for hitting the seccomp +// filters from vsyscall emulation, specified below. We allow for SIGSYS as a +// synchronous trap, but patch the structure to appear like a SIGSEGV with the +// Rip as the faulting address. +// +// Note that this should only be called after verifying that the signalInfo has +// been generated by the kernel. +func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { + if linux.Signal(signalInfo.Signo) == linux.SIGSYS { + signalInfo.Signo = int32(linux.SIGSEGV) + + // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered + // with the si_call_addr field pointing to the current RIP. This field + // aligns with the si_addr field for a SIGSEGV, so we don't need to touch + // anything there. We do need to unwind emulation however, so we set the + // instruction pointer to the faulting value, and "unpop" the stack. + regs.Rip = signalInfo.Addr() + regs.Rsp -= 8 + } +} + +// enableCpuidFault enables cpuid-faulting. +// +// This may fail on older kernels or hardware, so we just disregard the result. +// Host CPUID will be enabled. +// +// This is safe to call in an afterFork context. +// +//go:nosplit +func enableCpuidFault() { + syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0) +} + +// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. +// Ref attachedThread() for more detail. +func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { + return append(rules, seccomp.RuleSet{ + Rules: seccomp.SyscallRules{ + syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ + {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }) +} diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index bec884ba5..62a686ee7 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -17,8 +17,12 @@ package ptrace import ( + "fmt" + "strings" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" ) @@ -37,7 +41,7 @@ const ( // resetSysemuRegs sets up emulation registers. // // This should be called prior to calling sysemu. -func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) { +func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) { } // createSyscallRegs sets up syscall registers. @@ -124,3 +128,36 @@ func (t *thread) adjustInitRegsRip() { func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.Regs[7] = uint64(ppid) } + +// patchSignalInfo patches the signal info to account for hitting the seccomp +// filters from vsyscall emulation, specified below. We allow for SIGSYS as a +// synchronous trap, but patch the structure to appear like a SIGSEGV with the +// Rip as the faulting address. +// +// Note that this should only be called after verifying that the signalInfo has +// been generated by the kernel. +func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { + if linux.Signal(signalInfo.Signo) == linux.SIGSYS { + signalInfo.Signo = int32(linux.SIGSEGV) + + // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered + // with the si_call_addr field pointing to the current RIP. This field + // aligns with the si_addr field for a SIGSEGV, so we don't need to touch + // anything there. We do need to unwind emulation however, so we set the + // instruction pointer to the faulting value, and "unpop" the stack. + regs.Pc = signalInfo.Addr() + regs.Sp -= 8 + } +} + +// Noop on arm64. +// +//go:nosplit +func enableCpuidFault() { +} + +// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. +// Ref attachedThread() for more detail. +func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { + return rules +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index c075b5f91..cf13ea5e4 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -20,6 +20,7 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" @@ -77,27 +78,6 @@ func probeSeccomp() bool { } } -// patchSignalInfo patches the signal info to account for hitting the seccomp -// filters from vsyscall emulation, specified below. We allow for SIGSYS as a -// synchronous trap, but patch the structure to appear like a SIGSEGV with the -// Rip as the faulting address. -// -// Note that this should only be called after verifying that the signalInfo has -// been generated by the kernel. -func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { - if linux.Signal(signalInfo.Signo) == linux.SIGSYS { - signalInfo.Signo = int32(linux.SIGSEGV) - - // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered - // with the si_call_addr field pointing to the current RIP. This field - // aligns with the si_addr field for a SIGSEGV, so we don't need to touch - // anything there. We do need to unwind emulation however, so we set the - // instruction pointer to the faulting value, and "unpop" the stack. - regs.Rip = signalInfo.Addr() - regs.Rsp -= 8 - } -} - // createStub creates a fresh stub processes. // // Precondition: the runtime OS thread must be locked. @@ -129,6 +109,9 @@ func createStub() (*thread, error) { // transitively) will be killed as well. It's simply not possible to // safely handle a single stub getting killed: the exact state of // execution is unknown and not recoverable. + // + // In addition, we set the PTRACE_O_TRACEEXIT option to log more + // information about a stub process when it receives a fatal signal. return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction) } @@ -146,7 +129,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro Rules: seccomp.SyscallRules{ syscall.SYS_GETTIMEOFDAY: {}, syscall.SYS_TIME: {}, - 309: {}, // SYS_GETCPU. + unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64. }, Action: linux.SECCOMP_RET_TRAP, Vsyscall: true, @@ -170,10 +153,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // For the initial process creation. syscall.SYS_WAIT4: {}, - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, - }, - syscall.SYS_EXIT: {}, + syscall.SYS_EXIT: {}, // For the stub prctl dance (all). syscall.SYS_PRCTL: []seccomp.Rule{ @@ -193,6 +173,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }, Action: linux.SECCOMP_RET_ALLOW, }) + + rules = appendArchSeccompRules(rules) } instrs, err := seccomp.BuildProgram(rules, defaultAction) if err != nil { @@ -264,9 +246,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0) } - // Enable cpuid-faulting; this may fail on older kernels or hardware, - // so we just disregard the result. Host CPUID will be enabled. - syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0) + // Enable cpuid-faulting. + enableCpuidFault() // Call the stub; should not return. stubCall(stubStart, ppid) diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index de6783fb0..2e6fbe488 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -25,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/hostcpu" ) // maskPool contains reusable CPU masks for setting affinity. Unfortunately, @@ -49,20 +50,6 @@ func unmaskAllSignals() syscall.Errno { return errno } -// getCPU gets the current CPU. -// -// Precondition: the current runtime thread should be locked. -func getCPU() (uint32, error) { - var cpu uintptr - if _, _, errno := syscall.RawSyscall( - unix.SYS_GETCPU, - uintptr(unsafe.Pointer(&cpu)), - 0, 0); errno != 0 { - return 0, errno - } - return uint32(cpu), nil -} - // setCPU sets the CPU affinity. func (t *thread) setCPU(cpu uint32) error { mask := maskPool.Get().([]uintptr) @@ -93,10 +80,8 @@ func (t *thread) setCPU(cpu uint32) error { // // Precondition: the current runtime thread should be locked. func (t *thread) bind() { - currentCPU, err := getCPU() - if err != nil { - return - } + currentCPU := hostcpu.GetCPU() + if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU { // Set the affinity on the thread and save the CPU for next // round; we don't expect CPUs to bounce around too frequently. diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index b80a3604d..2ae6b9f9d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 48b0ceaec..87f4552b5 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -4,7 +4,7 @@ load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) go_template( - name = "defs", + name = "defs_amd64", srcs = [ "defs.go", "defs_amd64.go", @@ -14,11 +14,29 @@ go_template( visibility = [":__subpackages__"], ) +go_template( + name = "defs_arm64", + srcs = [ + "aarch64.go", + "defs.go", + "defs_arm64.go", + "offsets_arm64.go", + ], + visibility = [":__subpackages__"], +) + go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", package = "ring0", - template = ":defs", + template = ":defs_amd64", +) + +go_template_instance( + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", + package = "ring0", + template = ":defs_arm64", ) genrule( @@ -29,17 +47,31 @@ genrule( tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) +genrule( + name = "entry_impl_arm64", + srcs = ["entry_arm64.s"], + outs = ["entry_impl_arm64.s"], + cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) + go_library( name = "ring0", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "entry_amd64.go", + "entry_arm64.go", "entry_impl_amd64.s", + "entry_impl_arm64.s", "kernel.go", "kernel_amd64.go", + "kernel_arm64.go", "kernel_unsafe.go", "lib_amd64.go", "lib_amd64.s", + "lib_arm64.go", + "lib_arm64.s", "ring0.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go new file mode 100644 index 000000000..6b078cd1e --- /dev/null +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -0,0 +1,109 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// Useful bits. +const ( + _PGD_PGT_BASE = 0x1000 + _PGD_PGT_SIZE = 0x1000 + _PUD_PGT_BASE = 0x2000 + _PUD_PGT_SIZE = 0x1000 + _PMD_PGT_BASE = 0x3000 + _PMD_PGT_SIZE = 0x4000 + _PTE_PGT_BASE = 0x7000 + _PTE_PGT_SIZE = 0x1000 + + _PSR_MODE_EL0t = 0x0 + _PSR_MODE_EL1t = 0x4 + _PSR_MODE_EL1h = 0x5 + _PSR_EL_MASK = 0xf + + _PSR_D_BIT = 0x200 + _PSR_A_BIT = 0x100 + _PSR_I_BIT = 0x80 + _PSR_F_BIT = 0x40 +) + +const ( + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _PSR_MODE_EL1h + + // UserFlagsSet are always set in userspace. + UserFlagsSet = _PSR_MODE_EL0t + + KernelFlagsClear = _PSR_EL_MASK + UserFlagsClear = _PSR_EL_MASK + + PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT +) + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + El1SyncInvalid = iota + El1IrqInvalid + El1FiqInvalid + El1ErrorInvalid + El1Sync + El1Irq + El1Fiq + El1Error + El0Sync + El0Irq + El0Fiq + El0Error + El0Sync_invalid + El0Irq_invalid + El0Fiq_invalid + El0Error_invalid + El1Sync_da + El1Sync_ia + El1Sync_sp_pc + El1Sync_undef + El1Sync_dbg + El1Sync_inv + El0Sync_svc + El0Sync_da + El0Sync_ia + El0Sync_fpsimd_acc + El0Sync_sve_acc + El0Sync_sys + El0Sync_sp_pc + El0Sync_undef + El0Sync_dbg + El0Sync_inv + VirtualizationException + _NR_INTERRUPTS +) + +// System call vectors. +const ( + Syscall Vector = El0Sync_svc + PageFault Vector = El0Sync_da +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +func VirtualAddressBits() uint32 { + return 48 +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +func PhysicalAddressBits() uint32 { + return 40 +} diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 076063f85..3f094c2a7 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -20,17 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" ) -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - // Kernel is a global kernel object. // // This contains global state, shared by multiple CPUs. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 7206322b1..10dbd381f 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -20,6 +20,17 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + // Segment indices and Selectors. const ( // Index into GDT array. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go new file mode 100644 index 000000000..dc0eeec01 --- /dev/null +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -0,0 +1,136 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +import ( + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits()) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// KernelOpts has initialization options for the kernel. +type KernelOpts struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables +} + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + KernelOpts +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [512]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // faultAddr is the value of far_el1. + faultAddr uintptr + + // ttbr0Kvm is the value of ttbr0_el1 for sentry. + ttbr0Kvm uintptr + + // ttbr0App is the value of ttbr0_el1 for applicaton. + ttbr0App uintptr + + // exception vector. + vecCode Vector + + // application context pointer. + appAddr uintptr + + // lazyVFP is the value of cpacr_el1. + lazyVFP uintptr +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 // No code. + c.errorType = 1 // User mode. +} + +//go:nosplit +func (c *CPU) GetFaultAddr() (value uintptr) { + return c.faultAddr +} + +//go:nosplit +func (c *CPU) SetTtbr0Kvm(value uintptr) { + c.ttbr0Kvm = value +} + +//go:nosplit +func (c *CPU) SetTtbr0App(value uintptr) { + c.ttbr0App = value +} + +//go:nosplit +func (c *CPU) GetVector() (value Vector) { + return c.vecCode +} + +//go:nosplit +func (c *CPU) SetAppAddr(value uintptr) { + c.appAddr = value +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserASID indicates that the application ASID to be used on switch, + UserASID uint16 + + // KernelASID indicates that the kernel ASID to be used on return, + KernelASID uint16 +} + +func init() { +} diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/sentry/platform/ring0/entry_arm64.go new file mode 100644 index 000000000..0dfa42c36 --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_arm64.go @@ -0,0 +1,60 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// This is an assembly function. +// +// The sysenter function is invoked in two situations: +// +// (1) The guest kernel has executed a system call. +// (2) The guest application has executed a system call. +// +// The interrupt flag is examined to determine whether the system call was +// executed from kernel mode or not and the appropriate stub is called. + +func El1_sync_invalid() +func El1_irq_invalid() +func El1_fiq_invalid() +func El1_error_invalid() + +func El1_sync() +func El1_irq() +func El1_fiq() +func El1_error() + +func El0_sync() +func El0_irq() +func El0_fiq() +func El0_error() + +func El0_sync_invalid() +func El0_irq_invalid() +func El0_fiq_invalid() +func El0_error_invalid() + +func Vectors() + +// Start is the CPU entrypoint. +// +// The CPU state will be set to c.Registers(). +func Start() +func kernelExitToEl1() + +func kernelExitToEl0() + +// Shutdown execution +func Shutdown() diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s new file mode 100644 index 000000000..add2c3e08 --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -0,0 +1,591 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +// NB: Offsets are programatically generated (see BUILD). +// +// This file is concatenated with the definitions. + +// Saves a register set. +// +// This is a macro because it may need to executed in contents where a stack is +// not available for calls. +// + +#define ERET() \ + WORD $0xd69f03e0 + +#define RSV_REG R18_PLATFORM +#define RSV_REG_APP R9 + +#define FPEN_NOTRAP 0x3 +#define FPEN_SHIFT 20 + +#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) + +#define REGISTERS_SAVE(reg, offset) \ + MOVD R0, offset+PTRACE_R0(reg); \ + MOVD R1, offset+PTRACE_R1(reg); \ + MOVD R2, offset+PTRACE_R2(reg); \ + MOVD R3, offset+PTRACE_R3(reg); \ + MOVD R4, offset+PTRACE_R4(reg); \ + MOVD R5, offset+PTRACE_R5(reg); \ + MOVD R6, offset+PTRACE_R6(reg); \ + MOVD R7, offset+PTRACE_R7(reg); \ + MOVD R8, offset+PTRACE_R8(reg); \ + MOVD R10, offset+PTRACE_R10(reg); \ + MOVD R11, offset+PTRACE_R11(reg); \ + MOVD R12, offset+PTRACE_R12(reg); \ + MOVD R13, offset+PTRACE_R13(reg); \ + MOVD R14, offset+PTRACE_R14(reg); \ + MOVD R15, offset+PTRACE_R15(reg); \ + MOVD R16, offset+PTRACE_R16(reg); \ + MOVD R17, offset+PTRACE_R17(reg); \ + MOVD R19, offset+PTRACE_R19(reg); \ + MOVD R20, offset+PTRACE_R20(reg); \ + MOVD R21, offset+PTRACE_R21(reg); \ + MOVD R22, offset+PTRACE_R22(reg); \ + MOVD R23, offset+PTRACE_R23(reg); \ + MOVD R24, offset+PTRACE_R24(reg); \ + MOVD R25, offset+PTRACE_R25(reg); \ + MOVD R26, offset+PTRACE_R26(reg); \ + MOVD R27, offset+PTRACE_R27(reg); \ + MOVD g, offset+PTRACE_R28(reg); \ + MOVD R29, offset+PTRACE_R29(reg); \ + MOVD R30, offset+PTRACE_R30(reg); + +#define REGISTERS_LOAD(reg, offset) \ + MOVD offset+PTRACE_R0(reg), R0; \ + MOVD offset+PTRACE_R1(reg), R1; \ + MOVD offset+PTRACE_R2(reg), R2; \ + MOVD offset+PTRACE_R3(reg), R3; \ + MOVD offset+PTRACE_R4(reg), R4; \ + MOVD offset+PTRACE_R5(reg), R5; \ + MOVD offset+PTRACE_R6(reg), R6; \ + MOVD offset+PTRACE_R7(reg), R7; \ + MOVD offset+PTRACE_R8(reg), R8; \ + MOVD offset+PTRACE_R10(reg), R10; \ + MOVD offset+PTRACE_R11(reg), R11; \ + MOVD offset+PTRACE_R12(reg), R12; \ + MOVD offset+PTRACE_R13(reg), R13; \ + MOVD offset+PTRACE_R14(reg), R14; \ + MOVD offset+PTRACE_R15(reg), R15; \ + MOVD offset+PTRACE_R16(reg), R16; \ + MOVD offset+PTRACE_R17(reg), R17; \ + MOVD offset+PTRACE_R19(reg), R19; \ + MOVD offset+PTRACE_R20(reg), R20; \ + MOVD offset+PTRACE_R21(reg), R21; \ + MOVD offset+PTRACE_R22(reg), R22; \ + MOVD offset+PTRACE_R23(reg), R23; \ + MOVD offset+PTRACE_R24(reg), R24; \ + MOVD offset+PTRACE_R25(reg), R25; \ + MOVD offset+PTRACE_R26(reg), R26; \ + MOVD offset+PTRACE_R27(reg), R27; \ + MOVD offset+PTRACE_R28(reg), g; \ + MOVD offset+PTRACE_R29(reg), R29; \ + MOVD offset+PTRACE_R30(reg), R30; + +//NOP +#define nop31Instructions() \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; + +#define ESR_ELx_EC_UNKNOWN (0x00) +#define ESR_ELx_EC_WFx (0x01) +/* Unallocated EC: 0x02 */ +#define ESR_ELx_EC_CP15_32 (0x03) +#define ESR_ELx_EC_CP15_64 (0x04) +#define ESR_ELx_EC_CP14_MR (0x05) +#define ESR_ELx_EC_CP14_LS (0x06) +#define ESR_ELx_EC_FP_ASIMD (0x07) +#define ESR_ELx_EC_CP10_ID (0x08) /* EL2 only */ +#define ESR_ELx_EC_PAC (0x09) /* EL2 and above */ +/* Unallocated EC: 0x0A - 0x0B */ +#define ESR_ELx_EC_CP14_64 (0x0C) +/* Unallocated EC: 0x0d */ +#define ESR_ELx_EC_ILL (0x0E) +/* Unallocated EC: 0x0F - 0x10 */ +#define ESR_ELx_EC_SVC32 (0x11) +#define ESR_ELx_EC_HVC32 (0x12) /* EL2 only */ +#define ESR_ELx_EC_SMC32 (0x13) /* EL2 and above */ +/* Unallocated EC: 0x14 */ +#define ESR_ELx_EC_SVC64 (0x15) +#define ESR_ELx_EC_HVC64 (0x16) /* EL2 and above */ +#define ESR_ELx_EC_SMC64 (0x17) /* EL2 and above */ +#define ESR_ELx_EC_SYS64 (0x18) +#define ESR_ELx_EC_SVE (0x19) +/* Unallocated EC: 0x1A - 0x1E */ +#define ESR_ELx_EC_IMP_DEF (0x1f) /* EL3 only */ +#define ESR_ELx_EC_IABT_LOW (0x20) +#define ESR_ELx_EC_IABT_CUR (0x21) +#define ESR_ELx_EC_PC_ALIGN (0x22) +/* Unallocated EC: 0x23 */ +#define ESR_ELx_EC_DABT_LOW (0x24) +#define ESR_ELx_EC_DABT_CUR (0x25) +#define ESR_ELx_EC_SP_ALIGN (0x26) +/* Unallocated EC: 0x27 */ +#define ESR_ELx_EC_FP_EXC32 (0x28) +/* Unallocated EC: 0x29 - 0x2B */ +#define ESR_ELx_EC_FP_EXC64 (0x2C) +/* Unallocated EC: 0x2D - 0x2E */ +#define ESR_ELx_EC_SERROR (0x2F) +#define ESR_ELx_EC_BREAKPT_LOW (0x30) +#define ESR_ELx_EC_BREAKPT_CUR (0x31) +#define ESR_ELx_EC_SOFTSTP_LOW (0x32) +#define ESR_ELx_EC_SOFTSTP_CUR (0x33) +#define ESR_ELx_EC_WATCHPT_LOW (0x34) +#define ESR_ELx_EC_WATCHPT_CUR (0x35) +/* Unallocated EC: 0x36 - 0x37 */ +#define ESR_ELx_EC_BKPT32 (0x38) +/* Unallocated EC: 0x39 */ +#define ESR_ELx_EC_VECTOR32 (0x3A) /* EL2 only */ +/* Unallocted EC: 0x3B */ +#define ESR_ELx_EC_BRK64 (0x3C) +/* Unallocated EC: 0x3D - 0x3F */ +#define ESR_ELx_EC_MAX (0x3F) + +#define ESR_ELx_EC_SHIFT (26) +#define ESR_ELx_EC_MASK (UL(0x3F) << ESR_ELx_EC_SHIFT) +#define ESR_ELx_EC(esr) (((esr) & ESR_ELx_EC_MASK) >> ESR_ELx_EC_SHIFT) + +#define ESR_ELx_IL_SHIFT (25) +#define ESR_ELx_IL (UL(1) << ESR_ELx_IL_SHIFT) +#define ESR_ELx_ISS_MASK (ESR_ELx_IL - 1) + +/* ISS field definitions shared by different classes */ +#define ESR_ELx_WNR_SHIFT (6) +#define ESR_ELx_WNR (UL(1) << ESR_ELx_WNR_SHIFT) + +/* Asynchronous Error Type */ +#define ESR_ELx_IDS_SHIFT (24) +#define ESR_ELx_IDS (UL(1) << ESR_ELx_IDS_SHIFT) +#define ESR_ELx_AET_SHIFT (10) +#define ESR_ELx_AET (UL(0x7) << ESR_ELx_AET_SHIFT) + +#define ESR_ELx_AET_UC (UL(0) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UEU (UL(1) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UEO (UL(2) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UER (UL(3) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_CE (UL(6) << ESR_ELx_AET_SHIFT) + +/* Shared ISS field definitions for Data/Instruction aborts */ +#define ESR_ELx_SET_SHIFT (11) +#define ESR_ELx_SET_MASK (UL(3) << ESR_ELx_SET_SHIFT) +#define ESR_ELx_FnV_SHIFT (10) +#define ESR_ELx_FnV (UL(1) << ESR_ELx_FnV_SHIFT) +#define ESR_ELx_EA_SHIFT (9) +#define ESR_ELx_EA (UL(1) << ESR_ELx_EA_SHIFT) +#define ESR_ELx_S1PTW_SHIFT (7) +#define ESR_ELx_S1PTW (UL(1) << ESR_ELx_S1PTW_SHIFT) + +/* Shared ISS fault status code(IFSC/DFSC) for Data/Instruction aborts */ +#define ESR_ELx_FSC (0x3F) +#define ESR_ELx_FSC_TYPE (0x3C) +#define ESR_ELx_FSC_EXTABT (0x10) +#define ESR_ELx_FSC_SERROR (0x11) +#define ESR_ELx_FSC_ACCESS (0x08) +#define ESR_ELx_FSC_FAULT (0x04) +#define ESR_ELx_FSC_PERM (0x0C) + +/* ISS field definitions for Data Aborts */ +#define ESR_ELx_ISV_SHIFT (24) +#define ESR_ELx_ISV (UL(1) << ESR_ELx_ISV_SHIFT) +#define ESR_ELx_SAS_SHIFT (22) +#define ESR_ELx_SAS (UL(3) << ESR_ELx_SAS_SHIFT) +#define ESR_ELx_SSE_SHIFT (21) +#define ESR_ELx_SSE (UL(1) << ESR_ELx_SSE_SHIFT) +#define ESR_ELx_SRT_SHIFT (16) +#define ESR_ELx_SRT_MASK (UL(0x1F) << ESR_ELx_SRT_SHIFT) +#define ESR_ELx_SF_SHIFT (15) +#define ESR_ELx_SF (UL(1) << ESR_ELx_SF_SHIFT) +#define ESR_ELx_AR_SHIFT (14) +#define ESR_ELx_AR (UL(1) << ESR_ELx_AR_SHIFT) +#define ESR_ELx_CM_SHIFT (8) +#define ESR_ELx_CM (UL(1) << ESR_ELx_CM_SHIFT) + +/* ISS field definitions for exceptions taken in to Hyp */ +#define ESR_ELx_CV (UL(1) << 24) +#define ESR_ELx_COND_SHIFT (20) +#define ESR_ELx_COND_MASK (UL(0xF) << ESR_ELx_COND_SHIFT) +#define ESR_ELx_WFx_ISS_TI (UL(1) << 0) +#define ESR_ELx_WFx_ISS_WFI (UL(0) << 0) +#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) +#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) + +#define LOAD_KERNEL_ADDRESS(from, to) \ + MOVD from, to; \ + ORR $0xffff000000000000, to, to; + +// LOAD_KERNEL_STACK loads the kernel temporary stack. +#define LOAD_KERNEL_STACK(from) \ + LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \ + MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \ + MOVD RSV_REG, RSP; \ + ISB $15; \ + DSB $15; + +#define SWITCH_TO_APP_PAGETABLE(from) \ + MOVD CPU_TTBR0_APP(from), RSV_REG; \ + WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + ISB $15; \ + DSB $15; + +#define SWITCH_TO_KVM_PAGETABLE(from) \ + MOVD CPU_TTBR0_KVM(from), RSV_REG; \ + WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + ISB $15; \ + DSB $15; + +#define IRQ_ENABLE \ + MSR $2, DAIFSet; + +#define IRQ_DISABLE \ + MSR $2, DAIFClr; + +#define VFP_ENABLE \ + MOVD $FPEN_ENABLE, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + +#define VFP_DISABLE \ + MOVD $0x0, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + +#define KERNEL_ENTRY_FROM_EL0 \ + SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. + STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable. + SWITCH_TO_KVM_PAGETABLE(RSV_REG); \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer. + REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context. + MOVD RSV_REG_APP, R20; \ + LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \ + ADD $16, RSP, RSP; \ + MOVD RSV_REG, PTRACE_R18(R20); \ + MOVD RSV_REG_APP, PTRACE_R9(R20); \ + MOVD R20, RSV_REG_APP; \ + WORD $0xd5384003; \ // MRS SPSR_EL1, R3 + MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \ + MRS ELR_EL1, R3; \ + MOVD R3, PTRACE_PC(RSV_REG_APP); \ + WORD $0xd5384103; \ // MRS SP_EL0, R3 + MOVD R3, PTRACE_SP(RSV_REG_APP); + +#define KERNEL_ENTRY_FROM_EL1 \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // save sentry context + MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \ + WORD $0xd5384004; \ // MRS SPSR_EL1, R4 + MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \ + MRS ELR_EL1, R4; \ + MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \ + MOVD RSP, R4; \ + MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); + +TEXT ·Halt(SB),NOSPLIT,$0 + // clear bluepill. + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + CMP RSV_REG, R9 + BNE mmio_exit + MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG) +mmio_exit: + // Disable fpsimd. + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, CPU_LAZY_VFP(RSV_REG) + VFP_DISABLE + + // MMIO_EXIT. + MOVD $0, R9 + MOVD R0, 0xffff000000001000(R9) + B ·kernelExitToEl1(SB) + +TEXT ·Shutdown(SB),NOSPLIT,$0 + // PSCI EVENT. + MOVD $0x84000009, R0 + HVC $0 + +// See kernel.go. +TEXT ·Current(SB),NOSPLIT,$0-8 + MOVD CPU_SELF(RSV_REG), R8 + MOVD R8, ret+0(FP) + RET + +#define STACK_FRAME_SIZE 16 + +TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 + ERET() + +TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 + ERET() + +TEXT ·Start(SB),NOSPLIT,$0 + IRQ_DISABLE + MOVD R8, RSV_REG + ORR $0xffff000000000000, RSV_REG, RSV_REG + WORD $0xd518d092 //MSR R18, TPIDR_EL1 + + B ·kernelExitToEl1(SB) + +TEXT ·El1_sync_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_irq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_error_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_sync(SB),NOSPLIT,$0 + KERNEL_ENTRY_FROM_EL1 + WORD $0xd5385219 // MRS ESR_EL1, R25 + LSR $ESR_ELx_EC_SHIFT, R25, R24 + CMP $ESR_ELx_EC_DABT_CUR, R24 + BEQ el1_da + CMP $ESR_ELx_EC_IABT_CUR, R24 + BEQ el1_ia + CMP $ESR_ELx_EC_SYS64, R24 + BEQ el1_undef + CMP $ESR_ELx_EC_SP_ALIGN, R24 + BEQ el1_sp_pc + CMP $ESR_ELx_EC_PC_ALIGN, R24 + BEQ el1_sp_pc + CMP $ESR_ELx_EC_UNKNOWN, R24 + BEQ el1_undef + CMP $ESR_ELx_EC_SVC64, R24 + BEQ el1_svc + CMP $ESR_ELx_EC_BREAKPT_CUR, R24 + BGE el1_dbg + CMP $ESR_ELx_EC_FP_ASIMD, R24 + BEQ el1_fpsimd_acc + B el1_invalid + +el1_da: + B ·Halt(SB) + +el1_ia: + B ·Halt(SB) + +el1_sp_pc: + B ·Shutdown(SB) + +el1_undef: + B ·Shutdown(SB) + +el1_svc: + B ·Halt(SB) + +el1_dbg: + B ·Shutdown(SB) + +el1_fpsimd_acc: + VFP_ENABLE + B ·kernelExitToEl1(SB) // Resume. + +el1_invalid: + B ·Shutdown(SB) + +TEXT ·El1_irq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_fiq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_error(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_sync(SB),NOSPLIT,$0 + KERNEL_ENTRY_FROM_EL0 + WORD $0xd5385219 // MRS ESR_EL1, R25 + LSR $ESR_ELx_EC_SHIFT, R25, R24 + CMP $ESR_ELx_EC_SVC64, R24 + BEQ el0_svc + CMP $ESR_ELx_EC_DABT_LOW, R24 + BEQ el0_da + CMP $ESR_ELx_EC_IABT_LOW, R24 + BEQ el0_ia + CMP $ESR_ELx_EC_FP_ASIMD, R24 + BEQ el0_fpsimd_acc + CMP $ESR_ELx_EC_SVE, R24 + BEQ el0_sve_acc + CMP $ESR_ELx_EC_FP_EXC64, R24 + BEQ el0_fpsimd_exc + CMP $ESR_ELx_EC_SP_ALIGN, R24 + BEQ el0_sp_pc + CMP $ESR_ELx_EC_PC_ALIGN, R24 + BEQ el0_sp_pc + CMP $ESR_ELx_EC_UNKNOWN, R24 + BEQ el0_undef + CMP $ESR_ELx_EC_BREAKPT_LOW, R24 + BGE el0_dbg + B el0_invalid + +el0_svc: + B ·Halt(SB) + +el0_da: + B ·Halt(SB) + +el0_ia: + B ·Shutdown(SB) + +el0_fpsimd_acc: + B ·Shutdown(SB) + +el0_sve_acc: + B ·Shutdown(SB) + +el0_fpsimd_exc: + B ·Shutdown(SB) + +el0_sp_pc: + B ·Shutdown(SB) + +el0_undef: + B ·Shutdown(SB) + +el0_dbg: + B ·Shutdown(SB) + +el0_invalid: + B ·Shutdown(SB) + +TEXT ·El0_irq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_fiq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_error(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_irq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_error_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·Vectors(SB),NOSPLIT,$0 + B ·El1_sync_invalid(SB) + nop31Instructions() + B ·El1_irq_invalid(SB) + nop31Instructions() + B ·El1_fiq_invalid(SB) + nop31Instructions() + B ·El1_error_invalid(SB) + nop31Instructions() + + B ·El1_sync(SB) + nop31Instructions() + B ·El1_irq(SB) + nop31Instructions() + B ·El1_fiq(SB) + nop31Instructions() + B ·El1_error(SB) + nop31Instructions() + + B ·El0_sync(SB) + nop31Instructions() + B ·El0_irq(SB) + nop31Instructions() + B ·El0_fiq(SB) + nop31Instructions() + B ·El0_error(SB) + nop31Instructions() + + B ·El0_sync_invalid(SB) + nop31Instructions() + B ·El0_irq_invalid(SB) + nop31Instructions() + B ·El0_fiq_invalid(SB) + nop31Instructions() + B ·El0_error_invalid(SB) + nop31Instructions() + + WORD $0xd503201f //nop + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 780bf9a66..42076fb04 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -4,16 +4,24 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs", + template = "//pkg/sentry/platform/ring0:defs_arm64", +) + +go_template_instance( + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", + package = "main", + template = "//pkg/sentry/platform/ring0:defs_amd64", ) go_binary( name = "gen_offsets", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "main.go", ], visibility = ["//pkg/sentry/platform/ring0:__pkg__"], diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go new file mode 100644 index 000000000..ed82a131e --- /dev/null +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -0,0 +1,58 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// init initializes architecture-specific state. +func (k *Kernel) init(opts KernelOpts) { + // Save the root page tables. + k.PageTables = opts.PageTables +} + +// init initializes architecture-specific state. +func (c *CPU) init() { + // Set the kernel stack pointer(virtual address). + c.registers.Sp = uint64(c.StackTop()) + +} + +// StackTop returns the kernel's stack address. +// +//go:nosplit +func (c *CPU) StackTop() uint64 { + return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack)) +} + +// IsCanonical indicates whether addr is canonical per the arm64 spec. +// +//go:nosplit +func IsCanonical(addr uint64) bool { + return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 +} + +//go:nosplit +func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { + // Sanitize registers. + regs := switchOpts.Registers + + regs.Pstate &= ^uint64(UserFlagsClear) + regs.Pstate |= UserFlagsSet + kernelExitToEl0() + vector = c.vecCode + + // Perform the switch. + return +} diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go new file mode 100644 index 000000000..8bcfe1032 --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -0,0 +1,39 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// CPACREL1 returns the value of the CPACR_EL1 register. +func CPACREL1() (value uintptr) + +// FPCR returns the value of FPCR register. +func FPCR() (value uintptr) + +// SetFPCR writes the FPCR value. +func SetFPCR(value uintptr) + +// FPSR returns the value of FPSR register. +func FPSR() (value uintptr) + +// SetFPSR writes the FPSR value. +func SetFPSR(value uintptr) + +// SaveVRegs saves V0-V31 registers. +// V0-V31: 32 128-bit registers for floating point and simd. +func SaveVRegs(*byte) + +// LoadVRegs loads V0-V31 registers. +func LoadVRegs(*byte) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s new file mode 100644 index 000000000..1c9171004 --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -0,0 +1,118 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +TEXT ·CPACREL1(SB),NOSPLIT,$0-8 + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·FPCR(SB),NOSPLIT,$0-8 + WORD $0xd53b4201 // MRS NZCV, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·GetFPSR(SB),NOSPLIT,$0-8 + WORD $0xd53b4421 // MRS FPSR, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·FPCR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4201 // MSR R1, NZCV + RET + +TEXT ·SetFPSR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4421 // MSR R1, FPSR + RET + +TEXT ·SaveVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD F0, 16*1(R0) + FMOVD F1, 16*2(R0) + FMOVD F2, 16*3(R0) + FMOVD F3, 16*4(R0) + FMOVD F4, 16*5(R0) + FMOVD F5, 16*6(R0) + FMOVD F6, 16*7(R0) + FMOVD F7, 16*8(R0) + FMOVD F8, 16*9(R0) + FMOVD F9, 16*10(R0) + FMOVD F10, 16*11(R0) + FMOVD F11, 16*12(R0) + FMOVD F12, 16*13(R0) + FMOVD F13, 16*14(R0) + FMOVD F14, 16*15(R0) + FMOVD F15, 16*16(R0) + FMOVD F16, 16*17(R0) + FMOVD F17, 16*18(R0) + FMOVD F18, 16*19(R0) + FMOVD F19, 16*20(R0) + FMOVD F20, 16*21(R0) + FMOVD F21, 16*22(R0) + FMOVD F22, 16*23(R0) + FMOVD F23, 16*24(R0) + FMOVD F24, 16*25(R0) + FMOVD F25, 16*26(R0) + FMOVD F26, 16*27(R0) + FMOVD F27, 16*28(R0) + FMOVD F28, 16*29(R0) + FMOVD F29, 16*30(R0) + FMOVD F30, 16*31(R0) + FMOVD F31, 16*32(R0) + ISB $15 + + RET + +TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD 16*1(R0), F0 + FMOVD 16*2(R0), F1 + FMOVD 16*3(R0), F2 + FMOVD 16*4(R0), F3 + FMOVD 16*5(R0), F4 + FMOVD 16*6(R0), F5 + FMOVD 16*7(R0), F6 + FMOVD 16*8(R0), F7 + FMOVD 16*9(R0), F8 + FMOVD 16*10(R0), F9 + FMOVD 16*11(R0), F10 + FMOVD 16*12(R0), F11 + FMOVD 16*13(R0), F12 + FMOVD 16*14(R0), F13 + FMOVD 16*15(R0), F14 + FMOVD 16*16(R0), F15 + FMOVD 16*17(R0), F16 + FMOVD 16*18(R0), F17 + FMOVD 16*19(R0), F18 + FMOVD 16*20(R0), F19 + FMOVD 16*21(R0), F20 + FMOVD 16*22(R0), F21 + FMOVD 16*23(R0), F22 + FMOVD 16*24(R0), F23 + FMOVD 16*25(R0), F24 + FMOVD 16*26(R0), F25 + FMOVD 16*27(R0), F26 + FMOVD 16*28(R0), F27 + FMOVD 16*29(R0), F28 + FMOVD 16*30(R0), F29 + FMOVD 16*31(R0), F30 + FMOVD 16*32(R0), F31 + ISB $15 + + RET diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go new file mode 100644 index 000000000..cd2a65f97 --- /dev/null +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -0,0 +1,125 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +import ( + "fmt" + "io" + "reflect" + "syscall" +) + +// Emit prints architecture-specific offsets. +func Emit(w io.Writer) { + fmt.Fprintf(w, "// Automatically generated, do not edit.\n") + + c := &CPU{} + fmt.Fprintf(w, "\n// CPU offsets.\n") + fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) + fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_FAULT_ADDR 0x%02x\n", reflect.ValueOf(&c.faultAddr).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_TTBR0_KVM 0x%02x\n", reflect.ValueOf(&c.ttbr0Kvm).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) + + fmt.Fprintf(w, "\n// Bits.\n") + fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) + + fmt.Fprintf(w, "\n// Vectors.\n") + fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid) + fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid) + fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid) + fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid) + + fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync) + fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq) + fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq) + fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error) + + fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync) + fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq) + fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq) + fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error) + + fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid) + fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid) + fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid) + fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid) + + fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da) + fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia) + fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc) + fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef) + fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg) + fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv) + + fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc) + fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da) + fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia) + fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc) + fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc) + fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys) + fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc) + fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef) + fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg) + fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv) + + fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) + fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) + + p := &syscall.PtraceRegs{} + fmt.Fprintf(w, "\n// Ptrace registers.\n") + fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R2 0x%02x\n", reflect.ValueOf(&p.Regs[2]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R3 0x%02x\n", reflect.ValueOf(&p.Regs[3]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R4 0x%02x\n", reflect.ValueOf(&p.Regs[4]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R5 0x%02x\n", reflect.ValueOf(&p.Regs[5]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R6 0x%02x\n", reflect.ValueOf(&p.Regs[6]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R7 0x%02x\n", reflect.ValueOf(&p.Regs[7]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.Regs[8]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.Regs[9]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.Regs[10]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.Regs[11]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.Regs[12]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.Regs[13]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.Regs[14]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.Regs[15]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R16 0x%02x\n", reflect.ValueOf(&p.Regs[16]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R17 0x%02x\n", reflect.ValueOf(&p.Regs[17]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R18 0x%02x\n", reflect.ValueOf(&p.Regs[18]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R19 0x%02x\n", reflect.ValueOf(&p.Regs[19]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R20 0x%02x\n", reflect.ValueOf(&p.Regs[20]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R21 0x%02x\n", reflect.ValueOf(&p.Regs[21]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R22 0x%02x\n", reflect.ValueOf(&p.Regs[22]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R23 0x%02x\n", reflect.ValueOf(&p.Regs[23]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R24 0x%02x\n", reflect.ValueOf(&p.Regs[24]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R25 0x%02x\n", reflect.ValueOf(&p.Regs[25]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R26 0x%02x\n", reflect.ValueOf(&p.Regs[26]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R27 0x%02x\n", reflect.ValueOf(&p.Regs[27]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R28 0x%02x\n", reflect.ValueOf(&p.Regs[28]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R29 0x%02x\n", reflect.ValueOf(&p.Regs[29]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R30 0x%02x\n", reflect.ValueOf(&p.Regs[30]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer()) +} diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 934a90378..e2e15ba5c 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,14 +1,17 @@ -load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +config_setting( + name = "aarch64", + constraint_values = ["@bazel_tools//platforms:aarch64"], +) + go_template( name = "generic_walker", - srcs = [ - "walker_amd64.go", - ], + srcs = ["walker_amd64.go"], opt_types = [ "Visitor", ], @@ -76,9 +79,13 @@ go_library( "allocator.go", "allocator_unsafe.go", "pagetables.go", + "pagetables_aarch64.go", "pagetables_amd64.go", + "pagetables_arm64.go", "pagetables_x86.go", "pcids_x86.go", + "walker_amd64.go", + "walker_arm64.go", "walker_empty.go", "walker_lookup.go", "walker_map.go", @@ -97,6 +104,7 @@ go_test( size = "small", srcs = [ "pagetables_amd64_test.go", + "pagetables_arm64_test.go", "pagetables_test.go", "walker_check.go", ], diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 904f1a6de..30c64a372 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -48,15 +48,6 @@ func New(a Allocator) *PageTables { return p } -// Init initializes a set of PageTables. -// -//go:nosplit -func (p *PageTables) Init(allocator Allocator) { - p.Allocator = allocator - p.root = p.Allocator.NewPTEs() - p.rootPhysical = p.Allocator.PhysicalFor(p.root) -} - // mapVisitor is used for map. type mapVisitor struct { target uintptr // Input. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go new file mode 100644 index 000000000..e78424766 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -0,0 +1,212 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +// archPageTables is architecture-specific data. +type archPageTables struct { + // root is the pagetable root for kernel space. + root *PTEs + + // rootPhysical is the cached physical address of the root. + // + // This is saved only to prevent constant translation. + rootPhysical uintptr + + asid uint16 +} + +// TTBR0_EL1 returns the translation table base register 0. +// +//go:nosplit +func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 { + return uint64(p.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset +} + +// TTBR1_EL1 returns the translation table base register 1. +// +//go:nosplit +func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 { + return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset +} + +// Bits in page table entries. +const ( + typeTable = 0x3 << 0 + typeSect = 0x1 << 0 + typePage = 0x3 << 0 + pteValid = 0x1 << 0 + pteTableBit = 0x1 << 1 + pteTypeMask = 0x3 << 0 + present = pteValid | pteTableBit + user = 0x1 << 6 /* AP[1] */ + readOnly = 0x1 << 7 /* AP[2] */ + accessed = 0x1 << 10 + dbm = 0x1 << 51 + writable = dbm + cont = 0x1 << 52 + pxn = 0x1 << 53 + xn = 0x1 << 54 + dirty = 0x1 << 55 + nG = 0x1 << 11 + shared = 0x3 << 8 +) + +const ( + mtNormal = 0x4 << 2 +) + +const ( + executeDisable = xn + optionMask = 0xfff | 0xfff<<48 + protDefault = accessed | shared | mtNormal +) + +// MapOpts are x86 options. +type MapOpts struct { + // AccessType defines permissions. + AccessType usermem.AccessType + + // Global indicates the page is globally accessible. + Global bool + + // User indicates the page is a user page. + User bool +} + +// PTE is a page table entry. +type PTE uintptr + +// Clear clears this PTE, including sect page information. +// +//go:nosplit +func (p *PTE) Clear() { + atomic.StoreUintptr((*uintptr)(p), 0) +} + +// Valid returns true iff this entry is valid. +// +//go:nosplit +func (p *PTE) Valid() bool { + return atomic.LoadUintptr((*uintptr)(p))&present != 0 +} + +// Opts returns the PTE options. +// +// These are all options except Valid and Sect. +// +//go:nosplit +func (p *PTE) Opts() MapOpts { + v := atomic.LoadUintptr((*uintptr)(p)) + + return MapOpts{ + AccessType: usermem.AccessType{ + Read: true, + Write: v&readOnly == 0, + Execute: v&xn == 0, + }, + Global: v&nG == 0, + User: v&user != 0, + } +} + +// SetSect sets this page as a sect page. +// +// The page must not be valid or a panic will result. +// +//go:nosplit +func (p *PTE) SetSect() { + if p.Valid() { + // This is not allowed. + panic("SetSect called on valid page!") + } + atomic.StoreUintptr((*uintptr)(p), typeSect) +} + +// IsSect returns true iff this page is a sect page. +// +//go:nosplit +func (p *PTE) IsSect() bool { + return atomic.LoadUintptr((*uintptr)(p))&pteTypeMask == typeSect +} + +// Set sets this PTE value. +// +// This does not change the sect page property. +// +//go:nosplit +func (p *PTE) Set(addr uintptr, opts MapOpts) { + if !opts.AccessType.Any() { + p.Clear() + return + } + v := (addr &^ optionMask) | protDefault | nG | readOnly + + if p.IsSect() { + // Note that this is inherited from the previous instance. Set + // does not change the value of Sect. See above. + v |= typeSect + } else { + v |= typePage + } + + if opts.Global { + v = v &^ nG + } + + if opts.AccessType.Execute { + v = v &^ executeDisable + } else { + v |= executeDisable + } + if opts.AccessType.Write { + v = v &^ readOnly + } + + if opts.User { + v |= user + } else { + v = v &^ user + } + atomic.StoreUintptr((*uintptr)(p), v) +} + +// setPageTable sets this PTE value and forces the write bit and sect bit to +// be cleared. This is used explicitly for breaking sect pages. +// +//go:nosplit +func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) { + addr := pt.Allocator.PhysicalFor(ptes) + if addr&^optionMask != addr { + // This should never happen. + panic("unaligned physical address!") + } + v := addr | typeTable | protDefault + atomic.StoreUintptr((*uintptr)(p), v) +} + +// Address extracts the address. This should only be used if Valid returns true. +// +//go:nosplit +func (p *PTE) Address() uintptr { + return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index 7aa6c524e..0c153cf8c 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -41,5 +41,14 @@ const ( entriesPerPage = 512 ) +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) +} + // PTEs is a collection of entries. type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go new file mode 100644 index 000000000..1a49f12a2 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -0,0 +1,57 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagetables + +// Address constraints. +// +// The lowerTop and upperBottom currently apply to four-level pagetables; +// additional refactoring would be necessary to support five-level pagetables. +const ( + lowerTop = 0x0000ffffffffffff + upperBottom = 0xffff000000000000 + pteShift = 12 + pmdShift = 21 + pudShift = 30 + pgdShift = 39 + + pteMask = 0x1ff << pteShift + pmdMask = 0x1ff << pmdShift + pudMask = 0x1ff << pudShift + pgdMask = 0x1ff << pgdShift + + pteSize = 1 << pteShift + pmdSize = 1 << pmdShift + pudSize = 1 << pudShift + pgdSize = 1 << pgdShift + + ttbrASIDOffset = 55 + ttbrASIDMask = 0xff + + entriesPerPage = 512 +) + +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) + p.archPageTables.root = p.Allocator.NewPTEs() + p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) +} + +// PTEs is a collection of entries. +type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go new file mode 100644 index 000000000..254116233 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go @@ -0,0 +1,80 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +func Test2MAnd4K(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a small page and a huge page. + pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47) + + pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42) + pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, + {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}}, + {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}}, + }) +} + +func Test1GAnd4K(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a small page and a super page. + pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, + {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} + +func TestSplit1GPage(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a super page and knock out the middle. + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42) + pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} + +func TestSplit2MPage(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a huge page and knock out the middle. + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42) + pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go new file mode 100644 index 000000000..c261d393a --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +// Visitor is a generic type. +type Visitor interface { + // visit is called on each PTE. + visit(start uintptr, pte *PTE, align uintptr) + + // requiresAlloc indicates that new entries should be allocated within + // the walked range. + requiresAlloc() bool + + // requiresSplit indicates that entries in the given range should be + // split if they are huge or jumbo pages. + requiresSplit() bool +} + +// Walker walks page tables. +type Walker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor Visitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is sect pages. If a valid sect page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of sect pages whenever +// possible. Whether a sect page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *Walker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func next(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *Walker) iterateRangeCanonical(start, end uintptr) { + pgdEntryIndex := w.pageTables.root + if start >= upperBottom { + pgdEntryIndex = w.pageTables.archPageTables.root + } + + for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &pgdEntryIndex[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + start = next(start, pgdSize) + continue + } + + // Allocate a new pgd. + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + // Map the next level. + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + clearPUDEntries++ + start = next(start, pudSize) + continue + } + + // This level has 1-GB sect pages. Is this + // entire region at least as large as a single + // PUD entry? If so, we can skip allocating a + // new page for the pmd. + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSect() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = next(start, pudSize) + continue + } + } + + // Allocate a new pud. + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSect() { + // Does this page need to be split? + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) { + // Install the relevant entries. + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSect() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + // A sect page to be checked directly. + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + // Might have been cleared. + if !pudEntry.Valid() { + clearPUDEntries++ + } + + // Note that the sect page was changed. + start = next(start, pudSize) + continue + } + + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + // Map the next level, since this is valid. + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + clearPMDEntries++ + start = next(start, pmdSize) + continue + } + + // This level has 2-MB huge pages. If this + // region is contined in a single PMD entry? + // As above, we can skip allocating a new page. + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSect() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = next(start, pmdSize) + continue + } + } + + // Allocate a new pmd. + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSect() { + // Does this page need to be split? + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) { + // Install the relevant entries. + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + // A huge page to be checked directly. + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + // Might have been cleared. + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + // Note that the huge page was changed. + start = next(start, pmdSize) + continue + } + + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + // Map the next level, since this is valid. + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + // At this point, we are guaranteed that start%pteSize == 0. + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + // Note that the pte was changed. + start += pteSize + continue + } + + // Check if we no longer need this page. + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + // Check if we no longer need this page. + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + // Check if we no longer need this page. + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go index 2f65db70b..ba1f9043d 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sentry/sighandling/sighandling.go @@ -16,7 +16,6 @@ package sighandling import ( - "fmt" "os" "os/signal" "reflect" @@ -31,37 +30,25 @@ const numSignals = 32 // handleSignals listens for incoming signals and calls the given handler // function. // -// It starts when the start channel is closed, stops when the stop channel -// is closed, and closes done once it will no longer deliver signals to k. -func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) { +// It stops when the stop channel is closed. The done channel is closed once it +// will no longer deliver signals to k. +func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) { // Build a select case. - sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}} + sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}} for _, sigchan := range sigchans { sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)}) } - started := false for { // Wait for a notification. index, _, ok := reflect.Select(sc) - // Was it the start / stop channel? + // Was it the stop channel? if index == 0 { if !ok { - if !started { - // start channel; start forwarding and - // swap this case for the stop channel - // to select stop requests. - started = true - sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)} - } else { - // stop channel; stop forwarding and - // clear this case so it is never - // selected again. - started = false - close(done) - sc[0].Chan = reflect.Value{} - } + // Stop forwarding and notify that it's done. + close(done) + return } continue } @@ -73,44 +60,17 @@ func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, // Otherwise, it was a signal on channel N. Index 0 represents the stop // channel, so index N represents the channel for signal N. - signal := linux.Signal(index) - - if !started { - // Kernel cannot receive signals, either because it is - // not ready yet or is shutting down. - // - // Kill ourselves if this signal would have killed the - // process before PrepareForwarding was called. i.e., all - // _SigKill signals; see Go - // src/runtime/sigtab_linux_generic.go. - // - // Otherwise ignore the signal. - // - // TODO(b/114489875): Drop in Go 1.12, which uses tgkill - // in runtime.raise. - switch signal { - case linux.SIGHUP, linux.SIGINT, linux.SIGTERM: - dieFromSignal(signal) - panic(fmt.Sprintf("Failed to die from signal %d", signal)) - default: - continue - } - } - - // Pass the signal to the handler. - handler(signal) + handler(linux.Signal(index)) } } -// PrepareHandler ensures that synchronous signals are passed to the given -// handler function and returns a callback that starts signal delivery, which -// itself returns a callback that stops signal handling. +// StartSignalForwarding ensures that synchronous signals are passed to the +// given handler function and returns a callback that stops signal delivery. // // Note that this function permanently takes over signal handling. After the // stop callback, signals revert to the default Go runtime behavior, which // cannot be overridden with external calls to signal.Notify. -func PrepareHandler(handler func(linux.Signal)) func() func() { - start := make(chan struct{}) +func StartSignalForwarding(handler func(linux.Signal)) func() { stop := make(chan struct{}) done := make(chan struct{}) @@ -128,13 +88,10 @@ func PrepareHandler(handler func(linux.Signal)) func() func() { signal.Notify(sigchan, syscall.Signal(sig)) } // Start up our listener. - go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. + go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. - return func() func() { - close(start) - return func() { - close(stop) - <-done - } + return func() { + close(stop) + <-done } } diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index c303435d5..1ebe22d34 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -15,8 +15,6 @@ package sighandling import ( - "fmt" - "runtime" "syscall" "unsafe" @@ -48,27 +46,3 @@ func IgnoreChildStop() error { return nil } - -// dieFromSignal kills the current process with sig. -// -// Preconditions: The default action of sig is termination. -func dieFromSignal(sig linux.Signal) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - sa := sigaction{handler: linux.SIG_DFL} - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigaction failed: %v", e)) - } - - set := linux.MakeSignalSet(sig) - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigprocmask failed: %v", e)) - } - - if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil { - panic(fmt.Sprintf("tgkill failed: %v", err)) - } - - panic("failed to die") -} diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 4a6e83a8b..357517ed4 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -17,6 +17,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserror", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4e95101b7..af1a4e95f 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" @@ -194,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ // the available space, we must align down. // // align must be >= 4 and each data int32 is 4 bytes. The length of the - // header is already aligned, so if we align to the with of the data there + // header is already aligned, so if we align to the width of the data there // are two cases: // 1. The aligned length is less than the length of the header. The // unaligned length was also less than the length of the header, so we // can't write anything. // 2. The aligned length is greater than or equal to the length of the - // header. We can write the header plus zero or more datas. We can't write - // a partial int32, so the length of the message will be - // min(aligned length, header + datas). + // header. We can write the header plus zero or more bytes of data. We can't + // write a partial int32, so the length of the message will be + // min(aligned length, header + data). if space < linux.SizeOfControlMessageHeader { flags |= linux.MSG_CTRUNC return buf, flags @@ -239,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = binary.Marshal(buf, usermem.ByteOrder, data) - // Check if we went over. + // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { return hdrBuf } - // Fix up length. + // Update control message length to include data. putUint64(ob, uint64(len(buf)-len(ob))) return alignSlice(buf, align) @@ -320,35 +321,109 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { buf, linux.SOL_TCP, linux.TCP_INQ, - 4, + t.Arch().Width(), inq, ) } +// PackTOS packs an IP_TOS socket control message. +func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_TOS, + t.Arch().Width(), + tos, + ) +} + +// PackTClass packs an IPV6_TCLASS socket control message. +func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IPV6, + linux.IPV6_TCLASS, + t.Arch().Width(), + tClass, + ) +} + +// PackControlMessages packs control messages into the given buffer. +// +// We skip control messages specific to Unix domain sockets. +// +// Note that some control messages may be truncated if they do not fit under +// the capacity of buf. +func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { + if cmsgs.IP.HasTimestamp { + buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) + } + + if cmsgs.IP.HasInq { + // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. + buf = PackInq(t, cmsgs.IP.Inq, buf) + } + + if cmsgs.IP.HasTOS { + buf = PackTOS(t, cmsgs.IP.TOS, buf) + } + + if cmsgs.IP.HasTClass { + buf = PackTClass(t, cmsgs.IP.TClass, buf) + } + + return buf +} + +// cmsgSpace is equivalent to CMSG_SPACE in Linux. +func cmsgSpace(t *kernel.Task, dataLen int) int { + return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width()) +} + +// CmsgsSpace returns the number of bytes needed to fit the control messages +// represented in cmsgs. +func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { + space := 0 + + if cmsgs.IP.HasTimestamp { + space += cmsgSpace(t, linux.SizeOfTimeval) + } + + if cmsgs.IP.HasInq { + space += cmsgSpace(t, linux.SizeOfControlMessageInq) + } + + if cmsgs.IP.HasTOS { + space += cmsgSpace(t, linux.SizeOfControlMessageTOS) + } + + if cmsgs.IP.HasTClass { + space += cmsgSpace(t, linux.SizeOfControlMessageTClass) + } + + return space +} + // Parse parses a raw socket control message into portable objects. -func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) { +func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( - fds linux.ControlMessageRights - haveCreds bool - creds linux.ControlMessageCredentials + cmsgs socket.ControlMessages + fds linux.ControlMessageRights ) for i := 0; i < len(buf); { if i+linux.SizeOfControlMessageHeader > len(buf) { - return transport.ControlMessages{}, syserror.EINVAL + return cmsgs, syserror.EINVAL } var h linux.ControlMessageHeader binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h) if h.Length < uint64(linux.SizeOfControlMessageHeader) { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } if h.Length > uint64(len(buf)-i) { - return transport.ControlMessages{}, syserror.EINVAL - } - if h.Level != linux.SOL_SOCKET { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } i += linux.SizeOfControlMessageHeader @@ -358,59 +433,79 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport. // sizeof(long) in CMSG_ALIGN. width := t.Arch().Width() - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) - numRights := rightsSize / linux.SizeOfControlMessageRight - - if len(fds)+numRights > linux.SCM_MAX_FD { - return transport.ControlMessages{}, syserror.EINVAL + switch h.Level { + case linux.SOL_SOCKET: + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) + numRights := rightsSize / linux.SizeOfControlMessageRight + + if len(fds)+numRights > linux.SCM_MAX_FD { + return socket.ControlMessages{}, syserror.EINVAL + } + + for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { + fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + } + + i += AlignUp(length, width) + + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + return socket.ControlMessages{}, syserror.EINVAL + } + + var creds linux.ControlMessageCredentials + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) + scmCreds, err := NewSCMCredentials(t, creds) + if err != nil { + return socket.ControlMessages{}, err + } + cmsgs.Unix.Credentials = scmCreds + i += AlignUp(length, width) + + default: + // Unknown message type. + return socket.ControlMessages{}, syserror.EINVAL } - - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + case linux.SOL_IP: + switch h.Type { + case linux.IP_TOS: + cmsgs.IP.HasTOS = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) + i += AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - i += AlignUp(length, width) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { - return transport.ControlMessages{}, syserror.EINVAL + case linux.SOL_IPV6: + switch h.Type { + case linux.IPV6_TCLASS: + cmsgs.IP.HasTClass = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) + i += AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) - haveCreds = true - i += AlignUp(length, width) - default: - // Unknown message type. - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } } - var credentials SCMCredentials - if haveCreds { - var err error - if credentials, err = NewSCMCredentials(t, creds); err != nil { - return transport.ControlMessages{}, err - } - } else { - credentials = makeCreds(t, socketOrEndpoint) + if cmsgs.Unix.Credentials == nil { + cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint) } - var rights SCMRights if len(fds) > 0 { - var err error - if rights, err = NewSCMRights(t, fds); err != nil { - return transport.ControlMessages{}, err + rights, err := NewSCMRights(t, fds) + if err != nil { + return socket.ControlMessages{}, err } + cmsgs.Unix.Rights = rights } - if credentials == nil && rights == nil { - return transport.ControlMessages{}, nil - } - - return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil + return cmsgs, nil } func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4d174dda4..4c44c7c0f 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -29,9 +29,12 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", + "//pkg/sentry/socket/control", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip/stack", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 92beb1bcf..c957b0f1d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -18,6 +18,7 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/fdnotifier" @@ -29,6 +30,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -41,6 +43,10 @@ const ( // sizeofSockaddr is the size in bytes of the largest sockaddr type // supported by this package. sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in) + + // maxControlLen is the maximum size of a control message buffer used in a + // recvmsg or sendmsg syscall. + maxControlLen = 1024 ) // socketOperations implements fs.FileOperations and socket.Socket for a socket @@ -281,26 +287,32 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt // Whitelist options and constrain option length. var optlen int switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 - case syscall.SO_LINGER: + case linux.SO_LINGER: optlen = syscall.SizeofLinger } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 - case syscall.TCP_INFO: + case linux.TCP_INFO: optlen = int(linux.SizeOfTCPInfo) } } + if optlen == 0 { return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT } @@ -320,19 +332,24 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ // Whitelist options and constrain option length. var optlen int switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 } } @@ -354,11 +371,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that netstack/tcpip/transport/unix doesn't understand. Kill the + // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the // Socket interface's dependence on netstack. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument @@ -370,6 +387,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags senderAddrBuf = make([]byte, sizeofSockaddr) } + var controlBuf []byte var msgFlags int recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -384,11 +402,6 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We always do a non-blocking recv*(). sysflags := flags | syscall.MSG_DONTWAIT - if dsts.NumBlocks() == 1 { - // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) - } - iovs := iovecsFromBlockSeq(dsts) msg := syscall.Msghdr{ Iov: &iovs[0], @@ -398,12 +411,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msg.Name = &senderAddrBuf[0] msg.Namelen = uint32(len(senderAddrBuf)) } + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) + controlLen = uint64(msg.Controllen) return n, nil }) @@ -429,14 +451,38 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags n, err = dst.CopyOutFrom(t, recvmsgToBlocks) } } - - // We don't allow control messages. - msgFlags &^= linux.MSG_CTRUNC + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) + + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } + + controlMessages := socket.ControlMessages{} + for _, unixCmsg := range unixControlMessages { + switch unixCmsg.Header.Level { + case syscall.SOL_IP: + switch unixCmsg.Header.Type { + case syscall.IP_TOS: + controlMessages.IP.HasTOS = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + } + case syscall.SOL_IPV6: + switch unixCmsg.Header.Type { + case syscall.IPV6_TCLASS: + controlMessages.IP.HasTClass = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + } + } + } + + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil } // SendMsg implements socket.Socket.SendMsg. @@ -446,6 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.ErrInvalidArgument } + space := uint64(control.CmsgsSpace(t, controlMessages)) + if space > maxControlLen { + space = maxControlLen + } + controlBuf := make([]byte, 0, space) + // PackControlMessages will append up to space bytes to controlBuf. + controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. if uint64(src.NumBytes()) != srcs.NumBytes() { @@ -458,7 +512,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // We always do a non-blocking send*(). sysflags := flags | syscall.MSG_DONTWAIT - if srcs.NumBlocks() == 1 { + if srcs.NumBlocks() == 1 && len(controlBuf) == 0 { // Skip allocating []syscall.Iovec. src := srcs.Head() n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to))) @@ -477,6 +531,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] msg.Name = &to[0] msg.Namelen = uint32(len(to)) } + if len(controlBuf) != 0 { + msg.Control = &controlBuf[0] + msg.Controllen = uint64(len(controlBuf)) + } return sendmsg(s.fd, &msg, sysflags) }) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 3a4fdec47..e67b46c9e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -16,8 +16,11 @@ package hostinet import ( "fmt" + "io" "io/ioutil" "os" + "reflect" + "strconv" "strings" "syscall" @@ -26,7 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -51,6 +56,8 @@ type Stack struct { tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool + netDevFile *os.File + netSNMPFile *os.File } // NewStack returns an empty Stack containing no configuration. @@ -98,6 +105,18 @@ func (s *Stack) Configure() error { log.Warningf("Failed to read if TCP SACK if enabled, setting to true") } + if f, err := os.Open("/proc/net/dev"); err != nil { + log.Warningf("Failed to open /proc/net/dev: %v", err) + } else { + s.netDevFile = f + } + + if f, err := os.Open("/proc/net/snmp"); err != nil { + log.Warningf("Failed to open /proc/net/snmp: %v", err) + } else { + s.netSNMPFile = f + } + return nil } @@ -326,9 +345,95 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// getLine reads one line from proc file, with specified prefix. +// The last argument, withHeader, specifies if it contains line header. +func getLine(f *os.File, prefix string, withHeader bool) string { + data := make([]byte, 4096) + + if _, err := f.Seek(0, 0); err != nil { + return "" + } + + if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF { + return "" + } + + prefix = prefix + ":" + lines := strings.Split(string(data), "\n") + for _, l := range lines { + l = strings.TrimSpace(l) + if strings.HasPrefix(l, prefix) { + if withHeader { + withHeader = false + continue + } + return l + } + } + return "" +} + +func toSlice(i interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(i)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserror.EOPNOTSUPP + var ( + snmpTCP bool + rawLine string + sliceStat []uint64 + ) + + switch stat.(type) { + case *inet.StatDev: + if s.netDevFile == nil { + return fmt.Errorf("/proc/net/dev is not opened for hostinet") + } + rawLine = getLine(s.netDevFile, arg, false /* with no header */) + case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite: + if s.netSNMPFile == nil { + return fmt.Errorf("/proc/net/snmp is not opened for hostinet") + } + rawLine = getLine(s.netSNMPFile, arg, true) + default: + return syserr.ErrEndpointOperation.ToError() + } + + if rawLine == "" { + return fmt.Errorf("Failed to get raw line") + } + + parts := strings.SplitN(rawLine, ":", 2) + if len(parts) != 2 { + return fmt.Errorf("Failed to get prefix from: %q", rawLine) + } + + sliceStat = toSlice(stat) + fields := strings.Fields(strings.TrimSpace(parts[1])) + if len(fields) != len(sliceStat) { + return fmt.Errorf("Failed to parse fields: %q", rawLine) + } + if _, ok := stat.(*inet.StatSNMPTCP); ok { + snmpTCP = true + } + for i := 0; i < len(sliceStat); i++ { + var err error + if snmpTCP && i == 3 { + var tmp int64 + // MaxConn field is signed, RFC 2012. + tmp, err = strconv.ParseInt(fields[i], 10, 64) + sliceStat[i] = uint64(tmp) // Convert back to int before use. + } else { + sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) + } + if err != nil { + return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + } + } + + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -338,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index f95803f91..79589e3c8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 689cad997..be005df24 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -30,6 +30,13 @@ type Protocol interface { // Protocol returns the Linux netlink protocol value. Protocol() int + // CanSend returns true if this protocol may ever send messages. + // + // TODO(gvisor.dev/issue/1119): This is a workaround to allow + // advertising support for otherwise unimplemented features on sockets + // that will never send messages, thus making those features no-ops. + CanSend() bool + // ProcessMessage processes a single message from userspace. // // If err == nil, any messages added to ms will be sent back to the diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index cc70ac237..6b4a0ecf4 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -61,6 +61,11 @@ func (p *Protocol) Protocol() int { return linux.NETLINK_ROUTE } +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return true +} + // dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests. func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index b2732ca29..4a1b87a9a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -53,6 +54,8 @@ const ( maxSendBufferSize = 4 << 20 // 4MB ) +var errNoFilter = syserr.New("no filter attached", linux.ENOENT) + // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() @@ -61,7 +64,7 @@ var netlinkSocketDevice = device.NewAnonDevice() // This implementation only supports userspace sending and receiving messages // to/from the kernel. // -// Socket implements socket.Socket. +// Socket implements socket.Socket and transport.Credentialer. // // +stateify savable type Socket struct { @@ -104,9 +107,19 @@ type Socket struct { // sendBufferSize is the send buffer "size". We don't actually have a // fixed buffer but only consume this many bytes. sendBufferSize uint32 + + // passcred indicates if this socket wants SCM credentials. + passcred bool + + // filter indicates that this socket has a BPF filter "installed". + // + // TODO(gvisor.dev/issue/1119): We don't actually support filtering, + // this is just bookkeeping for tracking add/remove. + filter bool } var _ socket.Socket = (*Socket)(nil) +var _ transport.Credentialer = (*Socket)(nil) // NewSocket creates a new Socket. func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -172,6 +185,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } +// Passcred implements transport.Credentialer.Passcred. +func (s *Socket) Passcred() bool { + s.mu.Lock() + passcred := s.passcred + s.mu.Unlock() + return passcred +} + +// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. +func (s *Socket) ConnectedPasscred() bool { + // This socket is connected to the kernel, which doesn't need creds. + // + // This is arbitrary, as ConnectedPasscred on this type has no callers. + return false +} + // Ioctl implements fs.FileOperations.Ioctl. func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) { // TODO(b/68878065): no ioctls supported. @@ -309,9 +338,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. // We don't have limit on receiving size. return int32(math.MaxInt32), nil + case linux.SO_PASSCRED: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + var passcred int32 + if s.Passcred() { + passcred = 1 + } + return passcred, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } + case linux.SOL_NETLINK: switch name { case linux.NETLINK_BROADCAST_ERROR, @@ -348,6 +388,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy s.sendBufferSize = size s.mu.Unlock() return nil + case linux.SO_RCVBUF: if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -355,6 +396,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy // We don't have limit on receiving size. So just accept anything as // valid for compatibility. return nil + + case linux.SO_PASSCRED: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + passcred := usermem.ByteOrder.Uint32(opt) + + s.mu.Lock() + s.passcred = passcred != 0 + s.mu.Unlock() + return nil + + case linux.SO_ATTACH_FILTER: + // TODO(gvisor.dev/issue/1119): We don't actually + // support filtering. If this socket can't ever send + // messages, then there is nothing to filter and we can + // advertise support. Otherwise, be conservative and + // return an error. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + s.filter = true + s.mu.Unlock() + return nil + + case linux.SO_DETACH_FILTER: + // TODO(gvisor.dev/issue/1119): See above. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + filter := s.filter + s.filter = false + s.mu.Unlock() + + if !filter { + return errNoFilter + } + + return nil + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -483,6 +570,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ }) } +// kernelSCM implements control.SCMCredentials with credentials that represent +// the kernel itself rather than a Task. +// +// +stateify savable +type kernelSCM struct{} + +// Equals implements transport.CredentialsControlMessage.Equals. +func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool { + _, ok := oc.(kernelSCM) + return ok +} + +// Credentials implements control.SCMCredentials.Credentials. +func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) { + return 0, auth.RootUID, auth.RootGID +} + +// kernelCreds is the concrete version of kernelSCM used in all creds. +var kernelCreds = &kernelSCM{} + // sendResponse sends the response messages in ms back to userspace. func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. @@ -491,10 +598,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error bufs = append(bufs, m.Finalize()) } + // All messages are from the kernel. + cms := transport.ControlMessages{ + Credentials: kernelCreds, + } + if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -521,7 +633,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error // Add the dump_done_errno payload. m.Put(int64(0)) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD new file mode 100644 index 000000000..0777f3baf --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -0,0 +1,17 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "uevent", + srcs = ["protocol.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/kernel", + "//pkg/sentry/socket/netlink", + "//pkg/syserr", + ], +) diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go new file mode 100644 index 000000000..b5d7808d7 --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -0,0 +1,60 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol. +// +// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does +// not support any device events, so these sockets never send any messages. +package uevent + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/syserr" +) + +// Protocol implements netlink.Protocol. +// +// +stateify savable +type Protocol struct{} + +var _ netlink.Protocol = (*Protocol)(nil) + +// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol. +func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) { + return &Protocol{}, nil +} + +// Protocol implements netlink.Protocol.Protocol. +func (p *Protocol) Protocol() int { + return linux.NETLINK_KOBJECT_UEVENT +} + +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return false +} + +// ProcessMessage implements netlink.Protocol.ProcessMessage. +func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { + // Silently ignore all messages. + return nil +} + +// init registers the NETLINK_KOBJECT_UEVENT provider. +func init() { + netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol) +} diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0ae573b45..140851c17 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -53,6 +53,7 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -148,6 +149,10 @@ var Metrics = tcpip.Stats{ TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), + CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -278,7 +283,7 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { + if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { return nil, syserr.TranslateNetstackError(err) } } @@ -296,6 +301,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) +var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -307,12 +313,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address { } // AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and -// AF_INET6 addresses. +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. // // strict indicates whether addresses with the AF_UNSPEC family are accepted of not. // -// AddressAndFamily returns an address, its family. +// AddressAndFamily returns an address and its family. func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { @@ -320,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { + if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } @@ -371,6 +377,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } return out, family, nil + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(b/129292371): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + case linux.AF_UNSPEC: return tcpip.FullAddress{}, family, nil @@ -1035,8 +1057,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - var v tcpip.DelayOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1105,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Millisecond), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -1153,6 +1187,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa copy(b, v) return b, nil + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1477,11 +1523,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o tcpip.DelayOption + var o int if v == 0 { o = 1 } - return syserr.TranslateNetstackError(ep.SetSockOpt(o)) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1529,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { @@ -1536,6 +1593,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1951,12 +2016,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(2 + l) } return &out, uint32(3 + l) + case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInetSize) + case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1972,7 +2039,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(b/129292371): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + default: return nil, 0 } diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 357a664cc..2d2c1ba2a 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -62,6 +62,10 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in } case linux.SOCK_RAW: + // TODO(b/142504697): "In order to create a raw socket, a + // process must have the CAP_NET_RAW capability in the user + // namespace that governs its network namespace." - raw(7) + // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { @@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in return 0, true, syserr.ErrProtocolNotSupported } -// Socket creates a new socket object for the AF_INET or AF_INET6 family. +// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET +// family. func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() @@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, nil } + // Packet sockets are handled separately, since they are neither INET + // nor INET6 specific. + if p.family == linux.AF_PACKET { + return packetSocket(t, eps, stype, protocol) + } + // Figure out the transport protocol. transProto, associated, err := getTransportProtocol(t, stype, protocol) if err != nil { @@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return New(t, p.family, stype, int(transProto), wq, ep) } +func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { + // TODO(b/142504697): "In order to create a packet socket, a process + // must have the CAP_NET_RAW capability in the user namespace that + // governs its network namespace." - packet(7) + + // Packet sockets require CAP_NET_RAW. + creds := auth.CredentialsFromContext(t) + if !creds.HasCapability(linux.CAP_NET_RAW) { + return nil, syserr.ErrNotPermitted + } + + // "cooked" packets don't contain link layer information. + var cooked bool + switch stype { + case linux.SOCK_DGRAM: + cooked = true + case linux.SOCK_RAW: + cooked = false + default: + return nil, syserr.ErrProtocolNotSupported + } + + // protocol is passed in network byte order, but netstack wants it in + // host order. + netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + + wq := &waiter.Queue{} + ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return New(t, linux.AF_PACKET, stype, protocol, wq, ep) +} + // Pair just returns nil sockets (not supported). func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } -// init registers socket providers for AF_INET and AF_INET6. +// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET. func init() { // Providers backed by netstack. p := []provider{ @@ -138,6 +184,9 @@ func init() { family: linux.AF_INET6, netProto: ipv6.ProtocolNumber, }, + { + family: linux.AF_PACKET, + }, } for i := range p { diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fda0156e5..a0db2d4fd 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserr.ErrEndpointOperation.ToError() + switch stats := stat.(type) { + case *inet.StatSNMPIP: + ip := Metrics.IP + *stats = inet.StatSNMPIP{ + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + } + case *inet.StatSNMPICMP: + in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + *stats = inet.StatSNMPICMP{ + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + in.DstUnreachable.Value(), // InDestUnreachs. + in.TimeExceeded.Value(), // InTimeExcds. + in.ParamProblem.Value(), // InParmProbs. + in.SrcQuench.Value(), // InSrcQuenchs. + in.Redirect.Value(), // InRedirects. + in.Echo.Value(), // InEchos. + in.EchoReply.Value(), // InEchoReps. + in.Timestamp.Value(), // InTimestamps. + in.TimestampReply.Value(), // InTimestampReps. + in.InfoRequest.Value(), // InAddrMasks. + in.InfoReply.Value(), // InAddrMaskReps. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. + } + case *inet.StatSNMPTCP: + tcp := Metrics.TCP + // RFC 2012 (updates 1213): SNMPv2-MIB-TCP. + *stats = inet.StatSNMPTCP{ + 1, // RtoAlgorithm. + 200, // RtoMin. + 120000, // RtoMax. + (1<<64 - 1), // MaxConn. + tcp.ActiveConnectionOpenings.Value(), // ActiveOpens. + tcp.PassiveConnectionOpenings.Value(), // PassiveOpens. + tcp.FailedConnectionAttempts.Value(), // AttemptFails. + tcp.EstablishedResets.Value(), // EstabResets. + tcp.CurrentEstablished.Value(), // CurrEstab. + tcp.ValidSegmentsReceived.Value(), // InSegs. + tcp.SegmentsSent.Value(), // OutSegs. + tcp.Retransmits.Value(), // RetransSegs. + tcp.InvalidSegmentsReceived.Value(), // InErrs. + tcp.ResetsSent.Value(), // OutRsts. + tcp.ChecksumErrors.Value(), // InCsumErrors. + } + case *inet.StatSNMPUDP: + udp := Metrics.UDP + *stats = inet.StatSNMPUDP{ + udp.PacketsReceived.Value(), // InDatagrams. + udp.UnknownPortErrors.Value(), // NoPorts. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + udp.PacketsSent.Value(), // OutDatagrams. + udp.ReceiveBufferErrors.Value(), // RcvbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + } + default: + return syserr.ErrEndpointOperation.ToError() + } + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -200,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 3a6baa308..4668b87d1 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -37,6 +37,7 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", "//pkg/unet", "//pkg/waiter", ], diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 5dcb6b455..f7878a760 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/unet" ) @@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto index 9586f5923..b677e9eb3 100644 --- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto +++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto @@ -3,7 +3,6 @@ syntax = "proto3"; // package syscall_rpc is a set of networking related system calls that can be // forwarded to a socket gofer. // -// TODO(b/77963526): Document individual RPCs. package syscall_rpc; message SendmsgRequest { diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 8c250c325..2389a9cdb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -43,6 +43,11 @@ type ControlMessages struct { IP tcpip.ControlMessages } +// Release releases Unix domain socket credentials and rights. +func (c *ControlMessages) Release() { + c.Unix.Release() +} + // Socket is the interface containing socket syscalls used by the syscall layer // to redirect them to the appropriate implementation. type Socket interface { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1aaae8487..885758054 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -118,6 +118,9 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { func extractPath(sockaddr []byte) (string, *syserr.Error) { addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 72ebf766d..d46421199 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -14,6 +14,7 @@ go_library( "open.go", "poll.go", "ptrace.go", + "select.go", "signal.go", "socket.go", "strace.go", diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go index 5d57b75af..e603f858f 100644 --- a/pkg/sentry/strace/linux64.go +++ b/pkg/sentry/strace/linux64.go @@ -40,7 +40,7 @@ var linuxAMD64 = SyscallMap{ 20: makeSyscallInfo("writev", FD, WriteIOVec, Hex), 21: makeSyscallInfo("access", Path, Oct), 22: makeSyscallInfo("pipe", PipeFDs), - 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval), + 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval), 24: makeSyscallInfo("sched_yield"), 25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex), 26: makeSyscallInfo("msync", Hex, Hex, Hex), @@ -287,7 +287,7 @@ var linuxAMD64 = SyscallMap{ 267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex), 268: makeSyscallInfo("fchmodat", FD, Path, Mode), 269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex), - 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex), + 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet), 271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex), 272: makeSyscallInfo("unshare", CloneFlags), 273: makeSyscallInfo("set_robust_list", Hex, Hex), @@ -335,5 +335,33 @@ var linuxAMD64 = SyscallMap{ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex), 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex), 317: makeSyscallInfo("seccomp", Hex, Hex, Hex), + 318: makeSyscallInfo("getrandom", Hex, Hex, Hex), + 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close. + 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex), + 321: makeSyscallInfo("bpf", Hex, Hex, Hex), + 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex), + 323: makeSyscallInfo("userfaultfd", Hex), + 324: makeSyscallInfo("membarrier", Hex, Hex), + 325: makeSyscallInfo("mlock2", Hex, Hex, Hex), + 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex), + 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex), + 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex), + 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex), + 330: makeSyscallInfo("pkey_alloc", Hex, Hex), + 331: makeSyscallInfo("pkey_free", Hex), 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex), + 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet), + 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex), + 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex), + 425: makeSyscallInfo("io_uring_setup", Hex, Hex), + 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex), + 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex), + 428: makeSyscallInfo("open_tree", FD, Path, Hex), + 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex), + 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close. + 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex), + 432: makeSyscallInfo("fsmount", FD, Hex, Hex), + 433: makeSyscallInfo("fspick", FD, Path, Hex), + 434: makeSyscallInfo("pidfd_open", Hex, Hex), + 435: makeSyscallInfo("clone3", Hex, Hex), } diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go new file mode 100644 index 000000000..c77d418e6 --- /dev/null +++ b/pkg/sentry/strace/select.go @@ -0,0 +1,56 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +func fdsFromSet(t *kernel.Task, set []byte) []int { + var fds []int + // Append n if the n-th bit is 1. + for i, v := range set { + for j := 0; j < 8; j++ { + if (v>>j)&1 == 1 { + fds = append(fds, i*8+j) + } + } + } + return fds +} + +func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string { + if nfds < 0 { + return fmt.Sprintf("%#x (negative nfds)", addr) + } + if addr == 0 { + return "null" + } + + // Calculate the size of the fd set (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 + + set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte) + if err != nil { + return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err) + } + + return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set)) +} diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 94334f6d2..51f2efb39 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -208,6 +208,15 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) i += linux.SizeOfControlMessageHeader width := t.Arch().Width() length := int(h.Length) - linux.SizeOfControlMessageHeader + if length < 0 { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 311389547..629c1f308 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -439,6 +439,8 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer())) case PollFDs: output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false)) + case SelectFDSet: + output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) case Oct: output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8)) case Hex: diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto index 4b2f73a5f..906c52c51 100644 --- a/pkg/sentry/strace/strace.proto +++ b/pkg/sentry/strace/strace.proto @@ -32,8 +32,7 @@ message Strace { } } -message StraceEnter { -} +message StraceEnter {} message StraceExit { // Return value formatted as string. diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 3c389d375..e5d486c4e 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -206,6 +206,10 @@ const ( // PollFDs is an array of struct pollfd. The number of entries in the // array is in the next argument. PollFDs + + // SelectFDSet is an fd_set argument in select(2)/pselect(2). The number of + // fds represented must be the first argument. + SelectFDSet ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index fb2c1777f..6766ba587 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -49,6 +49,7 @@ go_library( "sys_tls.go", "sys_utsname.go", "sys_write.go", + "sys_xattr.go", "timespec.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux", @@ -79,6 +80,7 @@ go_library( "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", + "//pkg/sentry/loader", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/safemem", diff --git a/pkg/sentry/syscalls/linux/flags.go b/pkg/sentry/syscalls/linux/flags.go index 444f2b004..07961dad9 100644 --- a/pkg/sentry/syscalls/linux/flags.go +++ b/pkg/sentry/syscalls/linux/flags.go @@ -50,5 +50,6 @@ func linuxToFlags(mask uint) fs.FileFlags { Directory: mask&linux.O_DIRECTORY != 0, Async: mask&linux.O_ASYNC != 0, LargeFile: mask&linux.O_LARGEFILE != 0, + Truncate: mask&linux.O_TRUNC != 0, } } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index b317cb99d..68589a377 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -16,7 +16,12 @@ package linux const ( - _LINUX_SYSNAME = "Linux" - _LINUX_RELEASE = "4.4.0" - _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016" + // LinuxSysname is the OS name advertised by gVisor. + LinuxSysname = "Linux" + + // LinuxRelease is the Linux release version number advertised by gVisor. + LinuxRelease = "4.4.0" + + // LinuxVersion is the version info advertised by gVisor. + LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016" ) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index e215ac049..272ae9991 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -34,9 +34,9 @@ var AMD64 = &kernel.SyscallTable{ // guides the interface provided by this syscall table. The build // version is that for a clean build with default kernel config, at 5 // minutes after v4.4 was tagged. - Sysname: _LINUX_SYSNAME, - Release: _LINUX_RELEASE, - Version: _LINUX_VERSION, + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, }, AuditNumber: linux.AUDIT_ARCH_X86_64, Table: map[uintptr]kernel.Syscall{ @@ -228,10 +228,10 @@ var AMD64 = &kernel.SyscallTable{ 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), 186: syscalls.Supported("gettid", Gettid), 187: syscalls.Supported("readahead", Readahead), - 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 188: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil), 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 191: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil), 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), @@ -260,7 +260,7 @@ var AMD64 = &kernel.SyscallTable{ 217: syscalls.Supported("getdents64", Getdents64), 218: syscalls.Supported("set_tid_address", SetTidAddress), 219: syscalls.Supported("restart_syscall", RestartSyscall), - 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), 222: syscalls.Supported("timer_create", TimerCreate), 223: syscalls.Supported("timer_settime", TimerSettime), @@ -362,16 +362,36 @@ var AMD64 = &kernel.SyscallTable{ 319: syscalls.Supported("memfd_create", MemfdCreate), 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) + 322: syscalls.Supported("execveat", Execveat), 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) + 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - // Syscalls after 325 are "backports" from versions of Linux after 4.4. + // Syscalls implemented after 325 are "backports" from versions + // of Linux after 4.4. 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), 327: syscalls.Supported("preadv2", Preadv2), 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 332: syscalls.Supported("statx", Statx), + 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, Emulate: map[usermem.Addr]uintptr{ diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index f82dfac31..3b584eed9 100644 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -30,9 +30,9 @@ var ARM64 = &kernel.SyscallTable{ OS: abi.Linux, Arch: arch.ARM64, Version: kernel.Version{ - Sysname: _LINUX_SYSNAME, - Release: _LINUX_RELEASE, - Version: _LINUX_VERSION, + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, }, AuditNumber: linux.AUDIT_ARCH_AARCH64, Table: map[uintptr]kernel.Syscall{ @@ -41,10 +41,10 @@ var ARM64 = &kernel.SyscallTable{ 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 5: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 5: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil), 6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 8: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 8: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil), 9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), @@ -224,7 +224,7 @@ var ARM64 = &kernel.SyscallTable{ 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), - 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), @@ -295,14 +295,33 @@ var ARM64 = &kernel.SyscallTable{ 278: syscalls.Supported("getrandom", GetRandom), 279: syscalls.Supported("memfd_create", MemfdCreate), 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 281: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) + 281: syscalls.Supported("execveat", Execveat), 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) + 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), 286: syscalls.Supported("preadv2", Preadv2), 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 291: syscalls.Supported("statx", Statx), + 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, Emulate: map[usermem.Addr]uintptr{}, diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index b9a8e3e21..9bc2445a5 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -169,10 +169,14 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint if dirPath { return syserror.ENOTDIR } - if flags&linux.O_TRUNC != 0 { - if err := d.Inode.Truncate(t, d, 0); err != nil { - return err - } + } + + // Truncate is called when O_TRUNC is specified for any kind of + // existing Dirent. Behavior is delegated to the entry's Truncate + // implementation. + if flags&linux.O_TRUNC != 0 { + if err := d.Inode.Truncate(t, d, 0); err != nil { + return err } } @@ -396,7 +400,9 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l return err } - // Should we truncate the file? + // Truncate is called when O_TRUNC is specified for any kind of + // existing Dirent. Behavior is delegated to the entry's Truncate + // implementation. if flags&linux.O_TRUNC != 0 { if err := found.Inode.Truncate(t, found, 0); err != nil { return err @@ -834,23 +840,40 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(newfd), nil, nil } -func fGetOwn(t *kernel.Task, file *fs.File) int32 { +func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx { ma := file.Async(nil) if ma == nil { - return 0 + return linux.FOwnerEx{} } a := ma.(*fasync.FileAsync) ot, otg, opg := a.Owner() switch { case ot != nil: - return int32(t.PIDNamespace().IDOfTask(ot)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_TID, + PID: int32(t.PIDNamespace().IDOfTask(ot)), + } case otg != nil: - return int32(t.PIDNamespace().IDOfThreadGroup(otg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PID, + PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)), + } case opg != nil: - return int32(-t.PIDNamespace().IDOfProcessGroup(opg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PGRP, + PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)), + } default: - return 0 + return linux.FOwnerEx{} + } +} + +func fGetOwn(t *kernel.Task, file *fs.File) int32 { + owner := fGetOwnEx(t, file) + if owner.Type == linux.F_OWNER_PGRP { + return -owner.PID } + return owner.PID } // fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux. @@ -895,11 +918,13 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall t.FDTable().SetFlags(fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) + return 0, nil, nil case linux.F_GETFL: return uintptr(file.Flags().ToLinux()), nil, nil case linux.F_SETFL: flags := uint(args[2].Uint()) file.SetFlags(linuxToFlags(flags).Settable()) + return 0, nil, nil case linux.F_SETLK, linux.F_SETLKW: // In Linux the file system can choose to provide lock operations for an inode. // Normally pipe and socket types lack lock operations. We diverge and use a heavy @@ -1002,6 +1027,44 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_SETOWN: fSetOwn(t, file, args[2].Int()) return 0, nil, nil + case linux.F_GETOWN_EX: + addr := args[2].Pointer() + owner := fGetOwnEx(t, file) + _, err := t.CopyOut(addr, &owner) + return 0, nil, err + case linux.F_SETOWN_EX: + addr := args[2].Pointer() + var owner linux.FOwnerEx + n, err := t.CopyIn(addr, &owner) + if err != nil { + return 0, nil, err + } + a := file.Async(fasync.New).(*fasync.FileAsync) + switch owner.Type { + case linux.F_OWNER_TID: + task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) + if task == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerTask(t, task) + return uintptr(n), nil, nil + case linux.F_OWNER_PID: + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID)) + if tg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerThreadGroup(t, tg) + return uintptr(n), nil, nil + case linux.F_OWNER_PGRP: + pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID)) + if pg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerProcessGroup(t, pg) + return uintptr(n), nil, nil + default: + return 0, nil, syserror.EINVAL + } case linux.F_GET_SEALS: val, err := tmpfs.GetSeals(file.Dirent.Inode) return uintptr(val), nil, err @@ -1029,7 +1092,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Everything else is not yet supported. return 0, nil, syserror.EINVAL } - return 0, nil, nil } const ( @@ -1483,6 +1545,8 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if fs.IsDir(d.Inode.StableAttr) { return syserror.EISDIR } + // In contrast to open(O_TRUNC), truncate(2) is only valid for file + // types. if !fs.IsFile(d.Inode.StableAttr) { return syserror.EINVAL } @@ -1521,7 +1585,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EINVAL } - // Note that this is different from truncate(2) above, where a + // In contrast to open(O_TRUNC), truncate(2) is only valid for file + // types. Note that this is different from truncate(2) above, where a // directory returns EISDIR. if !fs.IsFile(file.Dirent.Inode.StableAttr) { return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index b9bd25464..bde17a767 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -226,6 +226,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if mask == 0 { return 0, nil, syserror.EINVAL } + if val <= 0 { + // The Linux kernel wakes one waiter even if val is + // non-positive. + val = 1 + } n, err := t.Futex().Wake(t, addr, private, mask, val) return uintptr(n), nil, err @@ -242,6 +247,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FUTEX_WAKE_OP: op := uint32(val3) + if val <= 0 { + // The Linux kernel wakes one waiter even if val is + // non-positive. + val = 1 + } n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op) return uintptr(n), nil, err diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 7a13beac2..2b2df989a 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -197,53 +197,51 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) return remainingTimeout, n, err } +// CopyInFDSet copies an fd set from select(2)/pselect(2). +func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { + set := make([]byte, nBytes) + + if addr != 0 { + if _, err := t.CopyIn(addr, &set); err != nil { + return nil, err + } + // If we only use part of the last byte, mask out the extraneous bits. + // + // N.B. This only works on little-endian architectures. + if nBitsInLastPartialByte != 0 { + set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte + } + } + return set, nil +} + func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { return 0, syserror.EINVAL } - // Capture all the provided input vectors. - // - // N.B. This only works on little-endian architectures. - byteCount := (nfds + 7) / 8 - - bitsInLastPartialByte := uint(nfds % 8) - r := make([]byte, byteCount) - w := make([]byte, byteCount) - e := make([]byte, byteCount) + // Calculate the size of the fd sets (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 - if readFDs != 0 { - if _, err := t.CopyIn(readFDs, &r); err != nil { - return 0, err - } - // Mask out bits above nfds. - if bitsInLastPartialByte != 0 { - r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + // Capture all the provided input vectors. + r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if writeFDs != 0 { - if _, err := t.CopyIn(writeFDs, &w); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if exceptFDs != 0 { - if _, err := t.CopyIn(exceptFDs, &e); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } // Count how many FDs are actually being requested so that we can build // a PollFD array. fdCount := 0 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { v := r[i] | w[i] | e[i] for v != 0 { v &= (v - 1) @@ -254,7 +252,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Build the PollFD array. pfd := make([]linux.PollFD, 0, fdCount) var fd int32 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { rV, wV, eV := r[i], w[i], e[i] v := rV | wV | eV m := byte(1) @@ -295,8 +293,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add } // Do the syscall, then count the number of bits set. - _, _, err := pollBlock(t, pfd, timeout) - if err != nil { + if _, _, err = pollBlock(t, pfd, timeout); err != nil { return 0, syserror.ConvertIntr(err, syserror.EINTR) } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index b5a72ce63..4b5aafcc0 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -447,16 +447,13 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - // Read the length if present. Reject negative values. + // Read the length. Reject negative values. optLen := int32(0) - if optLenAddr != 0 { - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { - return 0, nil, err - } - - if optLen < 0 { - return 0, nil, syserror.EINVAL - } + if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + return 0, nil, err + } + if optLen < 0 { + return 0, nil, syserror.EINVAL } // Call syscall implementation then copy both value and value len out. @@ -465,11 +462,9 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - if optLenAddr != 0 { - vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { - return 0, nil, err - } + vLen := int32(binary.Size(v)) + if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + return 0, nil, err } if v != nil { @@ -769,7 +764,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Unix.Release() + cms.Release() } if int(msg.Flags) != mflags { @@ -789,24 +784,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } - defer cms.Unix.Release() + defer cms.Release() controlData := make([]byte, 0, msg.ControlLen) + controlData = control.PackControlMessages(t, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) controlData, mflags = control.PackCredentials(t, creds, controlData, mflags) } - if cms.IP.HasTimestamp { - controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData) - } - - if cms.IP.HasInq { - // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. - controlData = control.PackInq(t, cms.IP.Inq, controlData) - } - if cms.Unix.Rights != nil { controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags) } @@ -882,7 +869,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Unix.Release() + cm.Release() if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } @@ -1065,7 +1052,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme } // Call the syscall implementation. - n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages}) + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) if err != nil { controlMessages.Release() diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 8ab7ffa25..4115116ff 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -15,12 +15,15 @@ package linux import ( + "path" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" + "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -67,8 +70,22 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal argvAddr := args[1].Pointer() envvAddr := args[2].Pointer() - // Extract our arguments. - filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX) + return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0) +} + +// Execveat implements linux syscall execveat(2). +func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirFD := args[0].Int() + pathnameAddr := args[1].Pointer() + argvAddr := args[2].Pointer() + envvAddr := args[3].Pointer() + flags := args[4].Int() + + return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags) +} + +func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { + pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) if err != nil { return 0, nil, err } @@ -89,14 +106,66 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } } + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + atEmptyPath := flags&linux.AT_EMPTY_PATH != 0 + if !atEmptyPath && len(pathname) == 0 { + return 0, nil, syserror.ENOENT + } + resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 + root := t.FSContext().RootDirectory() defer root.DecRef() - wd := t.FSContext().WorkingDirectory() - defer wd.DecRef() + + var wd *fs.Dirent + var executable *fs.File + var closeOnExec bool + if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) { + // Even if the pathname is absolute, we may still need the wd + // for interpreter scripts if the path of the interpreter is + // relative. + wd = t.FSContext().WorkingDirectory() + } else { + // Need to extract the given FD. + f, fdFlags := t.FDTable().Get(dirFD) + if f == nil { + return 0, nil, syserror.EBADF + } + defer f.DecRef() + closeOnExec = fdFlags.CloseOnExec + + if atEmptyPath && len(pathname) == 0 { + executable = f + } else { + wd = f.Dirent + wd.IncRef() + if !fs.IsDir(wd.Inode.StableAttr) { + return 0, nil, syserror.ENOTDIR + } + } + } + if wd != nil { + defer wd.DecRef() + } // Load the new TaskContext. - maxTraversals := uint(linux.MaxSymlinkTraversals) - tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet()) + remainingTraversals := uint(linux.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: t.MountNamespace(), + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: resolveFinal, + Filename: pathname, + File: executable, + CloseOnExec: closeOnExec, + Argv: argv, + Envv: envv, + Features: t.Arch().FeatureSet(), + } + + tc, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 27cd2c336..ad4b67806 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Pwritev2 implements linux syscall pwritev2(2). -// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality. func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { // While the syscall is // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go new file mode 100644 index 000000000..97d9a65ea --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -0,0 +1,169 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Getxattr implements linux syscall getxattr(2). +func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + + path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */) + if err != nil { + return 0, nil, err + } + + valueLen := 0 + err = fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + value, err := getxattr(t, d, dirPath, nameAddr) + if err != nil { + return err + } + + valueLen = len(value) + if size == 0 { + return nil + } + if size > linux.XATTR_SIZE_MAX { + size = linux.XATTR_SIZE_MAX + } + if valueLen > int(size) { + return syserror.ERANGE + } + + _, err = t.CopyOutBytes(valueAddr, []byte(value)) + return err + }) + if err != nil { + return 0, nil, err + } + return uintptr(valueLen), nil, nil +} + +// getxattr implements getxattr from the given *fs.Dirent. +func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr) (string, error) { + if dirPath && !fs.IsDir(d.Inode.StableAttr) { + return "", syserror.ENOTDIR + } + + if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil { + return "", err + } + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return "", err + } + + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return "", syserror.EOPNOTSUPP + } + + return d.Inode.Getxattr(name) +} + +// Setxattr implements linux syscall setxattr(2). +func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + flags := args[4].Uint() + + path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */) + if err != nil { + return 0, nil, err + } + + if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { + return 0, nil, syserror.EINVAL + } + + return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + return setxattr(t, d, dirPath, nameAddr, valueAddr, size, flags) + }) +} + +// setxattr implements setxattr from the given *fs.Dirent. +func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint, flags uint32) error { + if dirPath && !fs.IsDir(d.Inode.StableAttr) { + return syserror.ENOTDIR + } + + if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil { + return err + } + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return err + } + + if size > linux.XATTR_SIZE_MAX { + return syserror.E2BIG + } + buf := make([]byte, size) + if _, err = t.CopyInBytes(valueAddr, buf); err != nil { + return err + } + value := string(buf) + + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + + return d.Inode.Setxattr(name, value) +} + +func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { + name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) + if err != nil { + if err == syserror.ENAMETOOLONG { + return "", syserror.ERANGE + } + return "", err + } + if len(name) == 0 { + return "", syserror.ERANGE + } + return name, nil +} + +func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error { + // Restrict xattrs to regular files and directories. + // + // In Linux, this restriction technically only applies to xattrs in the + // "user.*" namespace, but we don't allow any other xattr prefixes anyway. + if !fs.IsRegular(i.StableAttr) && !fs.IsDir(i.StableAttr) { + if perms.Write { + return syserror.EPERM + } + return syserror.ENODATA + } + + return i.CheckPermission(t, perms) +} diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index d3a4cd943..18e212dff 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//third_party/gvsync:generic_seqatomic", + template = "//pkg/syncutil:generic_seqatomic", types = { "Value": "Parameters", }, @@ -36,8 +36,8 @@ go_library( deps = [ "//pkg/log", "//pkg/metric", + "//pkg/syncutil", "//pkg/syserror", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/sentry/usermem/bytes_io.go index 8d88396ba..7898851b3 100644 --- a/pkg/sentry/usermem/bytes_io.go +++ b/pkg/sentry/usermem/bytes_io.go @@ -102,19 +102,34 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { } func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) { - blocks := make([]safemem.Block, 0, ars.NumRanges()) - for !ars.IsEmpty() { - ar := ars.Head() - n, err := b.rangeCheck(ar.Start, int(ar.Length())) - if n != 0 { - blocks = append(blocks, safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start):int(ar.Start)+n])) + switch ars.NumRanges() { + case 0: + return safemem.BlockSeq{}, nil + case 1: + block, err := b.blockFromAddrRange(ars.Head()) + return safemem.BlockSeqOf(block), err + default: + blocks := make([]safemem.Block, 0, ars.NumRanges()) + for !ars.IsEmpty() { + block, err := b.blockFromAddrRange(ars.Head()) + if block.Len() != 0 { + blocks = append(blocks, block) + } + if err != nil { + return safemem.BlockSeqFromSlice(blocks), err + } + ars = ars.Tail() } - if err != nil { - return safemem.BlockSeqFromSlice(blocks), err - } - ars = ars.Tail() + return safemem.BlockSeqFromSlice(blocks), nil + } +} + +func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { + n, err := b.rangeCheck(ar.Start, int(ar.Length())) + if n == 0 { + return safemem.Block{}, err } - return safemem.BlockSeqFromSlice(blocks), nil + return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err } // BytesIOSequence returns an IOSequence representing the given byte slice. diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index eff4b44f6..e3e554b88 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -12,13 +12,14 @@ go_library( "file_description.go", "file_description_impl_util.go", "filesystem.go", + "filesystem_impl_util.go", "filesystem_type.go", "mount.go", "mount_unsafe.go", "options.go", + "pathname.go", "permissions.go", "resolving_path.go", - "syscalls.go", "testutil.go", "vfs.go", ], @@ -32,9 +33,9 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/usermem", + "//pkg/syncutil", "//pkg/syserror", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 7847854bc..9aa133bcb 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -39,8 +39,8 @@ Mount references are held by: - Mount: Each referenced Mount holds a reference on its parent, which is the mount containing its mount point. -- VirtualFilesystem: A reference is held on all Mounts that are attached - (reachable by Mount traversal). +- VirtualFilesystem: A reference is held on each Mount that has not been + umounted. MountNamespace and FileDescription references are held by users of VFS. The expectation is that each `kernel.Task` holds a reference on its corresponding diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go index 32cf9151b..705194ebc 100644 --- a/pkg/sentry/vfs/context.go +++ b/pkg/sentry/vfs/context.go @@ -24,6 +24,9 @@ type contextID int const ( // CtxMountNamespace is a Context.Value key for a MountNamespace. CtxMountNamespace contextID = iota + + // CtxRoot is a Context.Value key for a VFS root. + CtxRoot ) // MountNamespaceFromContext returns the MountNamespace used by ctx. It does @@ -35,3 +38,13 @@ func MountNamespaceFromContext(ctx context.Context) *MountNamespace { } return nil } + +// RootFromContext returns the VFS root used by ctx. It takes a reference on +// the returned VirtualDentry. If ctx does not have a specific VFS root, +// RootFromContext returns a zero-value VirtualDentry. +func RootFromContext(ctx context.Context) VirtualDentry { + if v := ctx.Value(CtxRoot); v != nil { + return v.(VirtualDentry) + } + return VirtualDentry{} +} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 45912fc58..6209eb053 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -16,6 +16,7 @@ package vfs import ( "fmt" + "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/syserror" @@ -50,7 +51,7 @@ import ( // and not inodes. Furthermore, when parties outside the scope of VFS can // rename inodes on such filesystems, VFS generally cannot "follow" the rename, // both due to synchronization issues and because it may not even be able to -// name the destination path; this implies that it would in fact be *incorrect* +// name the destination path; this implies that it would in fact be incorrect // for Dentries to be associated with inodes on such filesystems. Consequently, // operations that are inode operations in Linux are FilesystemImpl methods // and/or FileDescriptionImpl methods in gVisor's VFS. Filesystems that do @@ -87,6 +88,9 @@ type Dentry struct { // children are child Dentries. children map[string]*Dentry + // mu synchronizes disowning and mounting over this Dentry. + mu sync.Mutex + // impl is the DentryImpl associated with this Dentry. impl is immutable. // This should be the last field in Dentry. impl DentryImpl @@ -114,7 +118,7 @@ func (d *Dentry) Impl() DentryImpl { type DentryImpl interface { // IncRef increments the Dentry's reference count. A Dentry with a non-zero // reference count must remain coherent with the state of the filesystem. - IncRef(fs *Filesystem) + IncRef() // TryIncRef increments the Dentry's reference count and returns true. If // the Dentry's reference count is zero, TryIncRef may do nothing and @@ -122,10 +126,10 @@ type DentryImpl interface { // guarantee that the Dentry is coherent with the state of the filesystem.) // // TryIncRef does not require that a reference is held on the Dentry. - TryIncRef(fs *Filesystem) bool + TryIncRef() bool // DecRef decrements the Dentry's reference count. - DecRef(fs *Filesystem) + DecRef() } // IsDisowned returns true if d is disowned. @@ -142,16 +146,20 @@ func (d *Dentry) isMounted() bool { return atomic.LoadUint32(&d.mounts) != 0 } -func (d *Dentry) incRef(fs *Filesystem) { - d.impl.IncRef(fs) +// IncRef increments d's reference count. +func (d *Dentry) IncRef() { + d.impl.IncRef() } -func (d *Dentry) tryIncRef(fs *Filesystem) bool { - return d.impl.TryIncRef(fs) +// TryIncRef increments d's reference count and returns true. If d's reference +// count is zero, TryIncRef may instead do nothing and return false. +func (d *Dentry) TryIncRef() bool { + return d.impl.TryIncRef() } -func (d *Dentry) decRef(fs *Filesystem) { - d.impl.DecRef(fs) +// DecRef decrements d's reference count. +func (d *Dentry) DecRef() { + d.impl.DecRef() } // These functions are exported so that filesystem implementations can use @@ -191,6 +199,18 @@ func (d *Dentry) HasChildren() bool { return len(d.children) != 0 } +// Children returns a map containing all of d's children. +func (d *Dentry) Children() map[string]*Dentry { + if !d.HasChildren() { + return nil + } + m := make(map[string]*Dentry) + for name, child := range d.children { + m[name] = child + } + return m +} + // InsertChild makes child a child of d with the given name. // // InsertChild is a mutator of d and child. @@ -228,36 +248,48 @@ func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dent panic("d is already disowned") } } - vfs.mountMu.RLock() - if _, ok := mntns.mountpoints[d]; ok { - vfs.mountMu.RUnlock() + vfs.mountMu.Lock() + if mntns.mountpoints[d] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } - // Return with vfs.mountMu locked, which will be unlocked by - // AbortDeleteDentry or CommitDeleteDentry. + d.mu.Lock() + vfs.mountMu.Unlock() + // Return with d.mu locked to block attempts to mount over it; it will be + // unlocked by AbortDeleteDentry or CommitDeleteDentry. return nil } // AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion // fails. -func (vfs *VirtualFilesystem) AbortDeleteDentry() { - vfs.mountMu.RUnlock() +func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) { + d.mu.Unlock() } // CommitDeleteDentry must be called after the file represented by d is // deleted, and causes d to become disowned. // +// CommitDeleteDentry is a mutator of d and d.Parent(). +// // Preconditions: PrepareDeleteDentry was previously called on d. func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { - delete(d.parent.children, d.name) + if d.parent != nil { + delete(d.parent.children, d.name) + } d.setDisowned() - // TODO: lazily unmount mounts at d - vfs.mountMu.RUnlock() + d.mu.Unlock() + if d.isMounted() { + vfs.forgetDisownedMountpoint(d) + } } // DeleteDentry combines PrepareDeleteDentry and CommitDeleteDentry, as // appropriate for in-memory filesystems that don't need to ensure that some // external state change succeeds before committing the deletion. +// +// DeleteDentry is a mutator of d and d.Parent(). +// +// Preconditions: d is a child Dentry. func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) error { if err := vfs.PrepareDeleteDentry(mntns, d); err != nil { return err @@ -266,6 +298,27 @@ func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) err return nil } +// ForceDeleteDentry causes d to become disowned. It should only be used in +// cases where VFS has no ability to stop the deletion (e.g. d represents the +// local state of a file on a remote filesystem on which the file has already +// been deleted). +// +// ForceDeleteDentry is a mutator of d and d.Parent(). +// +// Preconditions: d is a child Dentry. +func (vfs *VirtualFilesystem) ForceDeleteDentry(d *Dentry) { + if checkInvariants { + if d.parent == nil { + panic("d is independent") + } + if d.IsDisowned() { + panic("d is already disowned") + } + } + d.mu.Lock() + vfs.CommitDeleteDentry(d) +} + // PrepareRenameDentry must be called before attempting to rename the file // represented by from. If to is not nil, it represents the file that will be // replaced or exchanged by the rename. If PrepareRenameDentry succeeds, the @@ -291,18 +344,21 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t } } } - vfs.mountMu.RLock() - if _, ok := mntns.mountpoints[from]; ok { - vfs.mountMu.RUnlock() + vfs.mountMu.Lock() + if mntns.mountpoints[from] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } if to != nil { - if _, ok := mntns.mountpoints[to]; ok { - vfs.mountMu.RUnlock() + if mntns.mountpoints[to] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } + to.mu.Lock() } - // Return with vfs.mountMu locked, which will be unlocked by + from.mu.Lock() + vfs.mountMu.Unlock() + // Return with from.mu and to.mu locked, which will be unlocked by // AbortRenameDentry, CommitRenameReplaceDentry, or // CommitRenameExchangeDentry. return nil @@ -310,38 +366,76 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t // AbortRenameDentry must be called after PrepareRenameDentry if the rename // fails. -func (vfs *VirtualFilesystem) AbortRenameDentry() { - vfs.mountMu.RUnlock() +func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) { + from.mu.Unlock() + if to != nil { + to.mu.Unlock() + } } // CommitRenameReplaceDentry must be called after the file represented by from // is renamed without RENAME_EXCHANGE. If to is not nil, it represents the file // that was replaced by from. // +// CommitRenameReplaceDentry is a mutator of from, to, from.Parent(), and +// to.Parent(). +// // Preconditions: PrepareRenameDentry was previously called on from and to. // newParent.Child(newName) == to. func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, newParent *Dentry, newName string, to *Dentry) { - if to != nil { - to.setDisowned() - // TODO: lazily unmount mounts at d - } if newParent.children == nil { newParent.children = make(map[string]*Dentry) } newParent.children[newName] = from from.parent = newParent from.name = newName - vfs.mountMu.RUnlock() + from.mu.Unlock() + if to != nil { + to.setDisowned() + to.mu.Unlock() + if to.isMounted() { + vfs.forgetDisownedMountpoint(to) + } + } } // CommitRenameExchangeDentry must be called after the files represented by // from and to are exchanged by rename(RENAME_EXCHANGE). // +// CommitRenameExchangeDentry is a mutator of from, to, from.Parent(), and +// to.Parent(). +// // Preconditions: PrepareRenameDentry was previously called on from and to. func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) { from.parent, to.parent = to.parent, from.parent from.name, to.name = to.name, from.name from.parent.children[from.name] = from to.parent.children[to.name] = to - vfs.mountMu.RUnlock() + from.mu.Unlock() + to.mu.Unlock() +} + +// forgetDisownedMountpoint is called when a mount point is deleted to umount +// all mounts using it in all other mount namespaces. +// +// forgetDisownedMountpoint is analogous to Linux's +// fs/namespace.c:__detach_mounts(). +func (vfs *VirtualFilesystem) forgetDisownedMountpoint(d *Dentry) { + var ( + vdsToDecRef []VirtualDentry + mountsToDecRef []*Mount + ) + vfs.mountMu.Lock() + vfs.mounts.seq.BeginWrite() + for mnt := range vfs.mountpoints[d] { + vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(mnt, &umountRecursiveOptions{}, vdsToDecRef, mountsToDecRef) + } + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 3a9665800..df03886c3 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -20,8 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -38,43 +40,64 @@ type FileDescription struct { // operations. refs int64 + // statusFlags contains status flags, "initialized by open(2) and possibly + // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic + // memory operations. + statusFlags uint32 + // vd is the filesystem location at which this FileDescription was opened. // A reference is held on vd. vd is immutable. vd VirtualDentry + opts FileDescriptionOptions + // impl is the FileDescriptionImpl associated with this Filesystem. impl is // immutable. This should be the last field in FileDescription. impl FileDescriptionImpl } -// Init must be called before first use of fd. It takes references on mnt and -// d. -func (fd *FileDescription) Init(impl FileDescriptionImpl, mnt *Mount, d *Dentry) { +// FileDescriptionOptions contains options to FileDescription.Init(). +type FileDescriptionOptions struct { + // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is + // usually only the case if O_DIRECT would actually have an effect. + AllowDirectIO bool +} + +// Init must be called before first use of fd. It takes ownership of references +// on mnt and d held by the caller. statusFlags is the initial file description +// status flags, which is usually the full set of flags passed to open(2). +func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) { fd.refs = 1 + fd.statusFlags = statusFlags | linux.O_LARGEFILE fd.vd = VirtualDentry{ mount: mnt, dentry: d, } - fd.vd.IncRef() + fd.opts = *opts fd.impl = impl } -// Impl returns the FileDescriptionImpl associated with fd. -func (fd *FileDescription) Impl() FileDescriptionImpl { - return fd.impl -} - -// VirtualDentry returns the location at which fd was opened. It does not take -// a reference on the returned VirtualDentry. -func (fd *FileDescription) VirtualDentry() VirtualDentry { - return fd.vd -} - // IncRef increments fd's reference count. func (fd *FileDescription) IncRef() { atomic.AddInt64(&fd.refs, 1) } +// TryIncRef increments fd's reference count and returns true. If fd's +// reference count is already zero, TryIncRef does nothing and returns false. +// +// TryIncRef does not require that a reference is held on fd. +func (fd *FileDescription) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&fd.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) { + return true + } + } +} + // DecRef decrements fd's reference count. func (fd *FileDescription) DecRef() { if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 { @@ -85,6 +108,82 @@ func (fd *FileDescription) DecRef() { } } +// Mount returns the mount on which fd was opened. It does not take a reference +// on the returned Mount. +func (fd *FileDescription) Mount() *Mount { + return fd.vd.mount +} + +// Dentry returns the dentry at which fd was opened. It does not take a +// reference on the returned Dentry. +func (fd *FileDescription) Dentry() *Dentry { + return fd.vd.dentry +} + +// VirtualDentry returns the location at which fd was opened. It does not take +// a reference on the returned VirtualDentry. +func (fd *FileDescription) VirtualDentry() VirtualDentry { + return fd.vd +} + +// StatusFlags returns file description status flags, as for fcntl(F_GETFL). +func (fd *FileDescription) StatusFlags() uint32 { + return atomic.LoadUint32(&fd.statusFlags) +} + +// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL). +func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Credentials, flags uint32) error { + // Compare Linux's fs/fcntl.c:setfl(). + oldFlags := fd.StatusFlags() + // Linux documents this check as "O_APPEND cannot be cleared if the file is + // marked as append-only and the file is open for write", which would make + // sense. However, the check as actually implemented seems to be "O_APPEND + // cannot be changed if the file is marked as append-only". + if (flags^oldFlags)&linux.O_APPEND != 0 { + stat, err := fd.impl.Stat(ctx, StatOptions{ + // There is no mask bit for stx_attributes. + Mask: 0, + // Linux just reads inode::i_flags directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return err + } + if (stat.AttributesMask&linux.STATX_ATTR_APPEND != 0) && (stat.Attributes&linux.STATX_ATTR_APPEND != 0) { + return syserror.EPERM + } + } + if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) { + stat, err := fd.impl.Stat(ctx, StatOptions{ + Mask: linux.STATX_UID, + // Linux's inode_owner_or_capable() just reads inode::i_uid + // directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return err + } + if stat.Mask&linux.STATX_UID == 0 { + return syserror.EPERM + } + if !CanActAsOwner(creds, auth.KUID(stat.UID)) { + return syserror.EPERM + } + } + if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO { + return syserror.EINVAL + } + // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? + const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK + atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) + return nil +} + +// Impl returns the FileDescriptionImpl associated with fd. +func (fd *FileDescription) Impl() FileDescriptionImpl { + return fd.impl +} + // FileDescriptionImpl contains implementation details for an FileDescription. // Implementations of FileDescriptionImpl should contain their associated // FileDescription by value as their first field. @@ -104,14 +203,6 @@ type FileDescriptionImpl interface { // prevent the file descriptor from being closed. OnClose(ctx context.Context) error - // StatusFlags returns file description status flags, as for - // fcntl(F_GETFL). - StatusFlags(ctx context.Context) (uint32, error) - - // SetStatusFlags sets file description status flags, as for - // fcntl(F_SETFL). - SetStatusFlags(ctx context.Context, flags uint32) error - // Stat returns metadata for the file represented by the FileDescription. Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) @@ -185,7 +276,21 @@ type FileDescriptionImpl interface { // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) - // TODO: extended attributes; file locking + // Listxattr returns all extended attribute names for the file. + Listxattr(ctx context.Context) ([]string, error) + + // Getxattr returns the value associated with the given extended attribute + // for the file. + Getxattr(ctx context.Context, name string) (string, error) + + // Setxattr changes the value associated with the given extended attribute + // for the file. + Setxattr(ctx context.Context, opts SetxattrOptions) error + + // Removexattr removes the given extended attribute from the file. + Removexattr(ctx context.Context, name string) error + + // TODO: file locking } // Dirent holds the information contained in struct linux_dirent64. @@ -214,3 +319,160 @@ type IterDirentsCallback interface { // called. Handle(dirent Dirent) bool } + +// OnClose is called when a file descriptor representing the FileDescription is +// closed. Returning a non-nil error should not prevent the file descriptor +// from being closed. +func (fd *FileDescription) OnClose(ctx context.Context) error { + return fd.impl.OnClose(ctx) +} + +// Stat returns metadata for the file represented by fd. +func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + return fd.impl.Stat(ctx, opts) +} + +// SetStat updates metadata for the file represented by fd. +func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error { + return fd.impl.SetStat(ctx, opts) +} + +// StatFS returns metadata for the filesystem containing the file represented +// by fd. +func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + return fd.impl.StatFS(ctx) +} + +// PRead reads from the file represented by fd into dst, starting at the given +// offset, and returns the number of bytes read. PRead is permitted to return +// partial reads with a nil error. +func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return fd.impl.PRead(ctx, dst, offset, opts) +} + +// Read is similar to PRead, but does not specify an offset. +func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + return fd.impl.Read(ctx, dst, opts) +} + +// PWrite writes src to the file represented by fd, starting at the given +// offset, and returns the number of bytes written. PWrite is permitted to +// return partial writes with a nil error. +func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return fd.impl.PWrite(ctx, src, offset, opts) +} + +// Write is similar to PWrite, but does not specify an offset. +func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return fd.impl.Write(ctx, src, opts) +} + +// IterDirents invokes cb on each entry in the directory represented by fd. If +// IterDirents has been called since the last call to Seek, it continues +// iteration from the end of the last call. +func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error { + return fd.impl.IterDirents(ctx, cb) +} + +// Seek changes fd's offset (assuming one exists) and returns its new value. +func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return fd.impl.Seek(ctx, offset, whence) +} + +// Sync has the semantics of fsync(2). +func (fd *FileDescription) Sync(ctx context.Context) error { + return fd.impl.Sync(ctx) +} + +// ConfigureMMap mutates opts to implement mmap(2) for the file represented by +// fd. +func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.impl.ConfigureMMap(ctx, opts) +} + +// Ioctl implements the ioctl(2) syscall. +func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.impl.Ioctl(ctx, uio, args) +} + +// Listxattr returns all extended attribute names for the file represented by +// fd. +func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { + names, err := fd.impl.Listxattr(ctx) + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by default + // don't exist. + return nil, nil + } + return names, err +} + +// Getxattr returns the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) { + return fd.impl.Getxattr(ctx, name) +} + +// Setxattr changes the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error { + return fd.impl.Setxattr(ctx, opts) +} + +// Removexattr removes the given extended attribute from the file represented +// by fd. +func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { + return fd.impl.Removexattr(ctx, name) +} + +// SyncFS instructs the filesystem containing fd to execute the semantics of +// syncfs(2). +func (fd *FileDescription) SyncFS(ctx context.Context) error { + return fd.vd.mount.fs.impl.Sync(ctx) +} + +// MappedName implements memmap.MappingIdentity.MappedName. +func (fd *FileDescription) MappedName(ctx context.Context) string { + vfsroot := RootFromContext(ctx) + s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd) + if vfsroot.Ok() { + vfsroot.DecRef() + } + return s +} + +// DeviceID implements memmap.MappingIdentity.DeviceID. +func (fd *FileDescription) DeviceID() uint64 { + stat, err := fd.impl.Stat(context.Background(), StatOptions{ + // There is no STATX_DEV; we assume that Stat will return it if it's + // available regardless of mask. + Mask: 0, + // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_sb->s_dev + // directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return 0 + } + return uint64(linux.MakeDeviceID(uint16(stat.DevMajor), stat.DevMinor)) +} + +// InodeID implements memmap.MappingIdentity.InodeID. +func (fd *FileDescription) InodeID() uint64 { + stat, err := fd.impl.Stat(context.Background(), StatOptions{ + Mask: linux.STATX_INO, + // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil || stat.Mask&linux.STATX_INO == 0 { + return 0 + } + return stat.Ino +} + +// Msync implements memmap.MappingIdentity.Msync. +func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error { + return fd.impl.Sync(ctx) +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 4fbad7840..3df49991c 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -127,6 +127,31 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg return 0, syserror.ENOTTY } +// Listxattr implements FileDescriptionImpl.Listxattr analogously to +// inode_operations::listxattr == NULL in Linux. +func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) { + // This isn't exactly accurate; see FileDescription.Listxattr. + return nil, syserror.ENOTSUP +} + +// Getxattr implements FileDescriptionImpl.Getxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) { + return "", syserror.ENOTSUP +} + +// Setxattr implements FileDescriptionImpl.Setxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error { + return syserror.ENOTSUP +} + +// Removexattr implements FileDescriptionImpl.Removexattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error { + return syserror.ENOTSUP +} + // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain // implementations of non-directory I/O methods that return EISDIR. @@ -252,3 +277,12 @@ func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int6 fd.off = offset return offset, nil } + +// GenericConfigureMMap may be used by most implementations of +// FileDescriptionImpl.ConfigureMMap. +func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error { + opts.Mappable = m + opts.MappingIdentity = fd + fd.IncRef() + return nil +} diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 511b829fc..678be07fe 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -48,7 +48,7 @@ type genCountFD struct { func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription { var fd genCountFD - fd.vfsfd.Init(&fd, mnt, vfsd) + fd.vfsfd.Init(&fd, 0 /* statusFlags */, mnt, vfsd, &FileDescriptionOptions{}) fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd) return &fd.vfsfd } @@ -90,7 +90,7 @@ func TestGenCountFD(t *testing.T) { vfsObj := New() // vfs.New() vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &GetFilesystemOptions{}) if err != nil { t.Fatalf("failed to create testfs root mount: %v", err) } @@ -103,7 +103,7 @@ func TestGenCountFD(t *testing.T) { // The first read causes Generate to be called to fill the FD's buffer. buf := make([]byte, 2) ioseq := usermem.BytesIOSequence(buf) - n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err := fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -112,17 +112,17 @@ func TestGenCountFD(t *testing.T) { } // A second read without seeking is still at EOF. - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 0 || err != io.EOF { t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err) } // Seeking to the beginning of the file causes it to be regenerated. - n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET) + n, err = fd.Seek(ctx, 0, linux.SEEK_SET) if n != 0 || err != nil { t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err) } - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -131,7 +131,7 @@ func TestGenCountFD(t *testing.T) { } // PRead at the beginning of the file also causes it to be regenerated. - n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{}) + n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err) } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 7a074b718..b766614e7 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -33,15 +34,28 @@ type Filesystem struct { // operations. refs int64 + // vfs is the VirtualFilesystem that uses this Filesystem. vfs is + // immutable. + vfs *VirtualFilesystem + // impl is the FilesystemImpl associated with this Filesystem. impl is // immutable. This should be the last field in Dentry. impl FilesystemImpl } // Init must be called before first use of fs. -func (fs *Filesystem) Init(impl FilesystemImpl) { +func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) { fs.refs = 1 + fs.vfs = vfsObj fs.impl = impl + vfsObj.filesystemsMu.Lock() + vfsObj.filesystems[fs] = struct{}{} + vfsObj.filesystemsMu.Unlock() +} + +// VirtualFilesystem returns the containing VirtualFilesystem. +func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem { + return fs.vfs } // Impl returns the FilesystemImpl associated with fs. @@ -49,14 +63,35 @@ func (fs *Filesystem) Impl() FilesystemImpl { return fs.impl } -func (fs *Filesystem) incRef() { +// IncRef increments fs' reference count. +func (fs *Filesystem) IncRef() { if atomic.AddInt64(&fs.refs, 1) <= 1 { - panic("Filesystem.incRef() called without holding a reference") + panic("Filesystem.IncRef() called without holding a reference") + } +} + +// TryIncRef increments fs' reference count and returns true. If fs' reference +// count is zero, TryIncRef does nothing and returns false. +// +// TryIncRef does not require that a reference is held on fs. +func (fs *Filesystem) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&fs.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) { + return true + } } } -func (fs *Filesystem) decRef() { +// DecRef decrements fs' reference count. +func (fs *Filesystem) DecRef() { if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 { + fs.vfs.filesystemsMu.Lock() + delete(fs.vfs.filesystems, fs) + fs.vfs.filesystemsMu.Unlock() fs.impl.Release() } else if refs < 0 { panic("Filesystem.decRef() called without holding a reference") @@ -151,5 +186,70 @@ type FilesystemImpl interface { // UnlinkAt removes the non-directory file at rp. UnlinkAt(ctx context.Context, rp *ResolvingPath) error - // TODO: d_path(); extended attributes; inotify_add_watch(); bind() + // ListxattrAt returns all extended attribute names for the file at rp. + ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) + + // GetxattrAt returns the value associated with the given extended + // attribute for the file at rp. + GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) + + // SetxattrAt changes the value associated with the given extended + // attribute for the file at rp. + SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error + + // RemovexattrAt removes the given extended attribute from the file at rp. + RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error + + // PrependPath prepends a path from vd to vd.Mount().Root() to b. + // + // If vfsroot.Ok(), it is the contextual VFS root; if it is encountered + // before vd.Mount().Root(), PrependPath should stop prepending path + // components and return a PrependPathAtVFSRootError. + // + // If traversal of vd.Dentry()'s ancestors encounters an independent + // ("root") Dentry that is not vd.Mount().Root() (i.e. vd.Dentry() is not a + // descendant of vd.Mount().Root()), PrependPath should stop prepending + // path components and return a PrependPathAtNonMountRootError. + // + // Filesystems for which Dentries do not have meaningful paths may prepend + // an arbitrary descriptive string to b and then return a + // PrependPathSyntheticError. + // + // Most implementations can acquire the appropriate locks to ensure that + // Dentry.Name() and Dentry.Parent() are fixed for vd.Dentry() and all of + // its ancestors, then call GenericPrependPath. + // + // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. + PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error + + // TODO: inotify_add_watch(); bind() +} + +// PrependPathAtVFSRootError is returned by implementations of +// FilesystemImpl.PrependPath() when they encounter the contextual VFS root. +type PrependPathAtVFSRootError struct{} + +// Error implements error.Error. +func (PrependPathAtVFSRootError) Error() string { + return "vfs.FilesystemImpl.PrependPath() reached VFS root" +} + +// PrependPathAtNonMountRootError is returned by implementations of +// FilesystemImpl.PrependPath() when they encounter an independent ancestor +// Dentry that is not the Mount root. +type PrependPathAtNonMountRootError struct{} + +// Error implements error.Error. +func (PrependPathAtNonMountRootError) Error() string { + return "vfs.FilesystemImpl.PrependPath() reached root other than Mount root" +} + +// PrependPathSyntheticError is returned by implementations of +// FilesystemImpl.PrependPath() for which prepended names do not represent real +// paths. +type PrependPathSyntheticError struct{} + +// Error implements error.Error. +func (PrependPathSyntheticError) Error() string { + return "vfs.FilesystemImpl.PrependPath() prepended synthetic name" } diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go new file mode 100644 index 000000000..7315a588e --- /dev/null +++ b/pkg/sentry/vfs/filesystem_impl_util.go @@ -0,0 +1,69 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "strings" + + "gvisor.dev/gvisor/pkg/fspath" +) + +// GenericParseMountOptions parses a comma-separated list of options of the +// form "key" or "key=value", where neither key nor value contain commas, and +// returns it as a map. If str contains duplicate keys, then the last value +// wins. For example: +// +// str = "key0=value0,key1,key2=value2,key0=value3" -> map{'key0':'value3','key1':'','key2':'value2'} +// +// GenericParseMountOptions is not appropriate if values may contain commas, +// e.g. in the case of the mpol mount option for tmpfs(5). +func GenericParseMountOptions(str string) map[string]string { + m := make(map[string]string) + for _, opt := range strings.Split(str, ",") { + if len(opt) > 0 { + res := strings.SplitN(opt, "=", 2) + if len(res) == 2 { + m[res[0]] = res[1] + } else { + m[opt] = "" + } + } + } + return m +} + +// GenericPrependPath may be used by implementations of +// FilesystemImpl.PrependPath() for which a single statically-determined lock +// or set of locks is sufficient to ensure its preconditions (as opposed to +// e.g. per-Dentry locks). +// +// Preconditions: Dentry.Name() and Dentry.Parent() must be held constant for +// vd.Dentry() and all of its ancestors. +func GenericPrependPath(vfsroot, vd VirtualDentry, b *fspath.Builder) error { + mnt, d := vd.mount, vd.dentry + for { + if mnt == vfsroot.mount && d == vfsroot.dentry { + return PrependPathAtVFSRootError{} + } + if d == mnt.root { + return nil + } + if d.parent == nil { + return PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index f401ad7f3..c335e206d 100644 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go @@ -25,21 +25,21 @@ import ( // // FilesystemType is analogous to Linux's struct file_system_type. type FilesystemType interface { - // NewFilesystem returns a Filesystem configured by the given options, + // GetFilesystem returns a Filesystem configured by the given options, // along with its mount root. A reference is taken on the returned // Filesystem and Dentry. - NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) + GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error) } -// NewFilesystemOptions contains options to FilesystemType.NewFilesystem. -type NewFilesystemOptions struct { +// GetFilesystemOptions contains options to FilesystemType.GetFilesystem. +type GetFilesystemOptions struct { // Data is the string passed as the 5th argument to mount(2), which is // usually a comma-separated list of filesystem-specific mount options. Data string // InternalData holds opaque FilesystemType-specific data. There is // intentionally no way for applications to specify InternalData; if it is - // not nil, the call to NewFilesystem originates from within the sentry. + // not nil, the call to GetFilesystem originates from within the sentry. InternalData interface{} } diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 11702f720..ec23ab0dd 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -18,6 +18,7 @@ import ( "math" "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -38,16 +39,12 @@ import ( // Mount is analogous to Linux's struct mount. (gVisor does not distinguish // between struct mount and struct vfsmount.) type Mount struct { - // The lower 63 bits of refs are a reference count. The MSB of refs is set - // if the Mount has been eagerly unmounted, as by umount(2) without the - // MNT_DETACH flag. refs is accessed using atomic memory operations. - refs int64 - - // The lower 63 bits of writers is the number of calls to - // Mount.CheckBeginWrite() that have not yet been paired with a call to - // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. - // writers is accessed using atomic memory operations. - writers int64 + // vfs, fs, and root are immutable. References are held on fs and root. + // + // Invariant: root belongs to fs. + vfs *VirtualFilesystem + fs *Filesystem + root *Dentry // key is protected by VirtualFilesystem.mountMu and // VirtualFilesystem.mounts.seq, and may be nil. References are held on @@ -57,13 +54,29 @@ type Mount struct { // key.parent.fs. key mountKey - // fs, root, and ns are immutable. References are held on fs and root (but - // not ns). - // - // Invariant: root belongs to fs. - fs *Filesystem - root *Dentry - ns *MountNamespace + // ns is the namespace in which this Mount was mounted. ns is protected by + // VirtualFilesystem.mountMu. + ns *MountNamespace + + // The lower 63 bits of refs are a reference count. The MSB of refs is set + // if the Mount has been eagerly umounted, as by umount(2) without the + // MNT_DETACH flag. refs is accessed using atomic memory operations. + refs int64 + + // children is the set of all Mounts for which Mount.key.parent is this + // Mount. children is protected by VirtualFilesystem.mountMu. + children map[*Mount]struct{} + + // umounted is true if VFS.umountRecursiveLocked() has been called on this + // Mount. VirtualFilesystem does not hold a reference on Mounts for which + // umounted is true. umounted is protected by VirtualFilesystem.mountMu. + umounted bool + + // The lower 63 bits of writers is the number of calls to + // Mount.CheckBeginWrite() that have not yet been paired with a call to + // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. + // writers is accessed using atomic memory operations. + writers int64 } // A MountNamespace is a collection of Mounts. @@ -73,13 +86,16 @@ type Mount struct { // // MountNamespace is analogous to Linux's struct mnt_namespace. type MountNamespace struct { - refs int64 // accessed using atomic memory operations - // root is the MountNamespace's root mount. root is immutable. root *Mount - // mountpoints contains all Dentries which are mount points in this - // namespace. mountpoints is protected by VirtualFilesystem.mountMu. + // refs is the reference count. refs is accessed using atomic memory + // operations. + refs int64 + + // mountpoints maps all Dentries which are mount points in this namespace + // to the number of Mounts for which they are mount points. mountpoints is + // protected by VirtualFilesystem.mountMu. // // mountpoints is used to determine if a Dentry can be moved or removed // (which requires that the Dentry is not a mount point in the calling @@ -89,26 +105,27 @@ type MountNamespace struct { // MountNamespace; this is required to ensure that // VFS.PrepareDeleteDentry() and VFS.PrepareRemoveDentry() operate // correctly on unreferenced MountNamespaces. - mountpoints map[*Dentry]struct{} + mountpoints map[*Dentry]uint32 } // NewMountNamespace returns a new mount namespace with a root filesystem // configured by the given arguments. A reference is taken on the returned // MountNamespace. -func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *NewFilesystemOptions) (*MountNamespace, error) { +func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { fsType := vfs.getFilesystemType(fsTypeName) if fsType == nil { return nil, syserror.ENODEV } - fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts) + fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts) if err != nil { return nil, err } mntns := &MountNamespace{ refs: 1, - mountpoints: make(map[*Dentry]struct{}), + mountpoints: make(map[*Dentry]uint32), } mntns.root = &Mount{ + vfs: vfs, fs: fs, root: root, ns: mntns, @@ -117,13 +134,13 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return mntns, nil } -// NewMount creates and mounts a new Filesystem. -func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *NewFilesystemOptions) error { +// MountAt creates and mounts a Filesystem configured by the given arguments. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { fsType := vfs.getFilesystemType(fsTypeName) if fsType == nil { return syserror.ENODEV } - fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts) + fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return err } @@ -131,17 +148,19 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti // lock ordering. vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{}) if err != nil { - root.decRef(fs) - fs.decRef() + root.DecRef() + fs.DecRef() return err } vfs.mountMu.Lock() + vd.dentry.mu.Lock() for { if vd.dentry.IsDisowned() { + vd.dentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef() - root.decRef(fs) - fs.decRef() + root.DecRef() + fs.DecRef() return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and @@ -153,36 +172,272 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti if nextmnt == nil { break } - nextmnt.incRef() - nextmnt.root.incRef(nextmnt.fs) + // It's possible that nextmnt has been umounted but not disconnected, + // in which case vfs no longer holds a reference on it, and the last + // reference may be concurrently dropped even though we're holding + // vfs.mountMu. + if !nextmnt.tryIncMountedRef() { + break + } + // This can't fail since we're holding vfs.mountMu. + nextmnt.root.IncRef() + vd.dentry.mu.Unlock() vd.DecRef() vd = VirtualDentry{ mount: nextmnt, dentry: nextmnt.root, } + vd.dentry.mu.Lock() } // TODO: Linux requires that either both the mount point and the mount root // are directories, or neither are, and returns ENOTDIR if this is not the // case. mntns := vd.mount.ns mnt := &Mount{ + vfs: vfs, fs: fs, root: root, ns: mntns, refs: 1, } - mnt.storeKey(vd.mount, vd.dentry) + vfs.mounts.seq.BeginWrite() + vfs.connectLocked(mnt, vd, mntns) + vfs.mounts.seq.EndWrite() + vd.dentry.mu.Unlock() + vfs.mountMu.Unlock() + return nil +} + +// UmountAt removes the Mount at the given path. +func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error { + if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 { + return syserror.EINVAL + } + + // MNT_FORCE is currently unimplemented except for the permission check. + if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) { + return syserror.EPERM + } + + vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{}) + if err != nil { + return err + } + defer vd.DecRef() + if vd.dentry != vd.mount.root { + return syserror.EINVAL + } + vfs.mountMu.Lock() + if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns { + vfs.mountMu.Unlock() + return syserror.EINVAL + } + + // TODO(jamieliu): Linux special-cases umount of the caller's root, which + // we don't implement yet (we'll just fail it since the caller holds a + // reference on it). + + vfs.mounts.seq.BeginWrite() + if opts.Flags&linux.MNT_DETACH == 0 { + if len(vd.mount.children) != 0 { + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + // We are holding a reference on vd.mount. + expectedRefs := int64(1) + if !vd.mount.umounted { + expectedRefs = 2 + } + if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + } + vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{ + eager: opts.Flags&linux.MNT_DETACH == 0, + disconnectHierarchy: true, + }, nil, nil) + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } + return nil +} + +type umountRecursiveOptions struct { + // If eager is true, ensure that future calls to Mount.tryIncMountedRef() + // on umounted mounts fail. + // + // eager is analogous to Linux's UMOUNT_SYNC. + eager bool + + // If disconnectHierarchy is true, Mounts that are umounted hierarchically + // should be disconnected from their parents. (Mounts whose parents are not + // umounted, which in most cases means the Mount passed to the initial call + // to umountRecursiveLocked, are unconditionally disconnected for + // consistency with Linux.) + // + // disconnectHierarchy is analogous to Linux's !UMOUNT_CONNECTED. + disconnectHierarchy bool +} + +// umountRecursiveLocked marks mnt and its descendants as umounted. It does not +// release mount or dentry references; instead, it appends VirtualDentries and +// Mounts on which references must be dropped to vdsToDecRef and mountsToDecRef +// respectively, and returns updated slices. (This is necessary because +// filesystem locks possibly taken by DentryImpl.DecRef() may precede +// vfs.mountMu in the lock order, and Mount.DecRef() may lock vfs.mountMu.) +// +// umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree(). +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. +func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) { + if !mnt.umounted { + mnt.umounted = true + mountsToDecRef = append(mountsToDecRef, mnt) + if parent := mnt.parent(); parent != nil && (opts.disconnectHierarchy || !parent.umounted) { + vdsToDecRef = append(vdsToDecRef, vfs.disconnectLocked(mnt)) + } + } + if opts.eager { + for { + refs := atomic.LoadInt64(&mnt.refs) + if refs < 0 { + break + } + if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs|math.MinInt64) { + break + } + } + } + for child := range mnt.children { + vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(child, opts, vdsToDecRef, mountsToDecRef) + } + return vdsToDecRef, mountsToDecRef +} + +// connectLocked makes vd the mount parent/point for mnt. It consumes +// references held by vd. +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. d.mu must be locked. mnt.parent() == nil. +func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) { + mnt.storeKey(vd) + if vd.mount.children == nil { + vd.mount.children = make(map[*Mount]struct{}) + } + vd.mount.children[mnt] = struct{}{} atomic.AddUint32(&vd.dentry.mounts, 1) - mntns.mountpoints[vd.dentry] = struct{}{} + mntns.mountpoints[vd.dentry]++ + vfs.mounts.insertSeqed(mnt) vfsmpmounts, ok := vfs.mountpoints[vd.dentry] if !ok { vfsmpmounts = make(map[*Mount]struct{}) vfs.mountpoints[vd.dentry] = vfsmpmounts } vfsmpmounts[mnt] = struct{}{} - vfs.mounts.Insert(mnt) - vfs.mountMu.Unlock() - return nil +} + +// disconnectLocked makes vd have no mount parent/point and returns its old +// mount parent/point with a reference held. +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. mnt.parent() != nil. +func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { + vd := mnt.loadKey() + mnt.storeKey(VirtualDentry{}) + delete(vd.mount.children, mnt) + atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1 + mnt.ns.mountpoints[vd.dentry]-- + if mnt.ns.mountpoints[vd.dentry] == 0 { + delete(mnt.ns.mountpoints, vd.dentry) + } + vfs.mounts.removeSeqed(mnt) + vfsmpmounts := vfs.mountpoints[vd.dentry] + delete(vfsmpmounts, mnt) + if len(vfsmpmounts) == 0 { + delete(vfs.mountpoints, vd.dentry) + } + return vd +} + +// tryIncMountedRef increments mnt's reference count and returns true. If mnt's +// reference count is already zero, or has been eagerly umounted, +// tryIncMountedRef does nothing and returns false. +// +// tryIncMountedRef does not require that a reference is held on mnt. +func (mnt *Mount) tryIncMountedRef() bool { + for { + refs := atomic.LoadInt64(&mnt.refs) + if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted + return false + } + if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { + return true + } + } +} + +// IncRef increments mnt's reference count. +func (mnt *Mount) IncRef() { + // In general, negative values for mnt.refs are valid because the MSB is + // the eager-unmount bit. + atomic.AddInt64(&mnt.refs, 1) +} + +// DecRef decrements mnt's reference count. +func (mnt *Mount) DecRef() { + refs := atomic.AddInt64(&mnt.refs, -1) + if refs&^math.MinInt64 == 0 { // mask out MSB + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() + } + mnt.root.DecRef() + mnt.fs.DecRef() + if vd.Ok() { + vd.DecRef() + } + } +} + +// IncRef increments mntns' reference count. +func (mntns *MountNamespace) IncRef() { + if atomic.AddInt64(&mntns.refs, 1) <= 1 { + panic("MountNamespace.IncRef() called without holding a reference") + } +} + +// DecRef decrements mntns' reference count. +func (mntns *MountNamespace) DecRef(vfs *VirtualFilesystem) { + if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 { + vfs.mountMu.Lock() + vfs.mounts.seq.BeginWrite() + vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{ + disconnectHierarchy: true, + }, nil, nil) + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } + } else if refs < 0 { + panic("MountNamespace.DecRef() called without holding a reference") + } } // getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes @@ -223,7 +478,7 @@ retryFirst: // Raced with umount. continue } - mnt.decRef() + mnt.DecRef() mnt = next d = next.root } @@ -231,12 +486,12 @@ retryFirst: } // getMountpointAt returns the mount point for the stack of Mounts including -// mnt. It takes a reference on the returned Mount and Dentry. If no such mount +// mnt. It takes a reference on the returned VirtualDentry. If no such mount // point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil). // // Preconditions: References are held on mnt and root. vfsroot is not (mnt, // mnt.root). -func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) (*Mount, *Dentry) { +func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry { // The first mount is special-cased: // // - The caller must have already checked mnt against vfsroot. @@ -246,21 +501,26 @@ func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) // - We don't drop the caller's reference on mnt. retryFirst: epoch := vfs.mounts.seq.BeginRead() - parent, point := mnt.loadKey() + parent, point := mnt.parent(), mnt.point() if !vfs.mounts.seq.ReadOk(epoch) { goto retryFirst } if parent == nil { - return nil, nil + return VirtualDentry{} } if !parent.tryIncMountedRef() { // Raced with umount. goto retryFirst } - if !point.tryIncRef(parent.fs) { + if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can only // happen due to a racing change to Mount.key. - parent.decRef() + parent.DecRef() + goto retryFirst + } + if !vfs.mounts.seq.ReadOk(epoch) { + point.DecRef() + parent.DecRef() goto retryFirst } mnt = parent @@ -274,7 +534,7 @@ retryFirst: } retryNotFirst: epoch := vfs.mounts.seq.BeginRead() - parent, point := mnt.loadKey() + parent, point := mnt.parent(), mnt.point() if !vfs.mounts.seq.ReadOk(epoch) { goto retryNotFirst } @@ -285,59 +545,23 @@ retryFirst: // Raced with umount. goto retryNotFirst } - if !point.tryIncRef(parent.fs) { + if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can // only happen due to a racing change to Mount.key. - parent.decRef() + parent.DecRef() goto retryNotFirst } if !vfs.mounts.seq.ReadOk(epoch) { - point.decRef(parent.fs) - parent.decRef() + point.DecRef() + parent.DecRef() goto retryNotFirst } - d.decRef(mnt.fs) - mnt.decRef() + d.DecRef() + mnt.DecRef() mnt = parent d = point } - return mnt, d -} - -// tryIncMountedRef increments mnt's reference count and returns true. If mnt's -// reference count is already zero, or has been eagerly unmounted, -// tryIncMountedRef does nothing and returns false. -// -// tryIncMountedRef does not require that a reference is held on mnt. -func (mnt *Mount) tryIncMountedRef() bool { - for { - refs := atomic.LoadInt64(&mnt.refs) - if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted - return false - } - if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { - return true - } - } -} - -func (mnt *Mount) incRef() { - // In general, negative values for mnt.refs are valid because the MSB is - // the eager-unmount bit. - atomic.AddInt64(&mnt.refs, 1) -} - -func (mnt *Mount) decRef() { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - parent, point := mnt.loadKey() - if point != nil { - point.decRef(parent.fs) - parent.decRef() - } - mnt.root.decRef(mnt.fs) - mnt.fs.decRef() - } + return VirtualDentry{mnt, d} } // CheckBeginWrite increments the counter of in-progress write operations on @@ -360,7 +584,7 @@ func (mnt *Mount) EndWrite() { atomic.AddInt64(&mnt.writers, -1) } -// Preconditions: VirtualFilesystem.mountMu must be locked for writing. +// Preconditions: VirtualFilesystem.mountMu must be locked. func (mnt *Mount) setReadOnlyLocked(ro bool) error { if oldRO := atomic.LoadInt64(&mnt.writers) < 0; oldRO == ro { return nil @@ -383,22 +607,6 @@ func (mnt *Mount) Filesystem() *Filesystem { return mnt.fs } -// IncRef increments mntns' reference count. -func (mntns *MountNamespace) IncRef() { - if atomic.AddInt64(&mntns.refs, 1) <= 1 { - panic("MountNamespace.IncRef() called without holding a reference") - } -} - -// DecRef decrements mntns' reference count. -func (mntns *MountNamespace) DecRef() { - if refs := atomic.AddInt64(&mntns.refs, 0); refs == 0 { - // TODO: unmount mntns.root - } else if refs < 0 { - panic("MountNamespace.DecRef() called without holding a reference") - } -} - // Root returns mntns' root. A reference is taken on the returned // VirtualDentry. func (mntns *MountNamespace) Root() VirtualDentry { diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index f394d7483..adff0b94b 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -37,7 +37,7 @@ func TestMountTableInsertLookup(t *testing.T) { mt.Init() mount := &Mount{} - mount.storeKey(&Mount{}, &Dentry{}) + mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) mt.Insert(mount) if m := mt.Lookup(mount.parent(), mount.point()); m != mount { @@ -78,18 +78,10 @@ const enableComparativeBenchmarks = false func newBenchMount() *Mount { mount := &Mount{} - mount.storeKey(&Mount{}, &Dentry{}) + mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) return mount } -func vdkey(mnt *Mount) VirtualDentry { - parent, point := mnt.loadKey() - return VirtualDentry{ - mount: parent, - dentry: point, - } -} - func BenchmarkMountTableParallelLookup(b *testing.B) { for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { for _, numMounts := range benchNumMounts { @@ -101,7 +93,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { for i := 0; i < numMounts; i++ { mount := newBenchMount() mt.Insert(mount) - keys = append(keys, vdkey(mount)) + keys = append(keys, mount.loadKey()) } var ready sync.WaitGroup @@ -153,7 +145,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := vdkey(mount) + key := mount.loadKey() ms[key] = mount keys = append(keys, key) } @@ -208,7 +200,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := vdkey(mount) + key := mount.loadKey() ms.Store(key, mount) keys = append(keys, key) } @@ -290,7 +282,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -325,7 +317,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { var ms sync.Map for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -379,7 +371,7 @@ func BenchmarkMountMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } } @@ -399,7 +391,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } } @@ -432,13 +424,13 @@ func BenchmarkMountMapRemove(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := range mounts { mount := mounts[i] - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } b.ResetTimer() for i := range mounts { mount := mounts[i] - delete(ms, vdkey(mount)) + delete(ms, mount.loadKey()) } } @@ -454,12 +446,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) { var ms sync.Map for i := range mounts { mount := mounts[i] - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Delete(vdkey(mount)) + ms.Delete(mount.loadKey()) } } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b0511aa40..ab13fa461 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -26,7 +26,7 @@ import ( "sync/atomic" "unsafe" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // mountKey represents the location at which a Mount is mounted. It is @@ -38,16 +38,6 @@ type mountKey struct { point unsafe.Pointer // *Dentry } -// Invariant: mnt.key's fields are nil. parent and point are non-nil. -func (mnt *Mount) storeKey(parent *Mount, point *Dentry) { - atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(parent)) - atomic.StorePointer(&mnt.key.point, unsafe.Pointer(point)) -} - -func (mnt *Mount) loadKey() (*Mount, *Dentry) { - return (*Mount)(atomic.LoadPointer(&mnt.key.parent)), (*Dentry)(atomic.LoadPointer(&mnt.key.point)) -} - func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,6 +46,19 @@ func (mnt *Mount) point() *Dentry { return (*Dentry)(atomic.LoadPointer(&mnt.key.point)) } +func (mnt *Mount) loadKey() VirtualDentry { + return VirtualDentry{ + mount: mnt.parent(), + dentry: mnt.point(), + } +} + +// Invariant: mnt.key.parent == nil. vd.Ok(). +func (mnt *Mount) storeKey(vd VirtualDentry) { + atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) + atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) +} + // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). @@ -72,7 +75,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq gvsync.SeqCount + seq syncutil.SeqCount seed uint32 // for hashing keys // size holds both length (number of elements) and capacity (number of @@ -201,9 +204,19 @@ loop: // Insert inserts the given mount into mt. // -// Preconditions: There are no concurrent mutators of mt. mt must not already -// contain a Mount with the same mount point and parent. +// Preconditions: mt must not already contain a Mount with the same mount point +// and parent. func (mt *mountTable) Insert(mount *Mount) { + mt.seq.BeginWrite() + mt.insertSeqed(mount) + mt.seq.EndWrite() +} + +// insertSeqed inserts the given mount into mt. +// +// Preconditions: mt.seq must be in a writer critical section. mt must not +// already contain a Mount with the same mount point and parent. +func (mt *mountTable) insertSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) // We're under the maximum load factor if: @@ -215,10 +228,8 @@ func (mt *mountTable) Insert(mount *Mount) { tcap := uintptr(1) << order if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) { // Atomically insert the new element into the table. - mt.seq.BeginWrite() atomic.AddUint64(&mt.size, mtSizeLenOne) mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash) - mt.seq.EndWrite() return } @@ -241,8 +252,6 @@ func (mt *mountTable) Insert(mount *Mount) { for { oldSlot := (*mountSlot)(oldCur) if oldSlot.value != nil { - // Don't need to lock mt.seq yet since newSlots isn't visible - // to readers. mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash) } if oldCur == oldLast { @@ -252,11 +261,9 @@ func (mt *mountTable) Insert(mount *Mount) { } // Insert the new element into the new table. mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash) - // Atomically switch to the new table. - mt.seq.BeginWrite() + // Switch to the new table. atomic.AddUint64(&mt.size, mtSizeLenOne|mtSizeOrderOne) atomic.StorePointer(&mt.slots, newSlots) - mt.seq.EndWrite() } // Preconditions: There are no concurrent mutators of the table (slots, cap). @@ -294,9 +301,18 @@ func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, has // Remove removes the given mount from mt. // -// Preconditions: There are no concurrent mutators of mt. mt must contain -// mount. +// Preconditions: mt must contain mount. func (mt *mountTable) Remove(mount *Mount) { + mt.seq.BeginWrite() + mt.removeSeqed(mount) + mt.seq.EndWrite() +} + +// removeSeqed removes the given mount from mt. +// +// Preconditions: mt.seq must be in a writer critical section. mt must contain +// mount. +func (mt *mountTable) removeSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 @@ -311,7 +327,6 @@ func (mt *mountTable) Remove(mount *Mount) { // backward until we either find an empty slot, or an element that // is already in its first-probed slot. (This is backward shift // deletion.) - mt.seq.BeginWrite() for { nextOff := (off + mountSlotBytes) & offmask nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff)) @@ -330,7 +345,6 @@ func (mt *mountTable) Remove(mount *Mount) { } atomic.StorePointer(&slot.value, nil) atomic.AddUint64(&mt.size, mtSizeLenNegOne) - mt.seq.EndWrite() return } if checkInvariants && slotValue == nil { diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 3aa73d911..97ee4a446 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -46,6 +46,12 @@ type MknodOptions struct { DevMinor uint32 } +// MountOptions contains options to VirtualFilesystem.MountAt(). +type MountOptions struct { + // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). + GetFilesystemOptions GetFilesystemOptions +} + // OpenOptions contains options to VirtualFilesystem.OpenAt() and // FilesystemImpl.OpenAt(). type OpenOptions struct { @@ -95,6 +101,20 @@ type SetStatOptions struct { Stat linux.Statx } +// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), +// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and +// FileDescriptionImpl.Setxattr(). +type SetxattrOptions struct { + // Name is the name of the extended attribute being mutated. + Name string + + // Value is the extended attribute's new value. + Value string + + // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2). + Flags uint32 +} + // StatOptions contains options to VirtualFilesystem.StatAt(), // FilesystemImpl.StatAt(), FileDescription.Stat(), and // FileDescriptionImpl.Stat(). @@ -114,6 +134,12 @@ type StatOptions struct { Sync uint32 } +// UmountOptions contains options to VirtualFilesystem.UmountAt(). +type UmountOptions struct { + // Flags contains flags as specified for umount2(2). + Flags uint32 +} + // WriteOptions contains options to FileDescription.PWrite(), // FileDescriptionImpl.PWrite(), FileDescription.Write(), and // FileDescriptionImpl.Write(). diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go new file mode 100644 index 000000000..8e155654f --- /dev/null +++ b/pkg/sentry/vfs/pathname.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/syserror" +) + +var fspathBuilderPool = sync.Pool{ + New: func() interface{} { + return &fspath.Builder{} + }, +} + +func getFSPathBuilder() *fspath.Builder { + return fspathBuilderPool.Get().(*fspath.Builder) +} + +func putFSPathBuilder(b *fspath.Builder) { + // No methods can be called on b after b.String(), so reset it to its zero + // value (as returned by fspathBuilderPool.New) instead. + *b = fspath.Builder{} + fspathBuilderPool.Put(b) +} + +// PathnameWithDeleted returns an absolute pathname to vd, consistent with +// Linux's d_path(). In particular, if vd.Dentry() has been disowned, +// PathnameWithDeleted appends " (deleted)" to the returned pathname. +func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) { + b := getFSPathBuilder() + defer putFSPathBuilder(b) + haveRef := false + defer func() { + if haveRef { + vd.DecRef() + } + }() + + origD := vd.dentry +loop: + for { + err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b) + switch err.(type) { + case nil: + if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { + // GenericPrependPath() will have returned + // PrependPathAtVFSRootError in this case since it checks + // against vfsroot before mnt.root, but other implementations + // of FilesystemImpl.PrependPath() may return nil instead. + break loop + } + nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + if !nextVD.Ok() { + break loop + } + if haveRef { + vd.DecRef() + } + vd = nextVD + haveRef = true + // continue loop + case PrependPathSyntheticError: + // Skip prepending "/" and appending " (deleted)". + return b.String(), nil + case PrependPathAtVFSRootError, PrependPathAtNonMountRootError: + break loop + default: + return "", err + } + } + b.PrependByte('/') + if origD.IsDisowned() { + b.AppendString(" (deleted)") + } + return b.String(), nil +} + +// PathnameForGetcwd returns an absolute pathname to vd, consistent with +// Linux's sys_getcwd(). +func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) { + if vd.dentry.IsDisowned() { + return "", syserror.ENOENT + } + + b := getFSPathBuilder() + defer putFSPathBuilder(b) + haveRef := false + defer func() { + if haveRef { + vd.DecRef() + } + }() + unreachable := false +loop: + for { + err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b) + switch err.(type) { + case nil: + if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { + break loop + } + nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + if !nextVD.Ok() { + unreachable = true + break loop + } + if haveRef { + vd.DecRef() + } + vd = nextVD + haveRef = true + case PrependPathAtVFSRootError: + break loop + case PrependPathAtNonMountRootError, PrependPathSyntheticError: + unreachable = true + break loop + default: + return "", err + } + } + b.PrependByte('/') + if unreachable { + b.PrependString("(unreachable)") + } + return b.String(), nil +} + +// As of this writing, we do not have equivalents to: +// +// - d_absolute_path(), which returns EINVAL if (effectively) any call to +// FilesystemImpl.PrependPath() would return PrependPathAtNonMountRootError. +// +// - dentry_path(), which does not walk up mounts (and only returns the path +// relative to Filesystem root), but also appends "//deleted" for disowned +// Dentries. +// +// These should be added as necessary. diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index f8e74355c..f1edb0680 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -119,3 +119,65 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { return false } } + +// CheckSetStat checks that creds has permission to change the metadata of a +// file with the given permissions, UID, and GID as specified by stat, subject +// to the rules of Linux's fs/attr.c:setattr_prepare(). +func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error { + if stat.Mask&linux.STATX_MODE != 0 { + if !CanActAsOwner(creds, kuid) { + return syserror.EPERM + } + // TODO(b/30815691): "If the calling process is not privileged (Linux: + // does not have the CAP_FSETID capability), and the group of the file + // does not match the effective group ID of the process or one of its + // supplementary group IDs, the S_ISGID bit will be turned off, but + // this will not cause an error to be returned." - chmod(2) + } + if stat.Mask&linux.STATX_UID != 0 { + if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&linux.STATX_GID != 0 { + if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { + if !CanActAsOwner(creds, kuid) { + if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) { + return syserror.EPERM + } + // isDir is irrelevant in the following call to + // GenericCheckPermissions since ats == MayWrite means that + // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE + // applies, regardless of isDir. + if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil { + return err + } + } + } + return nil +} + +// CanActAsOwner returns true if creds can act as the owner of a file with the +// given owning UID, consistent with Linux's +// fs/inode.c:inode_owner_or_capable(). +func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool { + if creds.EffectiveKUID == kuid { + return true + } + return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok() +} + +// HasCapabilityOnFile returns true if creds has the given capability with +// respect to a file with the given owning UID and GID, consistent with Linux's +// kernel/capability.c:capable_wrt_inode_uidgid(). +func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool { + return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok() +} diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 8d05c8583..d580fd39e 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -85,11 +85,11 @@ func init() { // so error "constants" are really mutable vars, necessitating somewhat // expensive interface object comparisons. -type resolveMountRootError struct{} +type resolveMountRootOrJumpError struct{} // Error implements error.Error. -func (resolveMountRootError) Error() string { - return "resolving mount root" +func (resolveMountRootOrJumpError) Error() string { + return "resolving mount root or jump" } type resolveMountPointError struct{} @@ -149,20 +149,20 @@ func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) { func (rp *ResolvingPath) decRefStartAndMount() { if rp.flags&rpflagsHaveStartRef != 0 { - rp.start.decRef(rp.mount.fs) + rp.start.DecRef() } if rp.flags&rpflagsHaveMountRef != 0 { - rp.mount.decRef() + rp.mount.DecRef() } } func (rp *ResolvingPath) releaseErrorState() { if rp.nextStart != nil { - rp.nextStart.decRef(rp.nextMount.fs) + rp.nextStart.DecRef() rp.nextStart = nil } if rp.nextMount != nil { - rp.nextMount.decRef() + rp.nextMount.DecRef() rp.nextMount = nil } } @@ -269,12 +269,12 @@ func (rp *ResolvingPath) ResolveParent(d *Dentry) (*Dentry, error) { parent = d } else if d == rp.mount.root { // At mount root ... - mnt, mntpt := rp.vfs.getMountpointAt(rp.mount, rp.root) - if mnt != nil { + vd := rp.vfs.getMountpointAt(rp.mount, rp.root) + if vd.Ok() { // ... of non-root mount. - rp.nextMount = mnt - rp.nextStart = mntpt - return nil, resolveMountRootError{} + rp.nextMount = vd.mount + rp.nextStart = vd.dentry + return nil, resolveMountRootOrJumpError{} } // ... of root mount. parent = d @@ -385,11 +385,32 @@ func (rp *ResolvingPath) relpathPrepend(path fspath.Path) { } } +// HandleJump is called when the current path component is a "magic" link to +// the given VirtualDentry, like /proc/[pid]/fd/[fd]. If the calling Filesystem +// method should continue path traversal, HandleMagicSymlink updates the path +// component stream to reflect the magic link target and returns nil. Otherwise +// it returns a non-nil error. +// +// Preconditions: !rp.Done(). +func (rp *ResolvingPath) HandleJump(target VirtualDentry) error { + if rp.symlinks >= linux.MaxSymlinkTraversals { + return syserror.ELOOP + } + rp.symlinks++ + // Consume the path component that represented the magic link. + rp.Advance() + // Unconditionally return a resolveMountRootOrJumpError, even if the Mount + // isn't changing, to force restarting at the new Dentry. + target.IncRef() + rp.nextMount = target.mount + rp.nextStart = target.dentry + return resolveMountRootOrJumpError{} +} + func (rp *ResolvingPath) handleError(err error) bool { switch err.(type) { - case resolveMountRootError: - // Switch to the new Mount. We hold references on the Mount and Dentry - // (from VFS.getMountpointAt()). + case resolveMountRootOrJumpError: + // Switch to the new Mount. We hold references on the Mount and Dentry. rp.decRefStartAndMount() rp.mount = rp.nextMount rp.start = rp.nextStart @@ -407,9 +428,8 @@ func (rp *ResolvingPath) handleError(err error) bool { return true case resolveMountPointError: - // Switch to the new Mount. We hold a reference on the Mount (from - // VFS.getMountAt()), but borrow the reference on the mount root from - // the Mount. + // Switch to the new Mount. We hold a reference on the Mount, but + // borrow the reference on the mount root from the Mount. rp.decRefStartAndMount() rp.mount = rp.nextMount rp.start = rp.nextMount.root diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go deleted file mode 100644 index 23f2b9e08..000000000 --- a/pkg/sentry/vfs/syscalls.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" -) - -// PathOperation specifies the path operated on by a VFS method. -// -// PathOperation is passed to VFS methods by pointer to reduce memory copying: -// it's somewhat large and should never escape. (Options structs are passed by -// pointer to VFS and FileDescription methods for the same reason.) -type PathOperation struct { - // Root is the VFS root. References on Root are borrowed from the provider - // of the PathOperation. - // - // Invariants: Root.Ok(). - Root VirtualDentry - - // Start is the starting point for the path traversal. References on Start - // are borrowed from the provider of the PathOperation (i.e. the caller of - // the VFS method to which the PathOperation was passed). - // - // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. - Start VirtualDentry - - // Path is the pathname traversed by this operation. - Pathname string - - // If FollowFinalSymlink is true, and the Dentry traversed by the final - // path component represents a symbolic link, the symbolic link should be - // followed. - FollowFinalSymlink bool -} - -// GetDentryAt returns a VirtualDentry representing the given path, at which a -// file must exist. A reference is taken on the returned VirtualDentry. -func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return VirtualDentry{}, err - } - for { - d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) - if err == nil { - vd := VirtualDentry{ - mount: rp.mount, - dentry: d, - } - rp.mount.incRef() - vfs.putResolvingPath(rp) - return vd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return VirtualDentry{}, err - } - } -} - -// MkdirAt creates a directory at the given path. -func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { - // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is - // also honored." - mkdir(2) - opts.Mode &= 01777 - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err - } - for { - err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return err - } - } -} - -// OpenAt returns a FileDescription providing access to the file at the given -// path. A reference is taken on the returned FileDescription. -func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { - // Remove: - // - // - O_LARGEFILE, which we always report in FileDescription status flags - // since only 64-bit architectures are supported at this time. - // - // - O_CLOEXEC, which affects file descriptors and therefore must be - // handled outside of VFS. - // - // - Unknown flags. - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE - // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. - if opts.Flags&linux.O_SYNC != 0 { - opts.Flags |= linux.O_DSYNC - } - // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified - // with O_DIRECTORY and a writable access mode (to ensure that it fails on - // filesystem implementations that do not support it). - if opts.Flags&linux.O_TMPFILE != 0 { - if opts.Flags&linux.O_DIRECTORY == 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { - return nil, syserror.EINVAL - } - } - // O_PATH causes most other flags to be ignored. - if opts.Flags&linux.O_PATH != 0 { - opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH - } - // "On Linux, the following bits are also honored in mode: [S_ISUID, - // S_ISGID, S_ISVTX]" - open(2) - opts.Mode &= 07777 - - if opts.Flags&linux.O_NOFOLLOW != 0 { - pop.FollowFinalSymlink = false - } - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return nil, err - } - if opts.Flags&linux.O_DIRECTORY != 0 { - rp.mustBeDir = true - rp.mustBeDirOrig = true - } - for { - fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return fd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return nil, err - } - } -} - -// StatAt returns metadata for the file at the given path. -func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return linux.Statx{}, err - } - for { - stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return stat, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return linux.Statx{}, err - } - } -} - -// StatusFlags returns file description status flags. -func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) { - flags, err := fd.impl.StatusFlags(ctx) - flags |= linux.O_LARGEFILE - return flags, err -} - -// SetStatusFlags sets file description status flags. -func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - return fd.impl.SetStatusFlags(ctx, flags) -} - -// TODO: -// -// - VFS.SyncAllFilesystems() for sync(2) -// -// - Something for syncfs(2) -// -// - VFS.LinkAt() -// -// - VFS.MknodAt() -// -// - VFS.ReadlinkAt() -// -// - VFS.RenameAt() -// -// - VFS.RmdirAt() -// -// - VFS.SetStatAt() -// -// - VFS.StatFSAt() -// -// - VFS.SymlinkAt() -// -// - VFS.UnlinkAt() -// -// - FileDescription.(almost everything) diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go index 70b192ece..d94117bce 100644 --- a/pkg/sentry/vfs/testutil.go +++ b/pkg/sentry/vfs/testutil.go @@ -15,7 +15,10 @@ package vfs import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -33,10 +36,10 @@ type FDTestFilesystem struct { vfsfs Filesystem } -// NewFilesystem implements FilesystemType.NewFilesystem. -func (fstype FDTestFilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) { +// GetFilesystem implements FilesystemType.GetFilesystem. +func (fstype FDTestFilesystemType) GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error) { var fs FDTestFilesystem - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) return &fs.vfsfs, fs.NewDentry(), nil } @@ -114,6 +117,32 @@ func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) err return syserror.EPERM } +// ListxattrAt implements FilesystemImpl.ListxattrAt. +func (fs *FDTestFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) { + return nil, syserror.EPERM +} + +// GetxattrAt implements FilesystemImpl.GetxattrAt. +func (fs *FDTestFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) { + return "", syserror.EPERM +} + +// SetxattrAt implements FilesystemImpl.SetxattrAt. +func (fs *FDTestFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error { + return syserror.EPERM +} + +// RemovexattrAt implements FilesystemImpl.RemovexattrAt. +func (fs *FDTestFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error { + return syserror.EPERM +} + +// PrependPath implements FilesystemImpl.PrependPath. +func (fs *FDTestFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error { + b.PrependComponent(fmt.Sprintf("vfs.fdTestDentry:%p", vd.dentry.impl.(*fdTestDentry))) + return PrependPathSyntheticError{} +} + type fdTestDentry struct { vfsd Dentry } @@ -126,14 +155,14 @@ func (fs *FDTestFilesystem) NewDentry() *Dentry { } // IncRef implements DentryImpl.IncRef. -func (d *fdTestDentry) IncRef(vfsfs *Filesystem) { +func (d *fdTestDentry) IncRef() { } // TryIncRef implements DentryImpl.TryIncRef. -func (d *fdTestDentry) TryIncRef(vfsfs *Filesystem) bool { +func (d *fdTestDentry) TryIncRef() bool { return true } // DecRef implements DentryImpl.DecRef. -func (d *fdTestDentry) DecRef(vfsfs *Filesystem) { +func (d *fdTestDentry) DecRef() { } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 4a8a69540..e60898d7c 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -16,13 +16,24 @@ // // Lock order: // -// Filesystem implementation locks +// FilesystemImpl/FileDescriptionImpl locks // VirtualFilesystem.mountMu +// Dentry.mu +// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry +// VirtualFilesystem.filesystemsMu // VirtualFilesystem.fsTypesMu +// +// Locking Dentry.mu in multiple Dentries requires holding +// VirtualFilesystem.mountMu. package vfs import ( "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" ) // A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts. @@ -33,7 +44,7 @@ type VirtualFilesystem struct { // mountMu serializes mount mutations. // // mountMu is analogous to Linux's namespace_sem. - mountMu sync.RWMutex + mountMu sync.Mutex // mounts maps (mount parent, mount point) pairs to mounts. (Since mounts // are uniquely namespaced, including mount parent in the key correctly @@ -52,7 +63,7 @@ type VirtualFilesystem struct { // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. // - // mountpoints is used to find mounts that must be unmounted due to + // mountpoints is used to find mounts that must be umounted due to // removal of a mount point Dentry from another mount namespace. ("A file // or directory that is a mount point in one namespace that is not a mount // point in another namespace, may be renamed, unlinked, or removed @@ -62,6 +73,11 @@ type VirtualFilesystem struct { // mountpoints is analogous to Linux's mountpoint_hashtable. mountpoints map[*Dentry]map[*Mount]struct{} + // filesystems contains all Filesystems. filesystems is protected by + // filesystemsMu. + filesystemsMu sync.Mutex + filesystems map[*Filesystem]struct{} + // fsTypes contains all FilesystemTypes that are usable in the // VirtualFilesystem. fsTypes is protected by fsTypesMu. fsTypesMu sync.RWMutex @@ -72,12 +88,466 @@ type VirtualFilesystem struct { func New() *VirtualFilesystem { vfs := &VirtualFilesystem{ mountpoints: make(map[*Dentry]map[*Mount]struct{}), + filesystems: make(map[*Filesystem]struct{}), fsTypes: make(map[string]FilesystemType), } vfs.mounts.Init() return vfs } +// PathOperation specifies the path operated on by a VFS method. +// +// PathOperation is passed to VFS methods by pointer to reduce memory copying: +// it's somewhat large and should never escape. (Options structs are passed by +// pointer to VFS and FileDescription methods for the same reason.) +type PathOperation struct { + // Root is the VFS root. References on Root are borrowed from the provider + // of the PathOperation. + // + // Invariants: Root.Ok(). + Root VirtualDentry + + // Start is the starting point for the path traversal. References on Start + // are borrowed from the provider of the PathOperation (i.e. the caller of + // the VFS method to which the PathOperation was passed). + // + // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. + Start VirtualDentry + + // Path is the pathname traversed by this operation. + Pathname string + + // If FollowFinalSymlink is true, and the Dentry traversed by the final + // path component represents a symbolic link, the symbolic link should be + // followed. + FollowFinalSymlink bool +} + +// GetDentryAt returns a VirtualDentry representing the given path, at which a +// file must exist. A reference is taken on the returned VirtualDentry. +func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return VirtualDentry{}, err + } + for { + d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) + if err == nil { + vd := VirtualDentry{ + mount: rp.mount, + dentry: d, + } + rp.mount.IncRef() + vfs.putResolvingPath(rp) + return vd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return VirtualDentry{}, err + } + } +} + +// LinkAt creates a hard link at newpop representing the existing file at +// oldpop. +func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// MkdirAt creates a directory at the given path. +func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { + // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is + // also honored." - mkdir(2) + opts.Mode &= 0777 | linux.S_ISVTX + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// MknodAt creates a file of the given mode at the given path. It returns an +// error from the syserror package. +func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil + } + for { + if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil { + vfs.putResolvingPath(rp) + return nil + } + // Handle mount traversals. + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// OpenAt returns a FileDescription providing access to the file at the given +// path. A reference is taken on the returned FileDescription. +func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { + // Remove: + // + // - O_LARGEFILE, which we always report in FileDescription status flags + // since only 64-bit architectures are supported at this time. + // + // - O_CLOEXEC, which affects file descriptors and therefore must be + // handled outside of VFS. + // + // - Unknown flags. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE + // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. + if opts.Flags&linux.O_SYNC != 0 { + opts.Flags |= linux.O_DSYNC + } + // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified + // with O_DIRECTORY and a writable access mode (to ensure that it fails on + // filesystem implementations that do not support it). + if opts.Flags&linux.O_TMPFILE != 0 { + if opts.Flags&linux.O_DIRECTORY == 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_CREAT != 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { + return nil, syserror.EINVAL + } + } + // O_PATH causes most other flags to be ignored. + if opts.Flags&linux.O_PATH != 0 { + opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH + } + // "On Linux, the following bits are also honored in mode: [S_ISUID, + // S_ISGID, S_ISVTX]" - open(2) + opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX + + if opts.Flags&linux.O_NOFOLLOW != 0 { + pop.FollowFinalSymlink = false + } + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil, err + } + if opts.Flags&linux.O_DIRECTORY != 0 { + rp.mustBeDir = true + rp.mustBeDirOrig = true + } + for { + fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return fd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// ReadlinkAt returns the target of the symbolic link at the given path. +func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return "", err + } + for { + target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return target, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// RenameAt renames the file at oldpop to newpop. +func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// RmdirAt removes the directory at the given path. +func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.RmdirAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SetStatAt changes metadata for the file at the given path. +func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// StatAt returns metadata for the file at the given path. +func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statx{}, err + } + for { + stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return stat, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statx{}, err + } + } +} + +// StatFSAt returns metadata for the filesystem containing the file at the +// given path. +func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statfs{}, err + } + for { + statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return statfs, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statfs{}, err + } + } +} + +// SymlinkAt creates a symbolic link at the given path with the given target. +func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// UnlinkAt deletes the non-directory file at the given path. +func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.UnlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// ListxattrAt returns all extended attribute names for the file at the given +// path. +func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil, err + } + for { + names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return names, nil + } + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by + // default don't exist. + vfs.putResolvingPath(rp) + return nil, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// GetxattrAt returns the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return "", err + } + for { + val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return val, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// SetxattrAt changes the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// RemovexattrAt removes the given extended attribute from the file at rp. +func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SyncAllFilesystems has the semantics of Linux's sync(2). +func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + fss := make(map[*Filesystem]struct{}) + vfs.filesystemsMu.Lock() + for fs := range vfs.filesystems { + if !fs.TryIncRef() { + continue + } + fss[fs] = struct{}{} + } + vfs.filesystemsMu.Unlock() + var retErr error + for fs := range fss { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef() + } + return retErr +} + // A VirtualDentry represents a node in a VFS tree, by combining a Dentry // (which represents a node in a Filesystem's tree) and a Mount (which // represents the Filesystem's position in a VFS mount tree). @@ -111,15 +581,15 @@ func (vd VirtualDentry) Ok() bool { // IncRef increments the reference counts on the Mount and Dentry represented // by vd. func (vd VirtualDentry) IncRef() { - vd.mount.incRef() - vd.dentry.incRef(vd.mount.fs) + vd.mount.IncRef() + vd.dentry.IncRef() } // DecRef decrements the reference counts on the Mount and Dentry represented // by vd. func (vd VirtualDentry) DecRef() { - vd.dentry.decRef(vd.mount.fs) - vd.mount.decRef() + vd.dentry.DecRef() + vd.mount.DecRef() } // Mount returns the Mount associated with vd. It does not take a reference on diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 145102c0d..5e4611333 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -42,8 +42,35 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) -// DefaultTimeout is a resonable timeout value for most applications. -const DefaultTimeout = 3 * time.Minute +// Opts configures the watchdog. +type Opts struct { + // TaskTimeout is the amount of time to allow a task to execute the + // same syscall without blocking before it's declared stuck. + TaskTimeout time.Duration + + // TaskTimeoutAction indicates what action to take when a stuck tasks + // is detected. + TaskTimeoutAction Action + + // StartupTimeout is the amount of time to allow between watchdog + // creation and calling watchdog.Start. + StartupTimeout time.Duration + + // StartupTimeoutAction indicates what action to take when + // watchdog.Start is not called within the timeout. + StartupTimeoutAction Action +} + +// DefaultOpts is a default set of options for the watchdog. +var DefaultOpts = Opts{ + // Task timeout. + TaskTimeout: 3 * time.Minute, + TaskTimeoutAction: LogWarning, + + // Startup timeout. + StartupTimeout: 30 * time.Second, + StartupTimeoutAction: LogWarning, +} // descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period // is discounted from task's last update time. It's set high enough that small scheduling delays won't @@ -61,6 +88,7 @@ type Action int const ( // LogWarning logs warning message followed by stack trace. LogWarning Action = iota + // Panic will do the same logging as LogWarning and panic(). Panic ) @@ -80,17 +108,13 @@ func (a Action) String() string { // Watchdog is the main watchdog class. It controls a goroutine that periodically // analyses all tasks and reports if any of them appear to be stuck. type Watchdog struct { + // Configuration options are embedded. + Opts + // period indicates how often to check all tasks. It's calculated based on - // 'taskTimeout'. + // opts.TaskTimeout. period time.Duration - // taskTimeout is the amount of time to allow a task to execute the same syscall - // without blocking before it's declared stuck. - taskTimeout time.Duration - - // timeoutAction indicates what action to take when a stuck tasks is detected. - timeoutAction Action - // k is where the tasks come from. k *kernel.Kernel @@ -113,8 +137,12 @@ type Watchdog struct { // mu protects the fields below. mu sync.Mutex - // started is true if the watchdog has been started before. - started bool + // running is true if the watchdog is running. + running bool + + // startCalled is true if Start has ever been called. It remains true + // even if Stop is called. + startCalled bool } type offender struct { @@ -122,58 +150,81 @@ type offender struct { } // New creates a new watchdog. -func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog { - // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much. - period := taskTimeout / 4 - return &Watchdog{ - k: k, - period: period, - taskTimeout: taskTimeout, - timeoutAction: a, - offenders: make(map[*kernel.Task]*offender), - stop: make(chan struct{}), - done: make(chan struct{}), +func New(k *kernel.Kernel, opts Opts) *Watchdog { + // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much. + period := opts.TaskTimeout / 4 + w := &Watchdog{ + Opts: opts, + k: k, + period: period, + offenders: make(map[*kernel.Task]*offender), + stop: make(chan struct{}), + done: make(chan struct{}), + } + + // Handle StartupTimeout if it exists. + if w.StartupTimeout > 0 { + log.Infof("Watchdog waiting %v for startup", w.StartupTimeout) + go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore. } + + return w } // Start starts the watchdog. func (w *Watchdog) Start() { - if w.taskTimeout == 0 { - log.Infof("Watchdog disabled") - return - } - w.mu.Lock() defer w.mu.Unlock() - if w.started { + w.startCalled = true + + if w.running { return } + if w.TaskTimeout == 0 { + log.Infof("Watchdog task timeout disabled") + return + } w.lastRun = w.k.MonotonicClock().Now() - log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction) + log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction) go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore. - w.started = true + w.running = true } // Stop requests the watchdog to stop and wait for it. func (w *Watchdog) Stop() { - if w.taskTimeout == 0 { + if w.TaskTimeout == 0 { return } w.mu.Lock() defer w.mu.Unlock() - if !w.started { + if !w.running { return } log.Infof("Stopping watchdog") w.stop <- struct{}{} <-w.done - w.started = false + w.running = false log.Infof("Watchdog stopped") } +// waitForStart waits for Start to be called and takes action if it does not +// happen within the startup timeout. +func (w *Watchdog) waitForStart() { + <-time.After(w.StartupTimeout) + w.mu.Lock() + defer w.mu.Unlock() + if w.startCalled { + // We are fine. + return + } + var buf bytes.Buffer + buf.WriteString("Watchdog.Start() not called within %s:\n") + w.doAction(w.StartupTimeoutAction, false, &buf) +} + // loop is the main watchdog routine. It only returns when 'Stop()' is called. func (w *Watchdog) loop() { // Loop until someone stops it. @@ -202,7 +253,7 @@ func (w *Watchdog) runTurn() { select { case <-done: - case <-time.After(w.taskTimeout): + case <-time.After(w.TaskTimeout): // Report if the watchdog is not making progress. // No one is wathching the watchdog watcher though. w.reportStuckWatchdog() @@ -231,12 +282,14 @@ func (w *Watchdog) runTurn() { if tsched.State == kernel.TaskGoroutineRunningSys { lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick))) elapsed := now.Sub(lastUpdateTime) - discount - if elapsed > w.taskTimeout { + if elapsed > w.TaskTimeout { tc, ok := w.offenders[t] if !ok { // New stuck task detected. // - // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel. + // Note that tasks blocked doing IO may be considered stuck in kernel, + // unless they are surrounded b + // Task.UninterruptibleSleepStart/Finish. tc = &offender{lastUpdateTime: lastUpdateTime} stuckTasks.Increment() newTaskFound = true @@ -261,28 +314,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo tid := w.k.TaskSet().Root.IDOfTask(t) buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime))) } + buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") - w.onStuckTask(newTaskFound, &buf) + + // Dump stack only if a new task is detected or if it sometime has + // passed since the last time a stack dump was generated. + skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod + w.doAction(w.TaskTimeoutAction, skipStack, &buf) } func (w *Watchdog) reportStuckWatchdog() { var buf bytes.Buffer buf.WriteString("Watchdog goroutine is stuck:\n") - w.onStuckTask(true, &buf) + w.doAction(w.TaskTimeoutAction, false, &buf) } -func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { - switch w.timeoutAction { +// doAction will take the given action. If the action is LogWarnind and +// skipStack is true, then the stack printing will be skipped. +func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { + switch action { case LogWarning: - // Dump stack only if a new task is detected or if it sometime has passed since - // the last time a stack dump was generated. - if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod { + if skipStack { msg.WriteString("\n...[stack dump skipped]...") log.Warningf(msg.String()) - } else { - log.TracebackAll(msg.String()) - w.lastStackDump = time.Now() + return + } + log.TracebackAll(msg.String()) + w.lastStackDump = time.Now() case Panic: // Panic will skip over running tasks, which is likely the culprit here. So manually @@ -301,5 +360,8 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { case <-time.After(1 * time.Second): } panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String())) + default: + panic(fmt.Sprintf("Unknown watchdog action %v", action)) + } } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 8f5e60a25..acbf0229b 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 47e6b878a..590c241a3 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState { // to ensure that all callbacks are executed, otherwise the callback graph was // not acyclic. type decodeState struct { + // ctx is the decode context. + ctx context.Context + // objectByID is the set of objects in progress. objectsByID map[uint64]*objectState diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 5d9409a45..c5118d3a9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -16,6 +16,7 @@ package state import ( "container/list" + "context" "encoding/binary" "fmt" "io" @@ -38,6 +39,9 @@ type queuedObject struct { // The encoding process is a breadth-first traversal of the object graph. The // inherent races and dependencies are much simpler than the decode case. type encodeState struct { + // ctx is the encode context. + ctx context.Context + // lastID is the last object ID. // // See idsByObject for context. Because of the special zero encoding diff --git a/pkg/state/map.go b/pkg/state/map.go index 7e6fefed4..4f3ebb0da 100644 --- a/pkg/state/map.go +++ b/pkg/state/map.go @@ -15,6 +15,7 @@ package state import ( + "context" "fmt" "reflect" "sort" @@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) { // data dependencies have been cleared. m.os.callbacks = append(m.os.callbacks, fn) } + +// Context returns the current context object. +func (m Map) Context() context.Context { + if m.es != nil { + return m.es.ctx + } else if m.ds != nil { + return m.ds.ctx + } + return context.Background() // No context. +} diff --git a/pkg/state/object.proto b/pkg/state/object.proto index 952289069..5ebcfb151 100644 --- a/pkg/state/object.proto +++ b/pkg/state/object.proto @@ -18,8 +18,8 @@ package gvisor.state.statefile; // Slice is a slice value. message Slice { - uint32 length = 1; - uint32 capacity = 2; + uint32 length = 1; + uint32 capacity = 2; uint64 ref_value = 3; } @@ -30,13 +30,13 @@ message Array { // Map is a map value. message Map { - repeated Object keys = 1; + repeated Object keys = 1; repeated Object values = 2; } // Interface is an interface value. message Interface { - string type = 1; + string type = 1; Object value = 2; } @@ -47,7 +47,7 @@ message Struct { // Field encodes a single field. message Field { - string name = 1; + string name = 1; Object value = 2; } @@ -113,28 +113,28 @@ message Float32s { // Note that ref_value references an Object.id, below. message Object { oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; + bool bool_value = 1; + bytes string_value = 2; + int64 int64_value = 3; + uint64 uint64_value = 4; + double double_value = 5; + uint64 ref_value = 6; + Slice slice_value = 7; + Array array_value = 8; + Interface interface_value = 9; + Struct struct_value = 10; + Map map_value = 11; + bytes byte_array_value = 12; + Uint16s uint16_array_value = 13; + Uint32s uint32_array_value = 14; + Uint64s uint64_array_value = 15; + Uintptrs uintptr_array_value = 16; + Int8s int8_array_value = 17; + Int16s int16_array_value = 18; + Int32s int32_array_value = 19; + Int64s int64_array_value = 20; + Bools bool_array_value = 21; + Float64s float64_array_value = 22; + Float32s float32_array_value = 23; } } diff --git a/pkg/state/state.go b/pkg/state/state.go index d408ff84a..dbe507ab4 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -50,6 +50,7 @@ package state import ( + "context" "fmt" "io" "reflect" @@ -86,9 +87,10 @@ func UnwrapErrState(err error) error { } // Save saves the given object state. -func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { // Create the encoding state. es := &encodeState{ + ctx: ctx, idsByObject: make(map[uintptr]uint64), w: w, stats: stats, @@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { } // Load loads a checkpoint. -func Load(r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { // Create the decoding state. ds := &decodeState{ + ctx: ctx, objectsByID: make(map[uint64]*objectState), deferred: make(map[uint64]*pb.Object), r: r, diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 7c24bbcda..d7221e9e8 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "io/ioutil" "math" "reflect" @@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) { saveBuffer := &bytes.Buffer{} saveObjectPtr := reflect.New(reflect.TypeOf(root)) saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Save failed unexpectedly: %v", err) continue } else if err != nil { @@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) { // Load a new copy of the object. loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Load failed unexpectedly: %v", err) continue } else if err != nil { @@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) { bs := buildObject(b.N) var stats Stats b.StartTimer() - if err := Save(ioutil.Discard, bs, &stats); err != nil { + if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { b.Errorf("save failed: %v", err) } b.StopTimer() @@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) { bs := buildObject(b.N) var newBS benchStruct buf := &bytes.Buffer{} - if err := Save(buf, bs, nil); err != nil { + if err := Save(context.Background(), buf, bs, nil); err != nil { b.Errorf("save failed: %v", err) } var stats Stats b.StartTimer() - if err := Load(buf, &newBS, &stats); err != nil { + if err := Load(context.Background(), buf, &newBS, &stats); err != nil { b.Errorf("load failed: %v", err) } b.StopTimer() diff --git a/third_party/gvsync/BUILD b/pkg/syncutil/BUILD index 7d6d59c48..cb1f41628 100644 --- a/third_party/gvsync/BUILD +++ b/pkg/syncutil/BUILD @@ -1,4 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template") package( @@ -28,26 +29,24 @@ go_template( ) go_library( - name = "gvsync", + name = "syncutil", srcs = [ - "downgradable_rwmutex_1_12_unsafe.go", - "downgradable_rwmutex_1_13_unsafe.go", "downgradable_rwmutex_unsafe.go", - "gvsync.go", "memmove_unsafe.go", "norace_unsafe.go", "race_unsafe.go", "seqcount.go", + "syncutil.go", ], - importpath = "gvisor.dev/gvisor/third_party/gvsync", + importpath = "gvisor.dev/gvisor/pkg/syncutil", ) go_test( - name = "gvsync_test", + name = "syncutil_test", size = "small", srcs = [ "downgradable_rwmutex_test.go", "seqcount_test.go", ], - embed = [":gvsync"], + embed = [":syncutil"], ) diff --git a/third_party/gvsync/LICENSE b/pkg/syncutil/LICENSE index 6a66aea5e..6a66aea5e 100644 --- a/third_party/gvsync/LICENSE +++ b/pkg/syncutil/LICENSE diff --git a/third_party/gvsync/README.md b/pkg/syncutil/README.md index fcc7e6f44..2183c4e20 100644 --- a/third_party/gvsync/README.md +++ b/pkg/syncutil/README.md @@ -1,3 +1,5 @@ +# Syncutil + This package provides additional synchronization primitives not provided by the Go stdlib 'sync' package. It is partially derived from the upstream 'sync' -package. +package from go1.10. diff --git a/third_party/gvsync/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go index 525c4beed..525c4beed 100644 --- a/third_party/gvsync/atomicptr_unsafe.go +++ b/pkg/syncutil/atomicptr_unsafe.go diff --git a/third_party/gvsync/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD index 447ecf96a..63f411a90 100644 --- a/third_party/gvsync/atomicptrtest/BUILD +++ b/pkg/syncutil/atomicptrtest/BUILD @@ -1,4 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -8,7 +9,7 @@ go_template_instance( out = "atomicptr_int_unsafe.go", package = "atomicptr", suffix = "Int", - template = "//third_party/gvsync:generic_atomicptr", + template = "//pkg/syncutil:generic_atomicptr", types = { "Value": "int", }, @@ -17,7 +18,7 @@ go_template_instance( go_library( name = "atomicptr", srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/third_party/gvsync/atomicptr", + importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr", ) go_test( diff --git a/third_party/gvsync/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go index 8fdc5112e..8fdc5112e 100644 --- a/third_party/gvsync/atomicptrtest/atomicptr_test.go +++ b/pkg/syncutil/atomicptrtest/atomicptr_test.go diff --git a/third_party/gvsync/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go index 40c384b8b..ffaf7ecc7 100644 --- a/third_party/gvsync/downgradable_rwmutex_test.go +++ b/pkg/syncutil/downgradable_rwmutex_test.go @@ -9,7 +9,7 @@ // addition of downgradingWriter and the renaming of num_iterations to // numIterations to shut up Golint. -package gvsync +package syncutil import ( "fmt" diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go index 1f6007aa1..51e11555d 100644 --- a/third_party/gvsync/downgradable_rwmutex_unsafe.go +++ b/pkg/syncutil/downgradable_rwmutex_unsafe.go @@ -3,8 +3,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.12 -// +build !go1.14 +// +build go1.13 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -16,7 +16,7 @@ // - RUnlock -> Lock (via writerSem) // - DowngradeLock -> RLock (via readerSem) -package gvsync +package syncutil import ( "sync" @@ -27,6 +27,9 @@ import ( //go:linkname runtimeSemacquire sync.runtime_Semacquire func runtimeSemacquire(s *uint32) +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + // DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock // method. type DowngradableRWMutex struct { diff --git a/third_party/gvsync/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go index 84b69f215..348675baa 100644 --- a/third_party/gvsync/memmove_unsafe.go +++ b/pkg/syncutil/memmove_unsafe.go @@ -4,11 +4,11 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. -package gvsync +package syncutil import ( "unsafe" diff --git a/third_party/gvsync/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go index e3852db8c..0a0a9deda 100644 --- a/third_party/gvsync/norace_unsafe.go +++ b/pkg/syncutil/norace_unsafe.go @@ -5,7 +5,7 @@ // +build !race -package gvsync +package syncutil import ( "unsafe" diff --git a/third_party/gvsync/race_unsafe.go b/pkg/syncutil/race_unsafe.go index 13c02a830..206067ec1 100644 --- a/third_party/gvsync/race_unsafe.go +++ b/pkg/syncutil/race_unsafe.go @@ -5,7 +5,7 @@ // +build race -package gvsync +package syncutil import ( "runtime" diff --git a/third_party/gvsync/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go index 382eeed43..cb6d2eb22 100644 --- a/third_party/gvsync/seqatomic_unsafe.go +++ b/pkg/syncutil/seqatomic_unsafe.go @@ -13,7 +13,7 @@ import ( "strings" "unsafe" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // Value is a required type parameter. @@ -26,17 +26,17 @@ type Value struct{} // SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race // with any writer critical sections in sc. -func SeqAtomicLoad(sc *gvsync.SeqCount, ptr *Value) Value { +func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value { // This function doesn't use SeqAtomicTryLoad because doing so is // measurably, significantly (~20%) slower; Go is awful at inlining. var val Value for { epoch := sc.BeginRead() - if gvsync.RaceEnabled { + if syncutil.RaceEnabled { // runtime.RaceDisable() doesn't actually stop the race detector, // so it can't help us here. Instead, call runtime.memmove // directly, which is not instrumented by the race detector. - gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { // This is ~40% faster for short reads than going through memmove. val = *ptr @@ -52,10 +52,10 @@ func SeqAtomicLoad(sc *gvsync.SeqCount, ptr *Value) Value { // in sc initiated by a call to sc.BeginRead() that returned epoch. If the read // would race with a writer critical section, SeqAtomicTryLoad returns // (unspecified, false). -func SeqAtomicTryLoad(sc *gvsync.SeqCount, epoch gvsync.SeqCountEpoch, ptr *Value) (Value, bool) { +func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) { var val Value - if gvsync.RaceEnabled { - gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + if syncutil.RaceEnabled { + syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { val = *ptr } @@ -66,7 +66,7 @@ func init() { var val Value typ := reflect.TypeOf(val) name := typ.Name() - if ptrs := gvsync.PointersInType(typ, name); len(ptrs) != 0 { + if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 { panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) } } diff --git a/third_party/gvsync/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD index c858c20c4..ba18f3238 100644 --- a/third_party/gvsync/seqatomictest/BUILD +++ b/pkg/syncutil/seqatomictest/BUILD @@ -1,4 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -8,7 +9,7 @@ go_template_instance( out = "seqatomic_int_unsafe.go", package = "seqatomic", suffix = "Int", - template = "//third_party/gvsync:generic_seqatomic", + template = "//pkg/syncutil:generic_seqatomic", types = { "Value": "int", }, @@ -17,9 +18,9 @@ go_template_instance( go_library( name = "seqatomic", srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/third_party/gvsync/seqatomic", + importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic", deps = [ - "//third_party/gvsync", + "//pkg/syncutil", ], ) @@ -29,6 +30,6 @@ go_test( srcs = ["seqatomic_test.go"], embed = [":seqatomic"], deps = [ - "//third_party/gvsync", + "//pkg/syncutil", ], ) diff --git a/third_party/gvsync/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go index a5447f589..b0db44999 100644 --- a/third_party/gvsync/seqatomictest/seqatomic_test.go +++ b/pkg/syncutil/seqatomictest/seqatomic_test.go @@ -19,11 +19,11 @@ import ( "testing" "time" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) func TestSeqAtomicLoadUncontended(t *testing.T) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount const want = 1 data := want if got := SeqAtomicLoadInt(&seq, &data); got != want { @@ -32,7 +32,7 @@ func TestSeqAtomicLoadUncontended(t *testing.T) { } func TestSeqAtomicLoadAfterWrite(t *testing.T) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount var data int const want = 1 seq.BeginWrite() @@ -44,7 +44,7 @@ func TestSeqAtomicLoadAfterWrite(t *testing.T) { } func TestSeqAtomicLoadDuringWrite(t *testing.T) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount var data int const want = 1 seq.BeginWrite() @@ -59,7 +59,7 @@ func TestSeqAtomicLoadDuringWrite(t *testing.T) { } func TestSeqAtomicTryLoadUncontended(t *testing.T) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount const want = 1 data := want epoch := seq.BeginRead() @@ -69,7 +69,7 @@ func TestSeqAtomicTryLoadUncontended(t *testing.T) { } func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount var data int epoch := seq.BeginRead() seq.BeginWrite() @@ -80,7 +80,7 @@ func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { } func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount var data int epoch := seq.BeginRead() seq.BeginWrite() @@ -91,7 +91,7 @@ func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { } func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount const want = 42 data := want b.RunParallel(func(pb *testing.PB) { @@ -104,7 +104,7 @@ func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { } func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { - var seq gvsync.SeqCount + var seq syncutil.SeqCount const want = 42 data := want b.RunParallel(func(pb *testing.PB) { diff --git a/third_party/gvsync/seqcount.go b/pkg/syncutil/seqcount.go index 2c9c2c3d6..11d8dbfaa 100644 --- a/third_party/gvsync/seqcount.go +++ b/pkg/syncutil/seqcount.go @@ -3,7 +3,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package gvsync +package syncutil import ( "fmt" diff --git a/third_party/gvsync/seqcount_test.go b/pkg/syncutil/seqcount_test.go index 085e574b3..14d6aedea 100644 --- a/third_party/gvsync/seqcount_test.go +++ b/pkg/syncutil/seqcount_test.go @@ -3,7 +3,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package gvsync +package syncutil import ( "reflect" diff --git a/third_party/gvsync/gvsync.go b/pkg/syncutil/syncutil.go index 3bbef13c3..66e750d06 100644 --- a/third_party/gvsync/gvsync.go +++ b/pkg/syncutil/syncutil.go @@ -3,5 +3,5 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package gvsync provides synchronization primitives. -package gvsync +// Package syncutil provides synchronization primitives. +package syncutil diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 3c2b2b5ea..65d4d0cd8 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -6,6 +6,8 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ + "packet_buffer.go", + "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", ], diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..ee077ae83 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -151,10 +151,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -203,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 4287464f3..ba21f4eca 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -41,6 +41,11 @@ func NewPrependableFromView(v View) Prependable { return Prependable{buf: v, usedIdx: 0} } +// NewEmptyPrependableFromView creates a new prependable buffer from a View. +func NewEmptyPrependableFromView(v View) Prependable { + return Prependable{buf: v, usedIdx: len(v)} +} + // View returns a View of the backing buffer that contains all prepended // data so far. func (p Prependable) View() View { @@ -72,3 +77,9 @@ func (p *Prependable) Prepend(size int) []byte { p.usedIdx -= size return p.View()[:size:size] } + +// DeepCopy copies p and the bytes backing it. +func (p Prependable) DeepCopy() Prependable { + p.buf = append(View(nil), p.buf...) + return p +} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index 4cecfb989..b6fa6fc37 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 02137e1c9..2f15bf1f1 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -639,6 +640,8 @@ func ICMPv4Code(want byte) TransportChecker { // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and // potentially additional ICMPv6 header fields. +// +// ICMPv6 will validate the checksum field before calling checkers. func ICMPv6(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -650,6 +653,10 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { } icmp := header.ICMPv6(last.Payload()) + if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { + t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) + } + for _, f := range checkers { f(t, icmp) } diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 0c5c20cea..e648efa71 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -7,9 +7,7 @@ go_library( name = "jenkins", srcs = ["jenkins.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 07d09abed..f1d837196 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -19,6 +19,7 @@ go_library( "ndp_neighbor_advert.go", "ndp_neighbor_solicit.go", "ndp_options.go", + "ndp_router_advert.go", "tcp.go", "udp.go", ], @@ -36,16 +37,29 @@ go_test( name = "header_x_test", size = "small", srcs = [ + "checksum_test.go", + "ipv6_test.go", "ipversion_test.go", "tcp_test.go", ], - deps = [":header"], + deps = [ + ":header", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "@com_github_google_go-cmp//cmp:go_default_library", + ], ) go_test( name = "header_test", size = "small", - srcs = ["ndp_test.go"], + srcs = [ + "eth_test.go", + "ndp_test.go", + ], embed = [":header"], - deps = ["//pkg/tcpip"], + deps = [ + "//pkg/tcpip", + "@com_github_google_go-cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 39a4d69be..9749c7f4d 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -23,11 +23,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -func calculateChecksum(buf []byte, initial uint32) uint16 { +func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { v := initial + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + l := len(buf) - if l&1 != 0 { + odd = l&1 != 0 + if odd { l-- v += uint32(buf[l]) << 8 } @@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) } - return ChecksumCombine(uint16(v), uint16(v>>16)) + return ChecksumCombine(uint16(v), uint16(v>>16)), odd } // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the @@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - return calculateChecksum(buf, uint32(initial)) + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s } // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in @@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - var odd bool + return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) +} + +// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the +// bytes in the given VectorizedView. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { + odd := false sum := initial for _, v := range vv.Views() { if len(v) == 0 { continue } - s := uint32(sum) - if odd { - s += uint32(v[0]) - v = v[1:] + + if off >= len(v) { + off -= len(v) + continue + } + v = v[off:] + + l := len(v) + if l > size { + l = size + } + v = v[:l] + + sum, odd = calculateChecksum(v, odd, uint32(sum)) + + size -= len(v) + if size == 0 { + break } - odd = len(v)&1 != 0 - sum = calculateChecksum(v, s) + off = 0 } return sum } diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go new file mode 100644 index 000000000..86b466c1c --- /dev/null +++ b/pkg/tcpip/header/checksum_test.go @@ -0,0 +1,109 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package header provides the implementation of the encoding and decoding of +// network protocol headers. +package header_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestChecksumVVWithOffset(t *testing.T) { + testCases := []struct { + name string + vv buffer.VectorisedView + off, size int + initial uint16 + want uint16 + }{ + { + name: "empty", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 0, + want: 0, + }, + { + name: "OneView", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 5, + want: 1294, + }, + { + name: "TwoViews", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 0, + size: 11, + want: 33819, + }, + { + name: "TwoViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 1, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 7, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithInitial", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), + }), + initial: 77, + off: 7, + size: 11, + want: 33896, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { + t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) + } + v := tc.vv.ToView() + v.TrimFront(tc.off) + v.CapLength(tc.size) + if got, want := header.Checksum(v, tc.initial), tc.want; got != want { + t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) + } + }) + } +} diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index 4c3d3311f..f5d2c127f 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -48,8 +48,48 @@ const ( // EthernetAddressSize is the size, in bytes, of an ethernet address. EthernetAddressSize = 6 + + // unspecifiedEthernetAddress is the unspecified ethernet address + // (all bits set to 0). + unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + + // unicastMulticastFlagMask is the mask of the least significant bit in + // the first octet (in network byte order) of an ethernet address that + // determines whether the ethernet address is a unicast or multicast. If + // the masked bit is a 1, then the address is a multicast, unicast + // otherwise. + // + // See the IEEE Std 802-2001 document for more details. Specifically, + // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf: + // "A 48-bit universal address consists of two parts. The first 24 bits + // correspond to the OUI as assigned by the IEEE, expect that the + // assignee may set the LSB of the first octet to 1 for group addresses + // or set it to 0 for individual addresses." + unicastMulticastFlagMask = 1 + + // unicastMulticastFlagByteIdx is the byte that holds the + // unicast/multicast flag. See unicastMulticastFlagMask. + unicastMulticastFlagByteIdx = 0 +) + +const ( + // EthernetProtocolAll is a catch-all for all protocols carried inside + // an ethernet frame. It is mainly used to create packet sockets that + // capture all traffic. + EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003 + + // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype. + EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200 ) +// Ethertypes holds the protocol numbers describing the payload of an ethernet +// frame. These types aren't necessarily supported by netstack, but can be used +// to catch all traffic of a type via packet endpoints. +var Ethertypes = []tcpip.NetworkProtocolNumber{ + EthernetProtocolAll, + EthernetProtocolPUP, +} + // SourceAddress returns the "MAC source" field of the ethernet frame header. func (b Ethernet) SourceAddress() tcpip.LinkAddress { return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize]) @@ -72,3 +112,25 @@ func (b Ethernet) Encode(e *EthernetFields) { copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr) copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) } + +// IsValidUnicastEthernetAddress returns true if addr is a valid unicast +// ethernet address. +func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { + // Must be of the right length. + if len(addr) != EthernetAddressSize { + return false + } + + // Must not be unspecified. + if addr == unspecifiedEthernetAddress { + return false + } + + // Must not be a multicast. + if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { + return false + } + + // addr is a valid unicast ethernet address. + return true +} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go new file mode 100644 index 000000000..6634c90f5 --- /dev/null +++ b/pkg/tcpip/header/eth_test.go @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestIsValidUnicastEthernetAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.LinkAddress + expected bool + }{ + { + "Nil", + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty", + tcpip.LinkAddress(""), + false, + }, + { + "InvalidLength", + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Unspecified", + unspecifiedEthernetAddress, + false, + }, + { + "Multicast", + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Valid", + tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), + true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { + t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index e51c5098c..b4037b6c8 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -80,6 +80,13 @@ const ( // icmpv6SequenceOffset is the offset of the sequence field // in a ICMPv6 Echo Request/Reply message. icmpv6SequenceOffset = 6 + + // NDPHopLimit is the expected IP hop limit value of 255 for received + // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, + // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently + // drop the NDP packet. All outgoing NDP packets must use this value for + // its IP hop limit field. + NDPHopLimit = 255 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -125,7 +132,7 @@ func (b ICMPv6) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } -// SetChecksum calculates and sets the ICMP checksum field. +// SetChecksum sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } @@ -190,7 +197,7 @@ func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } -// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { // Calculate the IPv6 pseudo-header upper-layer checksum. diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index b125bbea5..fc671e439 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -90,9 +90,23 @@ const ( // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 ) -// IPv6EmptySubnet is the empty IPv6 subnet. +// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the +// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to +// be contained within this subnet. var IPv6EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) if err != nil { @@ -101,6 +115,15 @@ var IPv6EmptySubnet = func() tcpip.Subnet { return subnet }() +// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined +// by RFC 4291 section 2.5.6. +// +// The prefix is fe80::/64 +var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{ + Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 64, +} + // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { @@ -255,27 +278,43 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 +// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFF + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit +// Ethernet/MAC address, as per RFC 4291 section 2.5.1. +func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { - // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local - // header, FE80::. + // Convert a 48-bit MAC to a modified EUI-64 and then prepend the + // link-local header, FE80::. // // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff // Note the capital A. The conversion aa->Aa involves a bit flip. - lladdrb := [16]byte{ - 0: 0xFE, - 1: 0x80, - 8: linkAddr[0] ^ 2, - 9: linkAddr[1], - 10: linkAddr[2], - 11: 0xFF, - 12: 0xFE, - 13: linkAddr[3], - 14: linkAddr[4], - 15: linkAddr[5], + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, } + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go new file mode 100644 index 000000000..42c5c6fc1 --- /dev/null +++ b/pkg/tcpip/header/ipv6_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +func TestEthernetAdddressToModifiedEUI64(t *testing.T) { + expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} + + if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) + } + + var buf [header.IIDSize]byte + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + if diff := cmp.Diff(expectedIID, buf); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) + } +} + +func TestLinkLocalAddr(t *testing.T) { + if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) + } +} diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go index 5c2b472c8..505c92668 100644 --- a/pkg/tcpip/header/ndp_neighbor_advert.go +++ b/pkg/tcpip/header/ndp_neighbor_advert.go @@ -18,6 +18,8 @@ import "gvisor.dev/gvisor/pkg/tcpip" // NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will // only contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.4 for more details. type NDPNeighborAdvert []byte const ( diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go index 1dcb0fbc6..3a1b8e139 100644 --- a/pkg/tcpip/header/ndp_neighbor_solicit.go +++ b/pkg/tcpip/header/ndp_neighbor_solicit.go @@ -18,6 +18,8 @@ import "gvisor.dev/gvisor/pkg/tcpip" // NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only // contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.3 for more details. type NDPNeighborSolicit []byte const ( diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index b28bde15b..06e0bace2 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -15,6 +15,11 @@ package header import ( + "encoding/binary" + "errors" + "math" + "time" + "gvisor.dev/gvisor/pkg/tcpip" ) @@ -27,15 +32,226 @@ const ( // Link Layer Option for an Ethernet address. ndpTargetEthernetLinkLayerAddressSize = 8 + // NDPPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + NDPPrefixInformationType = 3 + + // ndpPrefixInformationLength is the expected length, in bytes, of the + // body of an NDP Prefix Information option, as per RFC 4861 section + // 4.6.2 which specifies that the Length field is 4. Given this, the + // expected length, in bytes, is 30 becuase 4 * lengthByteUnits (8) - 2 + // (Type & Length) = 30. + ndpPrefixInformationLength = 30 + + // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix + // Length field within an NDPPrefixInformation. + ndpPrefixInformationPrefixLengthOffset = 0 + + // ndpPrefixInformationFlagsOffset is the offset of the flags byte + // within an NDPPrefixInformation. + ndpPrefixInformationFlagsOffset = 1 + + // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationOnLinkFlagMask = (1 << 7) + + // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the + // Autonomous Address-Configuration flag field in the flags byte within + // an NDPPrefixInformation. + ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + + // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationReserved1FlagsMask = 63 + + // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte + // Valid Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationValidLifetimeOffset = 2 + + // ndpPrefixInformationPreferredLifetimeOffset is the start of the + // 4-byte Preferred Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationPreferredLifetimeOffset = 6 + + // ndpPrefixInformationReserved2Offset is the start of the 4-byte + // Reserved2 field within an NDPPrefixInformation. + ndpPrefixInformationReserved2Offset = 10 + + // ndpPrefixInformationReserved2Length is the length of the Reserved2 + // field. + // + // It is 4 bytes. + ndpPrefixInformationReserved2Length = 4 + + // ndpPrefixInformationPrefixOffset is the start of the Prefix field + // within an NDPPrefixInformation. + ndpPrefixInformationPrefixOffset = 14 + + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType = 25 + + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS + // Server option's length field value when it contains at least one + // IPv6 address. + minNDPRecursiveDNSServerLength = 3 + // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. lengthByteUnits = 8 ) +var ( + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. + // + // This is a variable instead of a constant so that tests can change + // this value to a smaller value. It should only be modified by tests. + NDPInfiniteLifetime = time.Second * math.MaxUint32 +) + +// NDPOptionIterator is an iterator of NDPOption. +// +// Note, between when an NDPOptionIterator is obtained and last used, no changes +// to the NDPOptions may happen. Doing so may cause undefined and unexpected +// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first +// few NDPOption then modify the backing NDPOptions so long as the +// NDPOptionIterator obtained before modification is no longer used. +type NDPOptionIterator struct { + // The NDPOptions this NDPOptionIterator is iterating over. + opts NDPOptions +} + +// Potential errors when iterating over an NDPOptions. +var ( + ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") + ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") +) + +// Next returns the next element in the backing NDPOptions, or true if we are +// done, or false if an error occured. +// +// The return can be read as option, done, error. Note, option should only be +// used if done is false and error is nil. +func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { + for { + // Do we still have elements to look at? + if len(i.opts) == 0 { + return nil, true, nil + } + + // Do we have enough bytes for an NDP option that has a Length + // field of at least 1? Note, 0 in the Length field is invalid. + if len(i.opts) < lengthByteUnits { + return nil, true, ErrNDPOptBufExhausted + } + + // Get the Type field. + t := i.opts[0] + + // Get the Length field. + l := i.opts[1] + + // This would indicate an erroneous NDP option as the Length + // field should never be 0. + if l == 0 { + return nil, true, ErrNDPOptZeroLength + } + + // How many bytes are in the option body? + numBytes := int(l) * lengthByteUnits + numBodyBytes := numBytes - 2 + + potentialBody := i.opts[2:] + + // This would indicate an erroenous NDPOptions buffer as we ran + // out of the buffer in the middle of an NDP option. + if left := len(potentialBody); left < numBodyBytes { + return nil, true, ErrNDPOptBufExhausted + } + + // Get only the options body, leaving the rest of the options + // buffer alone. + body := potentialBody[:numBodyBytes] + + // Update opts with the remaining options body. + i.opts = i.opts[numBytes:] + + switch t { + case NDPTargetLinkLayerAddressOptionType: + return NDPTargetLinkLayerAddressOption(body), false, nil + + case NDPPrefixInformationType: + // Make sure the length of a Prefix Information option + // body is ndpPrefixInformationLength, as per RFC 4861 + // section 4.6.2. + if numBodyBytes != ndpPrefixInformationLength { + return nil, true, ErrNDPOptMalformedBody + } + + return NDPPrefixInformation(body), false, nil + + case NDPRecursiveDNSServerOptionType: + // RFC 8106 section 5.3.1 outlines that the RDNSS option + // must have a minimum length of 3 so it contains at + // least one IPv6 address. + if l < minNDPRecursiveDNSServerLength { + return nil, true, ErrNDPInvalidLength + } + + opt := NDPRecursiveDNSServer(body) + if len(opt.Addresses()) == 0 { + return nil, true, ErrNDPOptMalformedBody + } + + return opt, false, nil + + default: + // We do not yet recognize the option, just skip for + // now. This is okay because RFC 4861 allows us to + // skip/ignore any unrecognized options. However, + // we MUST recognized all the options in RFC 4861. + // + // TODO(b/141487990): Handle all NDP options as defined + // by RFC 4861. + } + } +} + // NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6. type NDPOptions []byte +// Iter returns an iterator of NDPOption. +// +// If check is true, Iter will do an integrity check on the options by iterating +// over it and returning an error if detected. +// +// See NDPOptionIterator for more information. +func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { + it := NDPOptionIterator{opts: b} + + if check { + for it2 := it; true; { + if _, done, err := it2.Next(); err != nil || done { + return it, err + } + } + } + + return it, nil +} + // Serialize serializes the provided list of NDP options into o. // // Note, b must be of sufficient size to hold all the options in s. See @@ -75,15 +291,15 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { return done } -// ndpOption is the set of functions to be implemented by all NDP option types. -type ndpOption interface { - // Type returns the type of this ndpOption. +// NDPOption is the set of functions to be implemented by all NDP option types. +type NDPOption interface { + // Type returns the type of the receiver. Type() uint8 - // Length returns the length of the body of this ndpOption, in bytes. + // Length returns the length of the body of the receiver, in bytes. Length() int - // serializeInto serializes this ndpOption into the provided byte + // serializeInto serializes the receiver into the provided byte // buffer. // // Note, the caller MUST provide a byte buffer with size of at least @@ -92,15 +308,15 @@ type ndpOption interface { // buffer is not of sufficient size. // // serializeInto will return the number of bytes that was used to - // serialize this ndpOption. Implementers must only use the number of - // bytes required to serialize this ndpOption. Callers MAY provide a + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a // larger buffer than required to serialize into. serializeInto([]byte) int } // paddedLength returns the length of o, in bytes, with any padding bytes, if // required. -func paddedLength(o ndpOption) int { +func paddedLength(o NDPOption) int { l := o.Length() if l == 0 { @@ -139,7 +355,7 @@ func paddedLength(o ndpOption) int { } // NDPOptionsSerializer is a serializer for NDP options. -type NDPOptionsSerializer []ndpOption +type NDPOptionsSerializer []NDPOption // Length returns the total number of bytes required to serialize. func (b NDPOptionsSerializer) Length() int { @@ -154,19 +370,220 @@ func (b NDPOptionsSerializer) Length() int { // NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option // as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPTargetLinkLayerAddressOption tcpip.LinkAddress -// Type implements ndpOption.Type. +// Type implements NDPOption.Type. func (o NDPTargetLinkLayerAddressOption) Type() uint8 { return NDPTargetLinkLayerAddressOptionType } -// Length implements ndpOption.Length. +// Length implements NDPOption.Length. func (o NDPTargetLinkLayerAddressOption) Length() int { return len(o) } -// serializeInto implements ndpOption.serializeInto. +// serializeInto implements NDPOption.serializeInto. func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + +// NDPPrefixInformation is the NDP Prefix Information option as defined by +// RFC 4861 section 4.6.2. +// +// The length, in bytes, of a valid NDP Prefix Information option body MUST be +// ndpPrefixInformationLength bytes. +type NDPPrefixInformation []byte + +// Type implements NDPOption.Type. +func (o NDPPrefixInformation) Type() uint8 { + return NDPPrefixInformationType +} + +// Length implements NDPOption.Length. +func (o NDPPrefixInformation) Length() int { + return ndpPrefixInformationLength +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPPrefixInformation) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the Reserved1 field. + b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask + + // Zero out the Reserved2 field. + reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length] + for i := range reserved2 { + reserved2[i] = 0 + } + + return used +} + +// PrefixLength returns the value in the number of leading bits in the Prefix +// that are valid. +// +// Valid values are in the range [0, 128], but o may not always contain valid +// values. It is up to the caller to valdiate the Prefix Information option. +func (o NDPPrefixInformation) PrefixLength() uint8 { + return o[ndpPrefixInformationPrefixLengthOffset] +} + +// OnLinkFlag returns true of the prefix is considered on-link. On-link means +// that a forwarding node is not needed to send packets to other nodes on the +// same prefix. +// +// Note, when this function returns false, no statement is made about the +// on-link property of a prefix. That is, if OnLinkFlag returns false, the +// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any +// previously stored state for this prefix about its on-link status. +func (o NDPPrefixInformation) OnLinkFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0 +} + +// AutonomousAddressConfigurationFlag returns true if the prefix can be used for +// Stateless Address Auto-Configuration (as specified in RFC 4862). +func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0 +} + +// ValidLifetime returns the length of time that the prefix is valid for the +// purpose of on-link determination. This value is relative to the send time of +// the packet that the Prefix Information option was present in. +// +// Note, a value of 0 implies the prefix should not be considered as on-link, +// and a value of infinity/forever is represented by +// NDPInfiniteLifetime. +func (o NDPPrefixInformation) ValidLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) +} + +// PreferredLifetime returns the length of time that an address generated from +// the prefix via Stateless Address Auto-Configuration remains preferred. This +// value is relative to the send time of the packet that the Prefix Information +// option was present in. +// +// Note, a value of 0 implies that addresses generated from the prefix should +// no longer remain preferred, and a value of infinity is represented by +// NDPInfiniteLifetime. +// +// Also note that the value of this field MUST NOT exceed the Valid Lifetime +// field to avoid preferring addresses that are no longer valid, for the +// purpose of Stateless Address Auto-Configuration. +func (o NDPPrefixInformation) PreferredLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:])) +} + +// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix +// Length field (see NDPPrefixInformation.PrefixLength) contains the number +// of valid leading bits in the prefix. +// +// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field +// holds the link-local prefix (fe80::). +func (o NDPPrefixInformation) Prefix() tcpip.Address { + return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) +} + +// Subnet returns the Prefix field and Prefix Length field represented in a +// tcpip.Subnet. +func (o NDPPrefixInformation) Subnet() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: o.Prefix(), + PrefixLen: int(o.PrefixLength()), + } + return addrWithPrefix.Subnet() +} + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// Type implements NDPOption.Type. +func (NDPRecursiveDNSServer) Type() uint8 { + return NDPRecursiveDNSServerOptionType +} + +// Length implements NDPOption.Length. +func (o NDPRecursiveDNSServer) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, some of the addresses returned MAY be link-local addresses. +// +// Addresses may panic if o does not hold valid IPv6 addresses. +func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { + l := len(o) + if l < ndpRecursiveDNSServerAddressesOffset { + return nil + } + + l -= ndpRecursiveDNSServerAddressesOffset + if l%IPv6AddressSize != 0 { + return nil + } + + buf := o[ndpRecursiveDNSServerAddressesOffset:] + var addrs []tcpip.Address + for len(buf) > 0 { + addr := tcpip.Address(buf[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return nil + } + addrs = append(addrs, addr) + buf = buf[IPv6AddressSize:] + } + return addrs +} diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go new file mode 100644 index 000000000..bf7610863 --- /dev/null +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -0,0 +1,112 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "time" +) + +// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.2 for more details. +type NDPRouterAdvert []byte + +const ( + // NDPRAMinimumSize is the minimum size of a valid NDP Router + // Advertisement message (body of an ICMPv6 packet). + NDPRAMinimumSize = 12 + + // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field + // within an NDPRouterAdvert. + ndpRACurrHopLimitOffset = 0 + + // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags + // within an NDPRouterAdvert. + ndpRAFlagsOffset = 1 + + // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address + // Configuration flag within the bit-field/flags byte of an + // NDPRouterAdvert. + ndpRAManagedAddrConfFlagMask = (1 << 7) + + // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag + // within the bit-field/flags byte of an NDPRouterAdvert. + ndpRAOtherConfFlagMask = (1 << 6) + + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime + // field within an NDPRouterAdvert. + ndpRARouterLifetimeOffset = 2 + + // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time + // field within an NDPRouterAdvert. + ndpRAReachableTimeOffset = 4 + + // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer + // field within an NDPRouterAdvert. + ndpRARetransTimerOffset = 8 + + // ndpRAOptionsOffset is the start of the NDP options in an + // NDPRouterAdvert. + ndpRAOptionsOffset = 12 +) + +// CurrHopLimit returns the value of the Curr Hop Limit field. +func (b NDPRouterAdvert) CurrHopLimit() uint8 { + return b[ndpRACurrHopLimitOffset] +} + +// ManagedAddrConfFlag returns the value of the Managed Address Configuration +// flag. +func (b NDPRouterAdvert) ManagedAddrConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0 +} + +// OtherConfFlag returns the value of the Other Configuration flag. +func (b NDPRouterAdvert) OtherConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 +} + +// RouterLifetime returns the lifetime associated with the default router. A +// value of 0 means the source of the Router Advertisement is not a default +// router and SHOULD NOT appear on the default router list. Note, a value of 0 +// only means that the router should not be used as a default router, it does +// not apply to other information contained in the Router Advertisement. +func (b NDPRouterAdvert) RouterLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.2. + return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:])) +} + +// ReachableTime returns the time that a node assumes a neighbor is reachable +// after having received a reachability confirmation. A value of 0 means +// that it is unspecified by the source of the Router Advertisement message. +func (b NDPRouterAdvert) ReachableTime() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:])) +} + +// RetransTimer returns the time between retransmitted Neighbor Solicitation +// messages. A value of 0 means that it is unspecified by the source of the +// Router Advertisement message. +func (b NDPRouterAdvert) RetransTimer() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:])) +} + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterAdvert) Options() NDPOptions { + return NDPOptions(b[ndpRAOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index a431a6e61..2c439d70c 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -17,7 +17,9 @@ package header import ( "bytes" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -35,18 +37,18 @@ func TestNDPNeighborSolicit(t *testing.T) { ns := NDPNeighborSolicit(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := ns.TargetAddress(); got != addr { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") ns.SetTargetAddress(addr2) if got := ns.TargetAddress(); got != addr2 { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } } @@ -64,59 +66,129 @@ func TestNDPNeighborAdvert(t *testing.T) { na := NDPNeighborAdvert(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := na.TargetAddress(); got != addr { - t.Fatalf("got TargetAddress = %s, want %s", got, addr) + t.Errorf("got TargetAddress = %s, want %s", got, addr) } // Test getting the Router Flag. if got := na.RouterFlag(); !got { - t.Fatalf("got RouterFlag = false, want = true") + t.Errorf("got RouterFlag = false, want = true") } // Test getting the Solicited Flag. if got := na.SolicitedFlag(); got { - t.Fatalf("got SolicitedFlag = true, want = false") + t.Errorf("got SolicitedFlag = true, want = false") } // Test getting the Override Flag. if got := na.OverrideFlag(); !got { - t.Fatalf("got OverrideFlag = false, want = true") + t.Errorf("got OverrideFlag = false, want = true") } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") na.SetTargetAddress(addr2) if got := na.TargetAddress(); got != addr2 { - t.Fatalf("got TargetAddress = %s, want %s", got, addr2) + t.Errorf("got TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } // Test updating the Router Flag. na.SetRouterFlag(false) if got := na.RouterFlag(); got { - t.Fatalf("got RouterFlag = true, want = false") + t.Errorf("got RouterFlag = true, want = false") } // Test updating the Solicited Flag. na.SetSolicitedFlag(true) if got := na.SolicitedFlag(); !got { - t.Fatalf("got SolicitedFlag = false, want = true") + t.Errorf("got SolicitedFlag = false, want = true") } // Test updating the Override Flag. na.SetOverrideFlag(false) if got := na.OverrideFlag(); got { - t.Fatalf("got OverrideFlag = true, want = false") + t.Errorf("got OverrideFlag = true, want = false") } // Make sure flags got updated in the backing buffer. if got := b[ndpNAFlagsOffset]; got != 64 { - t.Fatalf("got flags byte = %d, want = 64") + t.Errorf("got flags byte = %d, want = 64") } } +func TestNDPRouterAdvert(t *testing.T) { + b := []byte{ + 64, 128, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + + ra := NDPRouterAdvert(b) + + if got := ra.CurrHopLimit(); got != 64 { + t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) + } + + if got := ra.ManagedAddrConfFlag(); !got { + t.Errorf("got ManagedAddrConfFlag = false, want = true") + } + + if got := ra.OtherConfFlag(); got { + t.Errorf("got OtherConfFlag = true, want = false") + } + + if got, want := ra.RouterLifetime(), time.Second*258; got != want { + t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { + t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { + t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + } +} + +// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "TLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "TLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tll := NDPTargetLinkLayerAddressOption(test.buf) + if got := tll.EthernetAddress(); got != test.expected { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) + } + +} + // TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a // NDPTargetLinkLayerAddressOption. func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { @@ -159,6 +231,555 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { if !bytes.Equal(test.buf, test.expectedBuf) { t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + tll := next.(NDPTargetLinkLayerAddressOption) + if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got tll.MACAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPPrefixInformationOption tests the field getters and serialization of a +// NDPPrefixInformation. +func TestNDPPrefixInformationOption(t *testing.T) { + b := []byte{ + 43, 127, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 5, 5, 5, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPPrefixInformation(b), + } + opts.Serialize(serializer) + expectedBuf := []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + if !bytes.Equal(targetBuf, expectedBuf) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + pi := next.(NDPPrefixInformation) + + if got := pi.Type(); got != 3 { + t.Errorf("got Type = %d, want = 3", got) + } + + if got := pi.Length(); got != 30 { + t.Errorf("got Length = %d, want = 30", got) + } + + if got := pi.PrefixLength(); got != 43 { + t.Errorf("got PrefixLength = %d, want = 43", got) + } + + if pi.OnLinkFlag() { + t.Error("got OnLinkFlag = true, want = false") + } + + if !pi.AutonomousAddressConfigurationFlag() { + t.Error("got AutonomousAddressConfigurationFlag = false, want = true") + } + + if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { + t.Errorf("got ValidLifetime = %d, want = %d", got, want) + } + + if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { + t.Errorf("got PreferredLifetime = %d, want = %d", got, want) + } + + if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { + t.Errorf("got Prefix = %s, want = %s", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 25, 3, 0, 0, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPRecursiveDNSServer(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Type(); got != 25 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16909320*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + want := []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + } + if got := opt.Addresses(); !cmp.Equal(got, want) { + t.Errorf("got Addresses = %v, want = %v", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + addrs []tcpip.Address + }{ + { + "Valid1Addr", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + }, + }, + { + "Valid2Addr", + []byte{ + 25, 5, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + }, + }, + { + "Valid3Addr", + []byte{ + 25, 7, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Iterator should get our option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { + t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } }) } } + +// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions +// the iterator was returned for is malformed. +func TestNDPOptionsIterCheck(t *testing.T) { + tests := []struct { + name string + buf []byte + expected error + }{ + { + "ZeroLengthField", + []byte{0, 0, 0, 0, 0, 0, 0, 0}, + ErrNDPOptZeroLength, + }, + { + "ValidTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5, 6}, + nil, + }, + { + "TooSmallTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5}, + ErrNDPOptBufExhausted, + }, + { + "ValidPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "TooSmallPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, + }, + ErrNDPOptBufExhausted, + }, + { + "InvalidPrefixInformationLength", + []byte{ + 3, 3, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + }, + ErrNDPOptMalformedBody, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformation", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up + // being the type for some recognized type, + // update 255 to some other unrecognized value. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "InvalidRecursiveDNSServerCutsOffAddress", + []byte{ + 25, 4, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7, + }, + ErrNDPOptMalformedBody, + }, + { + "InvalidRecursiveDNSServerInvalidLengthField", + []byte{ + 25, 2, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, + }, + ErrNDPInvalidLength, + }, + { + "RecursiveDNSServerTooSmall", + []byte{ + 25, 1, 0, 0, + 0, 0, 0, + }, + ErrNDPOptBufExhausted, + }, + { + "RecursiveDNSServerMulticast", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + ErrNDPOptMalformedBody, + }, + { + "RecursiveDNSServerUnspecified", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + ErrNDPOptMalformedBody, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + + if _, err := opts.Iter(true); err != test.expected { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) + } + + // test.buf may be malformed but we chose not to check + // the iterator so it must return true. + if _, err := opts.Iter(false); err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } + }) + } +} + +// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, +// this test does not actually check any of the option's getters, it simply +// checks the option Type and Body. We have other tests that tests the option +// field gettings given an option body and don't need to duplicate those tests +// here. +func TestNDPOptionsIter(t *testing.T) { + buf := []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up being the type + // for some recognized type, update 255 to some other + // unrecognized value. Note, this option should be skipped when + // iterating. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + opts := NDPOptions(buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Test the first (Taret Link-Layer) option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + + // Test the next (Prefix Information) option. + // Note, the unrecognized option should be skipped. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 97a794986..7dbc05754 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -6,7 +6,7 @@ go_library( name = "channel", srcs = ["channel.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 18adb2085..70188551f 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -25,10 +25,9 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Header buffer.View - Payload buffer.View - Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO + Pkt tcpip.PacketBuffer + Proto tcpip.NetworkProtocolNumber + GSO *stack.GSO } // Endpoint is link layer endpoint that stores outbound packets in a channel @@ -65,14 +64,14 @@ func (e *Endpoint) Drain() int { } } -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) +// InjectInbound injects an inbound packet. +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil)) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -96,7 +95,7 @@ func (e *Endpoint) MTU() uint32 { func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { caps := stack.LinkEndpointCapabilities(0) if e.GSO { - caps |= stack.CapabilityGSO + caps |= stack.CapabilityHardwareGSO } return caps } @@ -118,12 +117,55 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: payload.ToView(), - GSO: gso, + Pkt: pkt, + Proto: protocol, + GSO: gso, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + payloadView := pkts[0].Data.ToView() + n := 0 +packetLoop: + for _, pkt := range pkts { + off := pkt.DataOffset + size := pkt.DataSize + p := PacketInfo{ + Pkt: tcpip.PacketBuffer{ + Header: pkt.Header, + Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), + }, + Proto: protocol, + GSO: gso, + } + + select { + case e.C <- p: + n++ + default: + break packetLoop + } + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + p := PacketInfo{ + Pkt: tcpip.PacketBuffer{Data: vv}, + Proto: 0, + GSO: nil, } select { diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 8fa9e3984..897c94821 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,9 +14,7 @@ go_library( "packet_dispatchers.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f80ac3435..fa8a703d9 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -165,6 +165,9 @@ type Options struct { // disabled. GSOMaxSize uint32 + // SoftwareGSOEnabled indicates whether software GSO is enabled or not. + SoftwareGSOEnabled bool + // PacketDispatchMode specifies the type of inbound dispatcher to be // used for this endpoint. PacketDispatchMode PacketDispatchMode @@ -242,7 +245,11 @@ func New(opts *Options) (stack.LinkEndpoint, error) { } if isSocket { if opts.GSOMaxSize != 0 { - e.caps |= stack.CapabilityGSO + if opts.SoftwareGSOEnabled { + e.caps |= stack.CapabilitySoftwareGSO + } else { + e.caps |= stack.CapabilityHardwareGSO + } e.gsoMaxSize = opts.GSOMaxSize } } @@ -379,10 +386,11 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -397,17 +405,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen eth.Encode(ethHdr) } - if e.Capabilities()&stack.CapabilityGSO != 0 { + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { - vnetHdr.hdrLen = uint16(hdr.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -420,18 +428,151 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + } + + if pkt.Data.Size() == 0 { + return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View()) } - if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) + return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil) +} + +// WritePackets writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var ethHdrBuf []byte + // hdr + data + iovLen := 2 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ } - return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) + n := len(pkts) + + views := pkts[0].Data.Views() + /* + * Each bondary in views can add one more iovec. + * + * payload | | | | + * ----------------------------- + * packets | | | | | | | + * ----------------------------- + * iovecs | | | | | | | | | + */ + iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) + mmsgHdrs := make([]rawfile.MMsgHdr, n) + + iovecIdx := 0 + viewIdx := 0 + viewOff := 0 + off := 0 + nextOff := 0 + for i := range pkts { + // TODO(b/134618279): Different packets may have different data + // in the future. We should handle this. + if !viewsEqual(pkts[i].Data.Views(), views) { + panic("All packets in pkts should have the same Data.") + } + + prevIovecIdx := iovecIdx + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovec[iovecIdx] + packetSize := pkts[i].DataSize + hdr := &pkts[i].Header + + off = pkts[i].DataOffset + if off != nextOff { + // We stop in a different point last time. + size := packetSize + viewIdx = 0 + viewOff = 0 + for size > 0 { + if size >= len(views[viewIdx]) { + viewIdx++ + viewOff = 0 + size -= len(views[viewIdx]) + } else { + viewOff = size + size = 0 + } + } + } + nextOff = off + packetSize + + if ethHdrBuf != nil { + v := &iovec[iovecIdx] + v.Base = ðHdrBuf[0] + v.Len = uint64(len(ethHdrBuf)) + iovecIdx++ + } + + v := &iovec[iovecIdx] + hdrView := hdr.View() + v.Base = &hdrView[0] + v.Len = uint64(len(hdrView)) + iovecIdx++ + + for packetSize > 0 { + vec := &iovec[iovecIdx] + iovecIdx++ + + v := views[viewIdx] + vec.Base = &v[viewOff] + s := len(v) - viewOff + if s <= packetSize { + viewIdx++ + viewOff = 0 + } else { + s = packetSize + viewOff += s + } + vec.Len = uint64(s) + packetSize -= s + } + + mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + } + + packets := 0 + for packets < n { + sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } + return packets, nil +} + +// viewsEqual tests whether v1 and v2 refer to the same backing bytes. +func viewsEqual(vs1, vs2 []buffer.View) bool { + return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) } -// WriteRawPacket writes a raw packet directly to the file descriptor. -func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } @@ -468,9 +609,9 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher } -// Inject injects an inbound packet. -func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound injects an inbound packet. +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 04406bc9a..2066987eb 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents buffer.View + contents tcpip.PacketBuffer } type context struct { @@ -92,8 +92,8 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - c.ch <- packetInfo{remote, protocol, vv.ToView()} +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + c.ch <- packetInfo{remote, protocol, pkt} } func TestNoEthernetProperties(t *testing.T) { @@ -168,7 +168,10 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -258,7 +261,10 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.VectorisedView{}, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -293,11 +299,12 @@ func TestDeliverPacket(t *testing.T) { b[i] = uint8(rand.Intn(256)) } + var hdr header.Ethernet if !eth { // So that it looks like an IPv4 packet. b[0] = 0x40 } else { - hdr := make(header.Ethernet, header.EthernetMinimumSize) + hdr = make(header.Ethernet, header.EthernetMinimumSize) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, @@ -315,14 +322,21 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: b, + raddr: raddr, + proto: proto, + contents: tcpip.PacketBuffer{ + Data: buffer.View(b).ToVectorisedView(), + LinkHeader: buffer.View(hdr), + }, } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } + // want.contents.Data will be a single + // view, so make pi do the same for the + // DeepEqual check. + pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 8bfeb97e4..62ed1e569 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -169,9 +169,10 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(pkt) + eth = header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -189,6 +190,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + Data: buffer.View(pkt).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 7ca217e5b..c67d684ce 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -53,7 +53,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} d.views = make([]buffer.View, len(BufConfig)) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { iovLen++ } d.iovecs = make([]syscall.Iovec, iovLen) @@ -63,7 +63,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { func (d *readVDispatcher) allocateViews(bufConfig []int) { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -106,7 +106,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { if err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // Skip virtioNetHdr which is added before each packet, it // isn't used and it isn't in a view. n -= virtioNetHdrSize @@ -118,9 +118,10 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[0]) + eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,10 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) - vv.TrimFront(d.e.hdrSize) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -194,7 +198,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { } d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // virtioNetHdr is prepended before each packet. iovLen++ } @@ -225,7 +229,7 @@ func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { for k := 0; k < len(d.views); k++ { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -261,7 +265,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } if n <= d.e.hdrSize { @@ -271,9 +275,10 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[k][0]) + eth = header.Ethernet(d.views[k][0]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -291,9 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) - vv.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index 47a54845c..f35fcdff4 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -6,10 +6,11 @@ go_library( name = "loopback", srcs = ["loopback.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index b36629d2c..499cc608f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -23,6 +23,7 @@ package loopback import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,21 +71,45 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) - // Because we're immediately turning around and writing the packet back to the - // rx path, we intentionally don't preserve the remote and local link - // addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) return nil } -// Wait implements stack.LinkEndpoint.Wait. -func (*endpoint) Wait() {} +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { + return tcpip.ErrBadAddress + } + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + vv.TrimFront(len(linkHeader)) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + Data: vv, + LinkHeader: buffer.View(linkHeader), + }) + + return nil +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1bab380b0..1ac7948b6 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -7,9 +7,7 @@ go_library( name = "muxed", srcs = ["injectable.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 7c946101d..445b22c17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -79,29 +79,47 @@ func (m *InjectableEndpoint) IsAttached() bool { return m.dispatcher != nil } -// Inject implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound implements stack.InjectableLinkEndpoint. +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) +} + +// WritePackets writes outbound packets to the appropriate +// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if +// r.RemoteAddress has a route registered in this endpoint. +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + endpoint, ok := m.routes[r.RemoteAddress] + if !ok { + return 0, tcpip.ErrNoRoute + } + return endpoint.WritePackets(r, gso, pkts, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol) + return endpoint.WritePacket(r, gso, protocol, pkt) } return tcpip.ErrNoRoute } -// WriteRawPacket writes outbound packets to the appropriate +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // WriteRawPacket doesn't get a route or network address, so there's + // nowhere to write this. + return tcpip.ErrNoRoute +} + +// InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { endpoint, ok := m.routes[dest] if !ok { return tcpip.ErrNoRoute } - return endpoint.WriteRawPacket(dest, packet) + return endpoint.InjectOutbound(dest, packet) } // Wait implements stack.LinkEndpoint.Wait. diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 3086fec00..63b249837 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -31,7 +31,7 @@ import ( func TestInjectableEndpointRawDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - endpoint.WriteRawPacket(dstIP, []byte{0xFA}) + endpoint.InjectOutbound(dstIP, []byte{0xFA}) buf := make([]byte, ipv4.MaxTotalSize) bytesRead, err := sock.Read(buf) @@ -50,8 +50,10 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -68,8 +70,10 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewView(0).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 2e8bc772a..d8211e93d 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -13,8 +13,9 @@ go_library( "rawfile_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", - visibility = [ - "//visibility:public", + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "@org_golang_x_sys//unix:go_default_library", ], - deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index dda3b10a6..0b5a6cf49 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 7e286a3a6..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -22,6 +22,7 @@ import ( "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -101,6 +102,16 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { return nil } +// NonBlockingSendMMsg sends multiple messages on a socket. +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { + n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) + if e != 0 { + return 0, TranslateErrno(e) + } + + return int(n), nil +} + // PollEvent represents the pollfd structure passed to a poll() system call. type PollEvent struct { FD int32 diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 0a5ea3dc4..a4f9cdd69 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -12,9 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", - visibility = [ - "//:sandbox", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 330ed5e94..6b5bc542c 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -12,7 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index de1ce043d..8c9234d54 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -10,7 +10,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip/link/sharedmem/pipe", diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 9e71d4edf..080f9d667 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,9 +185,10 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { // Add the ethernet header here. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -199,10 +200,30 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa } eth.Encode(ethHdr) - v := payload.ToView() + v := pkt.Data.ToView() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(hdr.View(), v) + ok := e.tx.transmit(pkt.Header.View(), v) + e.mu.Unlock() + + if !ok { + return tcpip.ErrWouldBlock + } + + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + v := vv.ToView() + // Transmit the packet. + e.mu.Lock() + ok := e.tx.transmit(v, buffer.View{}) e.mu.Unlock() if !ok { @@ -253,8 +274,11 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { } // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView()) + eth := header.Ethernet(b[:header.EthernetMinimumSize]) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 0e9ba0846..89603c48f 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -78,9 +78,10 @@ func (q *queueBuffers) cleanup() { } type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView + addr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + vv buffer.VectorisedView + linkHeader buffer.View } type testContext struct { @@ -130,12 +131,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, proto: proto, - vv: vv.Clone(nil), + vv: pkt.Data.Clone(nil), }) c.mu.Unlock() @@ -272,7 +273,10 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -341,7 +345,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -395,7 +401,10 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -410,7 +419,10 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -435,7 +447,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -455,7 +470,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -470,7 +488,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -493,7 +514,10 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -509,7 +533,10 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -534,7 +561,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -547,7 +577,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: uu, + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -555,7 +588,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 1756114e6..d6ae0368a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -9,9 +9,7 @@ go_library( "sniffer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index e401dce44..3392b7edd 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -49,6 +49,13 @@ var LogPackets uint32 = 1 // LogPacketsToFile must be accessed atomically. var LogPacketsToFile uint32 = 1 +var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{ + header.ICMPv4ProtocolNumber: header.IPv4MinimumSize, + header.ICMPv6ProtocolNumber: header.IPv6MinimumSize, + header.UDPProtocolNumber: header.UDPMinimumSize, + header.TCPProtocolNumber: header.TCPMinimumSize, +} + type endpoint struct { dispatcher stack.NetworkDispatcher lower stack.LinkEndpoint @@ -116,19 +123,19 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First(), nil) + logPacket("recv", protocol, pkt.Data.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := vv.Views() - length := vv.Size() + vs := pkt.Data.Views() + length := pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil { panic(err) } for _, v := range vs { @@ -147,7 +154,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local panic(err) } } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher @@ -193,22 +200,19 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket implements the stack.LinkEndpoint interface. It is called by -// higher-level protocols to write packets; it just logs the packet and forwards -// the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View(), gso) + logPacket("send", protocol, pkt.Header.View(), gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - hdrBuf := hdr.View() - length := len(hdrBuf) + payload.Size() + hdrBuf := pkt.Header.View() + length := len(hdrBuf) + pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil { panic(err) } if len(hdrBuf) > length { @@ -218,26 +222,75 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen panic(err) } length -= len(hdrBuf) - if length > 0 { - for _, v := range payload.Views() { - if len(v) > length { - v = v[:length] - } - n, err := buf.Write(v) - if err != nil { - panic(err) - } - length -= n - if length == 0 { - break - } - } + logVectorisedView(pkt.Data, length, buf) + if _, err := e.file.Write(buf.Bytes()); err != nil { + panic(err) } + } +} + +// WritePacket implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + e.dumpPacket(gso, protocol, pkt) + return e.lower.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + view := pkts[0].Data.ToView() + for _, pkt := range pkts { + e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), + }) + } + return e.lower.WritePackets(r, gso, pkts, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { + logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + length := vv.Size() + if length > int(e.maxPCAPLen) { + length = int(e.maxPCAPLen) + } + + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + panic(err) + } + logVectorisedView(vv, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } } - return e.lower.WritePacket(r, gso, hdr, payload, protocol) + return e.lower.WriteRawPacket(vv) +} + +func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { + if length <= 0 { + return + } + for _, v := range vv.Views() { + if len(v) > length { + v = v[:length] + } + n, err := buf.Write(v) + if err != nil { + panic(err) + } + length -= n + if length == 0 { + return + } + } } // Wait implements stack.LinkEndpoint.Wait. @@ -287,6 +340,13 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie return } + // We aren't guaranteed to have a transport header - it's possible for + // writes via raw endpoints to contain only network headers. + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) + return + } + // Figure out the transport layer info. transName := "unknown" srcPort := uint16(0) diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 92dce8fac..a71a493fc 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -6,7 +6,5 @@ go_library( name = "tun", srcs = ["tun_unsafe.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0746dc8ec..134837943 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -9,9 +9,7 @@ go_library( "waitable.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/gate", "//pkg/tcpip", diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 5a1791cb5..a8de38979 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) e.dispatchGate.Leave() } @@ -99,12 +99,36 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, hdr, payload, protocol) + err := e.lower.WritePacket(r, gso, protocol, pkt) + e.writeGate.Leave() + return err +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if !e.writeGate.Enter() { + return len(pkts), nil + } + + n, err := e.lower.WritePackets(r, gso, pkts, protocol) + e.writeGate.Leave() + return n, err +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WriteRawPacket(vv) e.writeGate.Leave() return err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index ae23c96b7..31b11a27a 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.dispatchCount++ } @@ -65,7 +65,18 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + e.writeCount++ + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += len(pkts) + return len(pkts), nil +} + +func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { e.writeCount++ return nil } @@ -78,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -109,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index df0d3a8c0..e7617229b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -7,9 +7,7 @@ go_library( name = "arp", srcs = ["arp.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6b1e854dc..da8482509 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,7 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -58,7 +58,7 @@ func (e *endpoint) MTU() uint32 { } func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { @@ -79,16 +79,21 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return @@ -97,23 +102,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - pkt := header.ARP(hdr.Prepend(header.ARPSize)) - pkt.SetIPv4OverEthernet() - pkt.SetOp(header.ARPReply) - copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + packet := header.ARP(hdr.Prepend(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) } } @@ -130,12 +137,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } return &endpoint{ - nicid: nicid, + nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, }, nil @@ -160,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 88b57ec03..8e6048a21 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -102,19 +102,21 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + Data: v.ToVectorisedView(), + }) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + pi := <-c.linkEP.C + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pkt.Header) + rep := header.ARP(pi.Pkt.Header.View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 2cad0a0b6..acf1e022c 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -25,7 +25,7 @@ go_library( "reassembler_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", @@ -44,11 +44,3 @@ go_test( embed = [":fragmentation"], deps = ["//pkg/tcpip/buffer"], ) - -filegroup( - name = "autogen", - srcs = [ - "reassembler_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f644a8b08..4144a7837 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -150,27 +150,36 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(hdr.View()) + h := header.IPv4(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(hdr.View()) + h := header.IPv6(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, payload, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + return tcpip.ErrNotSupported +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, @@ -230,7 +239,10 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -270,7 +282,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -358,7 +372,9 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: vv, + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -421,13 +437,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -460,7 +480,10 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -500,7 +523,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -510,6 +535,7 @@ func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } const mtu = 0xffff + const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" cases := []struct { name string expectedCount int @@ -561,7 +587,7 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + SrcAddr: outerSrcAddr, DstAddr: localIpv6Addr, }) @@ -608,8 +634,12 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + // Set ICMPv6 checksum. + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 58e537aad..aeddfcdd4 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,9 +10,7 @@ go_library( "ipv4.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 50b363dc4..32bf39e43 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -24,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) + pkt.Data.TrimFront(hlen) p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return @@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -95,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { sent.Dropped.Increment() return } @@ -104,19 +112,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 5cd895ff0..e645cf62c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -47,7 +47,7 @@ const ( ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint @@ -57,9 +57,9 @@ type endpoint struct { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, @@ -89,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv4 endpoint ID. @@ -117,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -137,71 +138,85 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + originalAvailableLength := pkt.Header.AvailableLength() for i := 0; i < n; i++ { // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. h := ip if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) copy(h, ip[:ip.HeaderLength()]) } if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) offset += uint16(innerMTU) if i > 0 { - newPayload := payload.Clone([]buffer.View{}) + newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) + pkt.Data.TrimFront(newPayload.Size()) continue } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) + pkt.Data.TrimFront(newPayloadLength) } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + tmp.Append(pkt.Data) + pkt.Data = tmp } } return nil } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payload.Size()) + length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be @@ -219,41 +234,71 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() return nil } +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("multiple packets in local loop") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + ip := header.IPv4(pkt.Data.First()) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -267,7 +312,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) @@ -280,35 +325,39 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + e.HandlePacket(r, pkt.Clone()) } if loop&stack.PacketOut == 0 { return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { - if vv.Size() == 0 { + if pkt.Data.Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -316,10 +365,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and vv.size() causes a wrap - // around resulting in last being less than the offset. + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. if last < h.FragmentOffset() { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -327,7 +376,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } var ready bool var err error - vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -340,11 +389,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 85ab0e3bc..e900f1b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -47,10 +47,6 @@ func TestExcludeBroadcast(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: 1, @@ -117,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source = append(source, sourcePacketInfo.Data.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -136,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes.Append(packet.Data) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -177,7 +173,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan tcpip.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -187,17 +183,11 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan packetInfo, size), + Ch: make(chan tcpip.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView -} - // Drain removes all outbound packets from the channel and counts them. func (e *errorChannel) Drain() int { c := 0 @@ -212,14 +202,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -296,18 +281,21 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ + source := tcpip.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), + Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []tcpip.PacketBuffer L: for { select { @@ -349,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -455,7 +446,7 @@ func TestInvalidFragments(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - const nicid tcpip.NICID = 42 + const nicID tcpip.NICID = 42 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ ipv4.NewProtocol(), @@ -465,10 +456,12 @@ func TestInvalidFragments(t *testing.T) { var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicid, sniffer.New(ep)) + s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt})) + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f06622a8b..e4e273460 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -10,9 +10,7 @@ go_library( "ipv6.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7b638e9d0..1c3410618 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -21,21 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -const ( - // ndpHopLimit is the expected IP hop limit value of 255 for received - // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, - // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently - // drop the NDP packet. All outgoing NDP packets must use this value for - // its IP hop limit field. - ndpHopLimit = 255 -) - // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -49,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) + pkt.Data.TrimFront(header.IPv6MinimumSize) p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) + f := header.IPv6Fragment(pkt.Data.First()) if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. @@ -61,19 +52,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return @@ -81,16 +72,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V h := header.ICMPv6(v) iph := header.IPv6(netHeader) + // Validate ICMPv6 checksum before processing the packet. + // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. + // This copy is used as extra payload during the checksum calculation. + payload := pkt.Data + payload.RemoveFirst() + if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + received.Invalid.Increment() + return + } + // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255. + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. switch h.Type() { case header.ICMPv6NeighborSolicit, header.ICMPv6NeighborAdvert, header.ICMPv6RouterSolicit, header.ICMPv6RouterAdvert, header.ICMPv6RedirectMsg: - if iph.HopLimit() != ndpHopLimit { + if iph.HopLimit() != header.NDPHopLimit { + received.Invalid.Increment() + return + } + + if h.Code() != 0 { received.Invalid.Increment() return } @@ -104,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -114,10 +123,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch h.Code() { case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: @@ -171,7 +180,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // rxNICID so the packet is processed as defined in RFC 4861, // as per RFC 4862 section 5.4.3. - if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } @@ -180,9 +189,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), } hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) @@ -200,7 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // TODO(tamird/ghanan): there exists an explicit NDP option that is // used to update the neighbor table with link addresses for a @@ -209,7 +218,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // // Furthermore, the entirety of NDP handling here seems to be // contradicted by RFC 4861. - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // @@ -217,7 +226,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ndpHopLimit, TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return } @@ -255,19 +266,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V return } - // At this point we know that the targetAddress is not tentaive + // At this point we know that the targetAddress is not tentative // on rxNICID. However, targetAddr may still be assigned to // rxNICID but not tentative (it could be permanent). Such a // scenario is beyond the scope of RFC 4862. As such, we simply // ignore such a scenario for now and proceed as normal. // - // TODO(b/140896005): Handle the scenario described above - // (inform the netstack integration that a duplicate address was - // was detected) + // TODO(b/143147598): Handle the scenario described above. Also + // inform the netstack integration that a duplicate address was + // detected outside of DAD. - e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) } case header.ICMPv6EchoRequest: @@ -276,13 +287,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6EchoMinimumSize) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, h) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: pkt.Data, + }); err != nil { sent.Dropped.Increment() return } @@ -294,7 +308,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -306,8 +320,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.RouterSolicit.Increment() case header.ICMPv6RouterAdvert: + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + p := h.NDPPayload() + + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if len(p) < header.NDPRAMinimumSize { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + ra := header.NDPRouterAdvert(p) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + received.RouterAdvert.Increment() + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() @@ -359,13 +416,15 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: ndpHopLimit, + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index dd3c4d7c4..335f634d5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -30,7 +30,7 @@ import ( ) const ( - linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) @@ -55,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { return nil } @@ -65,7 +65,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { } type stubLinkAddressCache struct { @@ -131,7 +131,7 @@ func TestICMPCounts(t *testing.T) { {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, + {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, @@ -143,11 +143,13 @@ func TestICMPCounts(t *testing.T) { ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: ndpHopLimit, + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } for _, typ := range types { @@ -274,20 +276,22 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pkt := <-args.src.C + pi := <-args.src.C { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) + views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} + size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + Data: vv, + }) } - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) return } - ipv6 := header.IPv6(pkt.Header) + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -359,3 +363,537 @@ func TestLinkResolution(t *testing.T) { routeICMPv6Packet(t, args, nil) } } + +func TestICMPChecksumValidationSimple(t *testing.T) { + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + }, + { + "RouterSolicit", + header.ICMPv6RouterSolicit, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + "RouterAdvert", + header.ICMPv6RouterAdvert, + header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + "NeighborSolicit", + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborSolicitMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + "NeighborAdvert", + header.ICMPv6NeighborAdvert, + header.ICMPv6NeighborAdvertSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + "RedirectMsg", + header.ICMPv6RedirectMsg, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayload(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + icmpSize := size + payloadSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(typ) + payloadFn(pkt.Payload()) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + + payload := buffer.NewView(payloadSize) + payloadFn(payload) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size + payloadSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index cd1e34085..e13f1fabf 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -43,7 +43,7 @@ const ( ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint @@ -65,7 +65,7 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv6 endpoint ID. @@ -97,9 +97,8 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - length := uint16(hdr.UsedLength() + payload.Size()) +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { + length := uint16(hdr.UsedLength() + payloadSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, @@ -109,14 +108,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -124,36 +135,58 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("not implemented") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + hdr := &pkts[i].Header + size := pkts[i].DataSize + ip := e.addIPHeader(r, hdr, size, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { - // TODO(b/119580726): Support IPv6 header-included packets. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) + e.handleICMP(r, headerView, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. @@ -188,9 +221,9 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index deaa9b7f3..1cbfa7278 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -55,7 +55,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stats := s.Stats().ICMP.V6PacketsReceived @@ -111,7 +113,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := s.Stats().UDP.PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index e30791fe3..0dbce14a0 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) @@ -97,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + ep.HandlePacket(r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } types := []struct { @@ -109,7 +112,7 @@ func TestHopLimitValidation(t *testing.T) { {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }}, {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { @@ -150,7 +153,7 @@ func TestHopLimitValidation(t *testing.T) { // Receive the NDP packet with an invalid hop limit // value. - handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) // Invalid count should have increased. if got := invalid.Value(); got != 1 { @@ -164,7 +167,7 @@ func TestHopLimitValidation(t *testing.T) { } // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, ndpHopLimit, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) // Rx count of NDP packet of type typ.typ should have // increased. @@ -179,3 +182,191 @@ func TestHopLimitValidation(t *testing.T) { }) } } + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code uint8 + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + } + }) + } +} diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go new file mode 100644 index 000000000..ab24372e7 --- /dev/null +++ b/pkg/tcpip/packet_buffer.go @@ -0,0 +1,67 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // DataOffset is used for GSO output. It is the offset into the Data + // field where the payload of this packet starts. + DataOffset int + + // DataSize is used for GSO output. It is the size of this packet's + // payload. + DataSize int + + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. +func (pk PacketBuffer) Clone() PacketBuffer { + pk.Data = pk.Data.Clone(nil) + return pk +} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go new file mode 100644 index 000000000..ad3cc24fa --- /dev/null +++ b/pkg/tcpip/packet_buffer_state.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.Header = pk.Header.DeepCopy() + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 11efb4e44..e156b01f6 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,7 @@ go_library( name = "ports", srcs = ["ports.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 30cea8996..6c5e19e8f 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -41,6 +41,30 @@ type portDescriptor struct { port uint16 } +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool +} + +func (f Flags) bits() reuseFlag { + var rf reuseFlag + if f.MostRecent { + rf |= mostRecentFlag + } + if f.LoadBalanced { + rf |= loadBalancedFlag + } + return rf +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -54,9 +78,59 @@ type PortManager struct { hint uint32 } +type reuseFlag int + +const ( + mostRecentFlag reuseFlag = 1 << iota + loadBalancedFlag + nextFlag + + flagMask = nextFlag - 1 +) + type portNode struct { - reuse bool - refs int + // refs stores the count for each possible flag combination. + refs [nextFlag]int +} + +func (p portNode) totalRefs() int { + var total int + for _, r := range p.refs { + total += r + } + return total +} + +// flagRefs returns the number of references with all specified flags. +func (p portNode) flagRefs(flags reuseFlag) int { + var total int + for i, r := range p.refs { + if reuseFlag(i)&flags == flags { + total += r + } + } + return total +} + +// allRefsHave returns if all references have all specified flags. +func (p portNode) allRefsHave(flags reuseFlag) bool { + for i, r := range p.refs { + if reuseFlag(i)&flags == flags && r > 0 { + return false + } + } + return true +} + +// intersectionRefs returns the set of flags shared by all references. +func (p portNode) intersectionRefs() reuseFlag { + intersection := flagMask + for i, r := range p.refs { + if r > 0 { + intersection &= reuseFlag(i) + } + } + return intersection } // deviceNode is never empty. When it has no elements, it is removed from the @@ -66,30 +140,44 @@ type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all portNodes. If binding to a specific device, check // against the unspecified device and the provided device. -func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { +// +// If either of the port reuse flags is enabled on any of the nodes, all nodes +// sharing a port must share at least one reuse flag. This matches Linux's +// behavior. +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { + flagBits := flags.bits() if bindToDevice == 0 { // Trying to binding all devices. - if !reuse { + if flagBits == 0 { // Can't bind because the (addr,port) is already bound. return false } + intersection := flagMask for _, p := range d { - if !p.reuse { - // Can't bind because the (addr,port) was previously bound without reuse. + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { + // Can't bind because the (addr,port) was + // previously bound without reuse. return false } } return true } + intersection := flagMask + if p, ok := d[0]; ok { - if !reuse || !p.reuse { + intersection = p.intersectionRefs() + if intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - if !reuse || !p.reuse { + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { return false } } @@ -103,12 +191,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -117,14 +205,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -190,17 +278,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice) { return false } } @@ -212,14 +300,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -227,15 +315,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } + flagBits := flags.bits() // Reserve port on all network protocols. for _, network := range networks { @@ -250,12 +339,9 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - if n, ok := d[bindToDevice]; ok { - n.refs++ - d[bindToDevice] = n - } else { - d[bindToDevice] = portNode{reuse: reuse, refs: 1} - } + n := d[bindToDevice] + n.refs[flagBits]++ + d[bindToDevice] = n } return true @@ -263,10 +349,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() + flagBits := flags.bits() + for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -278,9 +366,9 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n.refs-- + n.refs[flagBits]-- d[bindToDevice] = n - if n.refs == 0 { + if n.refs == [nextFlag]int{} { delete(d, bindToDevice) } if len(d) == 0 { diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 19f4833fc..d6969d050 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -33,7 +33,7 @@ type portReserveTestAction struct { port uint16 ip tcpip.Address want *tcpip.Error - reuse bool + flags Flags release bool device tcpip.NICID } @@ -50,7 +50,7 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, }, }, { @@ -61,7 +61,7 @@ func TestPortReservation(t *testing.T) { /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -71,36 +71,36 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 00, ip: fakeIPAddress, want: nil}, {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to ip with reuseport", actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: true, want: nil}, + {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to inaddr any with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, release: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, }, }, { tname: "bind twice with device fails", @@ -125,88 +125,152 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { - tname: "bind with reuse", + tname: "bind with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { - tname: "binding with reuse and device", + tname: "binding with reuseport and device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, }, }, { - tname: "mixing reuse and not reuse by binding to device", + tname: "mixing reuseport and not reuseport by binding to device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: nil}, }, }, { - tname: "can't bind to 0 after mixing reuse and not reuse", + tname: "can't bind to 0 after mixing reuseport and not reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, }, }, { - tname: "bind twice with reuse once", + tname: "bind twice with reuseport once", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "release an unreserved device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, // Release all. - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, + }, + }, { + tname: "bind with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind twice with reuseaddr once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseport, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, }, } { @@ -216,12 +280,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index a57752a7c..d7496fde6 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_connect", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index dad8ef399..875561566 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_echo", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index 29b7d761c..b31ddba2f 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -6,7 +6,5 @@ go_library( name = "seqnum", srcs = ["seqnum.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bfc03e90b..69077669a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -31,9 +31,7 @@ go_library( "transport_demuxer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/ilist", "//pkg/rand", @@ -73,6 +71,7 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) @@ -86,11 +85,3 @@ go_test( "//pkg/tcpip", ], ) - -filegroup( - name = "autogen", - srcs = [ - "linkaddrentry_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index bed60d7b1..90664ba8a 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -38,6 +38,34 @@ const ( // Default = 1s (from RFC 4861 section 10). defaultRetransmitTimer = time.Second + // defaultHandleRAs is the default configuration for whether or not to + // handle incoming Router Advertisements as a host. + // + // Default = true. + defaultHandleRAs = true + + // defaultDiscoverDefaultRouters is the default configuration for + // whether or not to discover default routers from incoming Router + // Advertisements, as a host. + // + // Default = true. + defaultDiscoverDefaultRouters = true + + // defaultDiscoverOnLinkPrefixes is the default configuration for + // whether or not to discover on-link prefixes from incoming Router + // Advertisements' Prefix Information option, as a host. + // + // Default = true. + defaultDiscoverOnLinkPrefixes = true + + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + // minimumRetransmitTimer is the minimum amount of time to wait between // sending NDP Neighbor solicitation messages. Note, RFC 4861 does // not impose a minimum Retransmit Timer, but we do here to make sure @@ -49,8 +77,117 @@ const ( // // Min = 1ms. minimumRetransmitTimer = time.Millisecond + + // MaxDiscoveredDefaultRouters is the maximum number of discovered + // default routers. The stack should stop discovering new routers after + // discovering MaxDiscoveredDefaultRouters routers. + // + // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and + // SHOULD be more. + // + // Max = 10. + MaxDiscoveredDefaultRouters = 10 + + // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered + // on-link prefixes. The stack should stop discovering new on-link + // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link + // prefixes. + // + // Max = 10. + MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 ) +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour +) + +// NDPDispatcher is the interface integrators of netstack must implement to +// receive and handle NDP related events. +type NDPDispatcher interface { + // OnDuplicateAddressDetectionStatus will be called when the DAD process + // for an address (addr) on a NIC (with ID nicID) completes. resolved + // will be set to true if DAD completed successfully (no duplicate addr + // detected); false otherwise (addr was detected to be a duplicate on + // the link the NIC is a part of, or it was stopped for some other + // reason, such as the address being removed). If an error occured + // during DAD, err will be set and resolved must be ignored. + // + // This function is permitted to block indefinitely without interfering + // with the stack's operation. + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + + // OnDefaultRouterDiscovered will be called when a new default router is + // discovered. Implementations must return true if the newly discovered + // router should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool + + // OnDefaultRouterInvalidated will be called when a discovered default + // router that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) + + // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is + // discovered. Implementations must return true if the newly discovered + // on-link prefix should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool + + // OnOnLinkPrefixInvalidated will be called when a discovered on-link + // prefix that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + + // OnAutoGenAddress will be called when a new prefix with its + // autonomous address-configuration flag set has been received and SLAAC + // has been performed. Implementations may prevent the stack from + // assigning the address to the NIC by returning false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressInvalidated will be called when an auto-generated + // address (as part of SLAAC) has been invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption will be called when an NDP option with + // recursive DNS servers has been received. Note, addrs may contain + // link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Neighbor Solicitation messages to send when doing @@ -64,6 +201,31 @@ type NDPConfigurations struct { // // Must be greater than 0.5s. RetransmitTimer time.Duration + + // HandleRAs determines whether or not Router Advertisements will be + // processed. + HandleRAs bool + + // DiscoverDefaultRouters determines whether or not default routers will + // be discovered from Router Advertisements. This configuration is + // ignored if HandleRAs is false. + DiscoverDefaultRouters bool + + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes + // will be discovered from Router Advertisements' Prefix Information + // option. This configuration is ignored if HandleRAs is false. + DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not global IPv6 + // addresses will be generated for a NIC in response to receiving a new + // Prefix Information option with its Autonomous Address + // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -72,6 +234,10 @@ func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ DupAddrDetectTransmits: defaultDupAddrDetectTransmits, RetransmitTimer: defaultRetransmitTimer, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, } } @@ -88,8 +254,24 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // The NIC this ndpState is for. + nic *NIC + + // configs is the per-interface NDP configurations. + configs NDPConfigurations + // The DAD state to send the next NS message, or resolve the address. dad map[tcpip.Address]dadState + + // The default routers discovered through Router Advertisements. + defaultRouters map[tcpip.Address]defaultRouterState + + // The on-link prefixes discovered through Router Advertisements' Prefix + // Information option. + onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The addresses generated by SLAAC. + autoGenAddresses map[tcpip.Address]autoGenAddressState } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -105,13 +287,84 @@ type dadState struct { done *bool } +// defaultRouterState holds data associated with a default router discovered by +// a Router Advertisement (RA). +type defaultRouterState struct { + invalidationTimer *time.Timer + + // Used to inform the timer not to invalidate the default router (R) in + // a race condition (T1 is a goroutine that handles an RA from R and T2 + // is the goroutine that handles R's invalidation timer firing): + // T1: Receive a new RA from R + // T1: Obtain the NIC's lock before processing the RA + // T2: R's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends R's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates R immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate R, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool +} + +// onLinkPrefixState holds data associated with an on-link prefix discovered by +// a Router Advertisement's Prefix Information option (PI) when the NDP +// configurations was configured to do so. +type onLinkPrefixState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition (T1 is a goroutine that handles a PI for P and T2 + // is the goroutine that handles P's invalidation timer firing): + // T1: Receive a new PI for P + // T1: Obtain the NIC's lock before processing the PI + // T2: P's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends P's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates P immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate P, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool +} + +// autoGenAddressState holds data associated with an address generated via +// SLAAC. +type autoGenAddressState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the SLAAC address (A) in + // a race condition (T1 is a goroutine that handles a PI for A and T2 + // is the goroutine that handles A's invalidation timer firing): + // T1: Receive a new PI for A + // T1: Obtain the NIC's lock before processing the PI + // T2: A's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends A's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates A immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate A, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool + + // Nonzero only when the address is not valid forever (invalidationTimer + // is not nil). + validUntil time.Time +} + // startDuplicateAddressDetection performs Duplicate Address Detection. // // This function must only be called by IPv6 addresses that are currently // tentative. // -// The NIC that ndp belongs to (n) MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { return tcpip.ErrAddressFamilyNotSupported @@ -127,13 +380,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, // reference count would have been increased without doing the // work that would have been done for an address that was brand // new. See NIC.addPermanentAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, n.ID())) + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) } - remaining := n.stack.ndpConfigs.DupAddrDetectTransmits + remaining := ndp.configs.DupAddrDetectTransmits { - done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref) + done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) if err != nil { return err } @@ -146,42 +399,59 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, var done bool var timer *time.Timer - timer = time.AfterFunc(n.stack.ndpConfigs.RetransmitTimer, func() { - n.mu.Lock() - defer n.mu.Unlock() + timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { + var d bool + var err *tcpip.Error - if done { - // If we reach this point, it means that the DAD timer - // fired after another goroutine already obtained the - // NIC lock and stopped DAD before it this function - // obtained the NIC lock. Simply return here and do - // nothing further. - return - } + // doDadIteration does a single iteration of the DAD loop. + // + // Returns true if the integrator needs to be informed of DAD + // completing. + doDadIteration := func() bool { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() - ref, ok := n.endpoints[NetworkEndpointID{addr}] - if !ok { - // This should never happen. - // We should have an endpoint for addr since we are - // still performing DAD on it. If the endpoint does not - // exist, but we are doing DAD on it, then we started - // DAD at some point, but forgot to stop it when the - // endpoint was deleted. - panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, n.ID())) - } + if done { + // If we reach this point, it means that the DAD + // timer fired after another goroutine already + // obtained the NIC lock and stopped DAD before + // this function obtained the NIC lock. Simply + // return here and do nothing further. + return false + } - if done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref); err != nil || done { - if err != nil { - log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, n.ID(), err) + ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] + if !ok { + // This should never happen. + // We should have an endpoint for addr since we + // are still performing DAD on it. If the + // endpoint does not exist, but we are doing DAD + // on it, then we started DAD at some point, but + // forgot to stop it when the endpoint was + // deleted. + panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) } - ndp.stopDuplicateAddressDetection(addr) - return - } + d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) + if err != nil || d { + delete(ndp.dad, addr) - timer.Reset(n.stack.ndpConfigs.RetransmitTimer) - remaining-- + if err != nil { + log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) + } + // Let the integrator know DAD has completed. + return true + } + + remaining-- + timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + return false + } + + if doDadIteration() && ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + } }) ndp.dad[addr] = dadState{ @@ -204,11 +474,11 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, // The NIC that ndp belongs to (n) MUST be locked. // // Returns true if DAD has resolved; false if DAD is still ongoing. -func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { +func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { if ref.getKind() != permanentTentative { // The endpoint should still be marked as tentative // since we are still performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, n.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) } if remaining == 0 { @@ -219,17 +489,17 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem // Send a new NS. snmc := header.SolicitedNodeAddr(addr) - snmcRef, ok := n.endpoints[NetworkEndpointID{snmc}] + snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] if !ok { // This should never happen as if we have the // address, we should have the solicited-node // address. - panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", n.ID(), snmc, addr)) + panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) } // Use the unspecified address as the source address when performing // DAD. - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, n.linkEP.LinkAddress(), snmcRef, false, false) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) @@ -239,7 +509,9 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: DefaultTOS}); err != nil { + if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return false, err } @@ -275,5 +547,611 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) - return + // Let the integrator know DAD did not resolve. + if ndp.nic.stack.ndpDisp != nil { + go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + } +} + +// handleRA handles a Router Advertisement message that arrived on the NIC +// this ndp is for. Does nothing if the NIC is configured to not handle RAs. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + // Is the NIC configured to handle RAs at all? + // + // Currently, the stack does not determine router interface status on a + // per-interface basis; it is a stack-wide configuration, so we check + // stack's forwarding flag to determine if the NIC is a routing + // interface. + if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + return + } + + // Is the NIC configured to discover default routers? + if ndp.configs.DiscoverDefaultRouters { + rtr, ok := ndp.defaultRouters[ip] + rl := ra.RouterLifetime() + switch { + case !ok && rl != 0: + // This is a new default router we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredDefaultRouters routers. + if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { + ndp.rememberDefaultRouter(ip, rl) + } + + case ok && rl != 0: + // This is an already discovered default router. Update + // the invalidation timer. + timer := rtr.invalidationTimer + + // We should ALWAYS have an invalidation timer for a + // discovered router. + if timer == nil { + panic("ndphandlera: RA invalidation timer should not be nil") + } + + if !timer.Stop() { + // If we reach this point, then we know the + // timer fired after we already took the NIC + // lock. Inform the timer not to invalidate the + // router when it obtains the lock as we just + // got a new RA that refreshes its lifetime to a + // non-zero value. See + // defaultRouterState.doNotInvalidate for more + // details. + *rtr.doNotInvalidate = true + } + + timer.Reset(rl) + + case ok && rl == 0: + // We know about the router but it is no longer to be + // used as a default router so invalidate it. + ndp.invalidateDefaultRouter(ip) + } + } + + // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter + // Discovery. + + // We know the options is valid as far as wire format is concerned since + // we got the Router Advertisement, as documented by this fn. Given this + // we do not check the iterator for errors on calls to Next. + it, _ := ra.Options().Iter(false) + for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) + + case header.NDPPrefixInformation: + prefix := opt.Subnet() + + // Is the prefix a link-local? + if header.IsV6LinkLocalAddress(prefix.ID()) { + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). + continue + } + + // Is the Prefix Length 0? + if prefix.Prefix() == 0 { + // ...Yes, skip as this is an invalid prefix + // as all IPv6 addresses cannot be on-link. + continue + } + + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) + } + + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) + } + } + + // TODO(b/141556115): Do (MTU) Parameter Discovery. + } +} + +// invalidateDefaultRouter invalidates a discovered default router. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { + rtr, ok := ndp.defaultRouters[ip] + + // Is the router still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + rtr.invalidationTimer.Stop() + rtr.invalidationTimer = nil + *rtr.doNotInvalidate = true + rtr.doNotInvalidate = nil + + delete(ndp.defaultRouters, ip) + + // Let the integrator know a discovered default router is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + } +} + +// rememberDefaultRouter remembers a newly discovered default router with IPv6 +// link-local address ip with lifetime rl. +// +// The router identified by ip MUST NOT already be known by the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered a default router. + if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { + // Informed by the integrator to not remember the router, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the default router (R) in + // a race condition. See defaultRouterState.doNotInvalidate for more + // details. + var doNotInvalidate bool + + ndp.defaultRouters[ip] = defaultRouterState{ + invalidationTimer: time.AfterFunc(rl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if doNotInvalidate { + doNotInvalidate = false + return + } + + ndp.invalidateDefaultRouter(ip) + }), + doNotInvalidate: &doNotInvalidate, + } +} + +// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 +// address with prefix prefix with lifetime l. +// +// The prefix identified by prefix MUST NOT already be known. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered an on-link prefix. + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { + // Informed by the integrator to not remember the prefix, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition. See onLinkPrefixState.doNotInvalidate for more + // details. + var doNotInvalidate bool + var timer *time.Timer + + // Only create a timer if the lifetime is not infinite. + if l < header.NDPInfiniteLifetime { + timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) + } + + ndp.onLinkPrefixes[prefix] = onLinkPrefixState{ + invalidationTimer: timer, + doNotInvalidate: &doNotInvalidate, + } +} + +// invalidateOnLinkPrefix invalidates a discovered on-link prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { + s, ok := ndp.onLinkPrefixes[prefix] + + // Is the on-link prefix still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + if s.invalidationTimer != nil { + s.invalidationTimer.Stop() + s.invalidationTimer = nil + *s.doNotInvalidate = true + } + + s.doNotInvalidate = nil + + delete(ndp.onLinkPrefixes, prefix) + + // Let the integrator know a discovered on-link prefix is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + } +} + +// prefixInvalidationCallback returns a new on-link prefix invalidation timer +// for prefix that fires after vl. +// +// doNotInvalidate is used to signal the timer when it fires at the same time +// that a prefix's valid lifetime gets refreshed. See +// onLinkPrefixState.doNotInvalidate for more details. +func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateOnLinkPrefix(prefix) + }) +} + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // Update the invalidation timer. + timer := prefixState.invalidationTimer + + if timer == nil && vl >= header.NDPInfiniteLifetime { + // Had infinite valid lifetime before and + // continues to have an invalid lifetime. Do + // nothing further. + return + } + + if timer != nil && !timer.Stop() { + // If we reach this point, then we know the timer alread fired + // after we took the NIC lock. Inform the timer to not + // invalidate the prefix once it obtains the lock as we just + // got a new PI that refreshes its lifetime to a non-zero value. + // See onLinkPrefixState.doNotInvalidate for more details. + *prefixState.doNotInvalidate = true + } + + if vl >= header.NDPInfiniteLifetime { + // Prefix is now valid forever so we don't need + // an invalidation timer. + prefixState.invalidationTimer = nil + ndp.onLinkPrefixes[prefix] = prefixState + return + } + + if timer != nil { + // We already have a timer so just reset it to + // expire after the new valid lifetime. + timer.Reset(vl) + return + } + + // We do not have a timer so just create a new one. + prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already have an auto-generated address for prefix. + for _, ref := range ndp.nic.endpoints { + if ref.protocol != header.IPv6ProtocolNumber { + continue + } + + if ref.configType != slaac { + continue + } + + addr := ref.ep.ID().LocalAddress + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + if refAddrWithPrefix.Subnet() != prefix { + continue + } + + // + // At this point, we know we are refreshing a SLAAC generated + // IPv6 address with the prefix, prefix. Do the work as outlined + // by RFC 4862 section 5.5.3.e. + // + + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) + } + + // TODO(b/143713887): Handle deprecating auto-generated address + // after the preferred lifetime. + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the + // address generated by SLAAC is as follows: + // + // 1) If the received Valid Lifetime is greater than 2 hours or + // greater than RemainingLifetime, set the valid lifetime of + // the address to the advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, + // ignore the advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 + // hours. + + // Handle the infinite valid lifetime separately as we do not + // keep a timer in this case. + if vl >= header.NDPInfiniteLifetime { + if addrState.invalidationTimer != nil { + // Valid lifetime was finite before, but now it + // is valid forever. + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer = nil + addrState.validUntil = time.Time{} + ndp.autoGenAddresses[addr] = addrState + } + + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, + // assume the remaining time to be the maximum possible value. + if addrState.invalidationTimer == nil { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + ndp.autoGenAddresses[addr] = addrState + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if addrState.invalidationTimer == nil { + addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) + } else { + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer.Reset(effectiveVl) + } + + addrState.validUntil = time.Now().Add(effectiveVl) + ndp.autoGenAddresses[addr] = addrState + return + } + + // We do not already have an address within the prefix, prefix. Do the + // work as outlined by RFC 4862 section 5.5.3.d if n is configured + // to auto-generated global addresses by SLAAC. + + // Are we configured to auto-generate new global addresses? + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } + + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. + addrBytes := make([]byte, header.IPv6AddressSize) + copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.Address(addrBytes) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + } + + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr) { + return + } + + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + // Informed by the integrator not to add the address. + return + } + + if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + }, FirstPrimaryEndpoint, permanent, slaac); err != nil { + panic(err) + } + + // Setup the timers to deprecate and invalidate this newly generated + // address. + + // TODO(b/143713887): Handle deprecating auto-generated addresses + // after the preferred lifetime. + + var doNotInvalidate bool + var vTimer *time.Timer + if vl < header.NDPInfiniteLifetime { + vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + } + + ndp.autoGenAddresses[addr] = autoGenAddressState{ + invalidationTimer: vTimer, + doNotInvalidate: &doNotInvalidate, + validUntil: time.Now().Add(vl), + } +} + +// invalidateAutoGenAddress invalidates an auto-generated address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { + if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { + return + } + + ndp.nic.removePermanentAddressLocked(addr) +} + +// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated +// address's resources from ndp. If the stack has an NDP dispatcher, it will +// be notified that addr has been invalidated. +// +// Returns true if ndp had resources for addr to cleanup. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { + state, ok := ndp.autoGenAddresses[addr] + + if !ok { + return false + } + + if state.invalidationTimer != nil { + state.invalidationTimer.Stop() + state.invalidationTimer = nil + *state.doNotInvalidate = true + } + + state.doNotInvalidate = nil + + delete(ndp.autoGenAddresses, addr) + + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } + + return true +} + +// autoGenAddrInvalidationTimer returns a new invalidation timer for an +// auto-generated address that fires after vl. +// +// doNotInvalidate is used to inform the timer when it fires at the same time +// that an auto-generated address's valid lifetime gets refreshed. See +// autoGenAddrState.doNotInvalidate for more details. +func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateAutoGenAddress(addr) + }) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index cc789d70f..666f86c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,9 +15,12 @@ package stack_test import ( + "encoding/binary" + "fmt" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -29,12 +32,46 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = "\x01\x02\x03\x04\x05\x06" - linkAddr2 = "\x01\x02\x03\x04\x05\x07" + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + linkAddr1 = "\x02\x02\x03\x04\x05\x06" + linkAddr2 = "\x02\x02\x03\x04\x05\x07" + linkAddr3 = "\x02\x02\x03\x04\x05\x08" + defaultTimeout = 100 * time.Millisecond ) +var ( + llAddr1 = header.LinkLocalAddr(linkAddr1) + llAddr2 = header.LinkLocalAddr(linkAddr2) + llAddr3 = header.LinkLocalAddr(linkAddr3) +) + +// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent +// tcpip.Subnet, and an address where the lower half of the address is composed +// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. +func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { + prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixBytes), + PrefixLen: 64, + } + + subnet := prefix.Subnet() + + var addr tcpip.AddressWithPrefix + if header.IsValidUnicastEthernetAddress(linkAddr) { + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + } + + return prefix, subnet, addr +} + // TestDADDisabled tests that an address successfully resolves immediately // when DAD is not enabled (the default for an empty stack.Options). func TestDADDisabled(t *testing.T) { @@ -42,7 +79,7 @@ func TestDADDisabled(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -68,6 +105,160 @@ func TestDADDisabled(t *testing.T) { } } +// ndpDADEvent is a set of parameters that was passed to +// ndpDispatcher.OnDuplicateAddressDetectionStatus. +type ndpDADEvent struct { + nicID tcpip.NICID + addr tcpip.Address + resolved bool + err *tcpip.Error +} + +type ndpRouterEvent struct { + nicID tcpip.NICID + addr tcpip.Address + // true if router was discovered, false if invalidated. + discovered bool +} + +type ndpPrefixEvent struct { + nicID tcpip.NICID + prefix tcpip.Subnet + // true if prefix was discovered, false if invalidated. + discovered bool +} + +type ndpAutoGenAddrEventType int + +const ( + newAddr ndpAutoGenAddrEventType = iota + invalidatedAddr +) + +type ndpAutoGenAddrEvent struct { + nicID tcpip.NICID + addr tcpip.AddressWithPrefix + eventType ndpAutoGenAddrEventType +} + +type ndpRDNSS struct { + addrs []tcpip.Address + lifetime time.Duration +} + +type ndpRDNSSEvent struct { + nicID tcpip.NICID + rdnss ndpRDNSS +} + +var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) + +// ndpDispatcher implements NDPDispatcher so tests can know when various NDP +// related events happen for test purposes. +type ndpDispatcher struct { + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent +} + +// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { + if n.dadC != nil { + n.dadC <- ndpDADEvent{ + nicID, + addr, + resolved, + err, + } + } +} + +// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + true, + } + } + + return n.rememberRouter +} + +// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + false, + } + } +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + true, + } + } + + return n.rememberPrefix +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + false, + } + } +} + +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + newAddr, + } + } + return true +} + +func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + invalidatedAddr, + } + } +} + +// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { + if c := n.rdnssC; c != nil { + c <- ndpRDNSSEvent{ + nicID, + ndpRDNSS{ + addrs, + lifetime, + }, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -88,9 +279,17 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits @@ -107,8 +306,7 @@ func TestDADResolve(t *testing.T) { stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit - // Should have sent an NDP NS almost immediately. - time.Sleep(100 * time.Millisecond) + // Should have sent an NDP NS immediately. if got := stat.Value(); got != 1 { t.Fatalf("got NeighborSolicit = %d, want = 1", got) @@ -124,16 +322,10 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) } - // Wait for the remaining time - 500ms, to make sure - // the address is still not resolved. Note, we subtract - // 600ms because we already waited for 100ms earlier, - // so our remaining time is 100ms less than the expected - // time. - // (X - 100ms) - 500ms = X - 600ms - // - // TODO(b/140896005): Use events from the netstack to - // be signalled before DAD resolves. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - 600*time.Millisecond) + // Wait for the remaining time - some delta (500ms), to + // make sure the address is still not resolved. + const delta = 500 * time.Millisecond + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -142,13 +334,30 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) } - // Wait for the remaining time + 250ms, at which point - // the address should be resolved. Note, the remaining - // time is 500ms. See above comments. - // - // TODO(b/140896005): Use events from the netstack to - // know immediately when DAD completes. - time.Sleep(750 * time.Millisecond) + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) @@ -172,7 +381,8 @@ func TestDADResolve(t *testing.T) { } // Check NDP packet. - checker.IPv6(t, p.Header.ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1))) } @@ -250,13 +460,18 @@ func TestDADFail(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.DefaultNDPConfigurations(), + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = time.Second * 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -279,15 +494,37 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { t.Fatalf("got stat = %d, want = 1", got) } - // Wait 3 seconds to make sure that DAD did not resolve - time.Sleep(3 * time.Second) + // Wait for DAD to fail and make sure the address did + // not get resolved. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -302,13 +539,20 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, } - opts.NDPConfigs.RetransmitTimer = time.Second - opts.NDPConfigs.DupAddrDetectTransmits = 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -332,11 +576,27 @@ func TestDADStop(t *testing.T) { t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) } - // Wait for the time to normally resolve - // DupAddrDetectTransmits(2) * RetransmitTimer(1s) = 2s. - // An extra 250ms is added to make sure that if DAD was still running - // it resolves and the check below fails. - time.Sleep(2*time.Second + 250*time.Millisecond) + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -350,3 +610,1448 @@ func TestDADStop(t *testing.T) { t.Fatalf("got NeighborSolicit = %d, want <= 1", got) } } + +// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NDP configurations using an invalid NICID. +func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) + } +} + +// TestSetNDPConfigurations tests that we can update and use per-interface NDP +// configurations without affecting the default NDP configurations or other +// interfaces' configurations. +func TestSetNDPConfigurations(t *testing.T) { + tests := []struct { + name string + dupAddrDetectTransmits uint8 + retransmitTimer time.Duration + expectedRetransmitTimer time.Duration + }{ + { + "OK", + 1, + time.Second, + time.Second, + }, + { + "Invalid Retransmit Timer", + 1, + 0, + time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + }) + + // This NIC(1)'s NDP configurations will be updated to + // be different from the default. + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Created before updating NIC(1)'s NDP configurations + // but updating NIC(1)'s NDP configurations should not + // affect other existing NICs. + if err := s.CreateNIC(2, e); err != nil { + t.Fatalf("CreateNIC(2) = %s", err) + } + + // Update the NDP configurations on NIC(1) to use DAD. + configs := stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + RetransmitTimer: test.retransmitTimer, + } + if err := s.SetNDPConfigurations(1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + } + + // Created after updating NIC(1)'s NDP configurations + // but the stack's default NDP configurations should not + // have been updated. + if err := s.CreateNIC(3, e); err != nil { + t.Fatalf("CreateNIC(3) = %s", err) + } + + // Add addresses for each NIC. + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + } + if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + } + + // Address should not be considered bound to NIC(1) yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Should get the address on NIC(2) and NIC(3) + // immediately since we should not have performed DAD on + // it as the stack was configured to not do DAD by + // default and we only updated the NDP configurations on + // NIC(1). + addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + } + if addr.Address != addr2 { + t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + } + addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + } + if addr.Address != addr3 { + t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + } + + // Sleep until right (500ms before) before resolution to + // make sure the address didn't resolve on NIC(1) yet. + const delta = 500 * time.Millisecond + time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + } + }) + } +} + +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(0) + ra := header.NDPRouterAdvert(pkt.NDPPayload()) + opts := ra.Options() + opts.Serialize(optSer) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + iph.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} +} + +// raBuf returns a valid NDP Router Advertisement. +// +// Note, raBuf does not populate any of the RA fields other than the +// Router Lifetime. +func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) +} + +// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix +// Information option. +// +// Note, raBufWithPI does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { + flags := uint8(0) + if onLink { + // The OnLink flag is the 7th bit in the flags byte. + flags |= 1 << 7 + } + if auto { + // The Address Auto-Configuration flag is the 6th bit in the + // flags byte. + flags |= 1 << 6 + } + + // A valid header.NDPPrefixInformation must be 30 bytes. + buf := [30]byte{} + // The first byte in a header.NDPPrefixInformation is the Prefix Length + // field. + buf[0] = uint8(prefix.PrefixLen) + // The 2nd byte within a header.NDPPrefixInformation is the Flags field. + buf[1] = flags + // The Valid Lifetime field starts after the 2nd byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[2:], vl) + // The Preferred Lifetime field starts after the 6th byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[6:], pl) + // The Prefix Address field starts after the 14th byte within a + // header.NDPPrefixInformation. + copy(buf[14:], prefix.Address) + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ + header.NDPPrefixInformation(buf[:]), + }) +} + +// TestNoRouterDiscovery tests that router discovery will not be performed if +// configured not to. +func TestNoRouterDiscovery(t *testing.T) { + // Being configured to discover routers means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // router discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// discovered flag set to discovered. +func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { + return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered router when the dispatcher asks it not to. +func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA for a router we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr2, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the router in the first place. + select { + case <-ndpDisp.routerC: + t.Fatal("should not have received any router events") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +func TestRouterDiscovery(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + } + + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } + + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from another router (lladdr3) with non-zero lifetime. + l3Lifetime := time.Duration(6) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) + expectRouterEvent(llAddr3, true) + + // Rx an RA from lladdr2 with lesser lifetime. + l2Lifetime := time.Duration(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } + + // Wait for lladdr2's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) +} + +// TestRouterDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +func TestRouterDiscoveryMaxRouters(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA from 2 more than the max number of discovered routers. + for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + linkAddr := []byte{2, 2, 3, 4, 5, 0} + linkAddr[5] = byte(i) + llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) + + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + + if i <= stack.MaxDiscoveredDefaultRouters { + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + + } else { + select { + case <-ndpDisp.routerC: + t.Fatal("should not have discovered a new router after we already discovered the max number of routers") + default: + } + } + } +} + +// TestNoPrefixDiscovery tests that prefix discovery will not be performed if +// configured not to. +func TestNoPrefixDiscovery(t *testing.T) { + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + + // Being configured to discover prefixes means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // prefix discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) + + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for prefix on nic with ID 1, and the +// discovered flag set to discovered. +func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { + return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered on-link prefix when the dispatcher asks it not to. +func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + + prefix, subnet, _ := prefixSubnetAddr(0, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix that we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the prefix in the first place. + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have received any prefix events") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +func TestPrefixDiscovery(t *testing.T) { + t.Parallel() + + prefix1, subnet1, _ := prefixSubnetAddr(0, "") + prefix2, subnet2, _ := prefixSubnetAddr(1, "") + prefix3, subnet3, _ := prefixSubnetAddr(2, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) + + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) + + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) + + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: + } + + // Wait for prefix2's most recent invalidation timer plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) +} + +func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { + // Update the infinite lifetime value to a smaller value so we can test + // that when we receive a PI with such a lifetime value, we do not + // invalidate the prefix. + const testInfiniteLifetimeSeconds = 2 + const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second + saved := header.NDPInfiniteLifetime + header.NDPInfiniteLifetime = testInfiniteLifetime + defer func() { + header.NDPInfiniteLifetime = saved + }() + + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + subnet := prefix.Subnet() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix in an NDP Prefix Information option (PI) + // with infinite valid lifetime which should not get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + expectPrefixEvent(subnet, true) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with finite lifetime. + // The prefix should get invalidated after 1s. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(testInfiniteLifetime): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive an RA with finite lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + expectPrefixEvent(subnet, true) + + // Receive an RA with prefix with an infinite lifetime. + // The prefix should not be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with a prefix with a lifetime value greater than the + // set infinite lifetime value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): + } + + // Receive an RA with 0 lifetime. + // The prefix should get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) + expectPrefixEvent(subnet, false) +} + +// TestPrefixDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + + // Receive an RA with 2 more than the max number of discovered on-link + // prefixes. + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} + prefixAddr[7] = byte(i) + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixAddr[:]), + PrefixLen: 64, + } + prefixes[i] = prefix.Subnet() + buf := [30]byte{} + buf[0] = uint8(prefix.PrefixLen) + buf[1] = 128 + binary.BigEndian.PutUint32(buf[2:], 10) + copy(buf[14:], prefix.Address) + + optSer[i] = header.NDPPrefixInformation(buf[:]) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < stack.MaxDiscoveredOnLinkPrefixes { + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } else { + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") + default: + } + } + } +} + +// Checks to see if list contains an IPv6 address, item. +func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { + protocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: item, + } + + for _, i := range list { + if i == protocolAddress { + return true + } + } + + return false +} + +// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. +func TestNoAutoGenAddr(t *testing.T) { + prefix, _, _ := prefixSubnetAddr(0, "") + + // Being configured to auto-generate addresses means handle and + // autogen are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, autogen = + // true and forwarding = false (the required configuration to do + // SLAAC) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + autogen := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) + + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// event type is set to eventType. +func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { + return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) +} + +// TestAutoGenAddr tests that an address is properly generated and invalidated +// when configured to do so. +func TestAutoGenAddr(t *testing.T) { + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } +} + +// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an +// auto-generated address only gets updated when required to, as specified in +// RFC 4862 section 5.5.3.e. +func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { + const infiniteVL = 4294967295 + const newMinVL = 5 + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + ovl uint32 + nvl uint32 + evl uint32 + }{ + // Should update the VL to the minimum VL for updating if the + // new VL is less than newMinVL but was originally greater than + // it. + { + "LargeVLToVLLessThanMinVLForUpdate", + 9999, + 1, + newMinVL, + }, + { + "LargeVLTo0", + 9999, + 0, + newMinVL, + }, + { + "InfiniteVLToVLLessThanMinVLForUpdate", + infiniteVL, + 1, + newMinVL, + }, + { + "InfiniteVLTo0", + infiniteVL, + 0, + newMinVL, + }, + + // Should not update VL if original VL was less than newMinVL + // and the new VL is also less than newMinVL. + { + "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", + newMinVL - 1, + newMinVL - 3, + newMinVL - 1, + }, + + // Should take the new VL if the new VL is greater than the + // remaining time or is greater than newMinVL. + { + "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", + newMinVL + 5, + newMinVL + 3, + newMinVL + 3, + }, + { + "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", + newMinVL - 3, + newMinVL - 1, + newMinVL - 1, + }, + { + "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", + newMinVL - 3, + newMinVL + 1, + newMinVL + 1, + }, + } + + const delta = 500 * time.Millisecond + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + + // + // Validate that the VL for the address got set + // to test.evl. + // + + // Make sure we do not get any invalidation + // events until atleast 500ms (delta) before + // test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): + } + + // Wait for another second (2x delta), but now + // we expect the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + +// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed +// by the user, its resources will be cleaned up and an invalidation event will +// be sent to the integrator. +func TestAutoGenAddrRemoval(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive a PI to auto-generate an address. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr, newAddr) + + // Removing the address should result in an invalidation event + // immediately. + if err := s.RemoveAddress(1, addr.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // Wait for the original valid lifetime to make sure the original timer + // got stopped/cleaned up. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that +// is already assigned to the NIC, the static address remains. +func TestAutoGenAddrStaticConflict(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Add the address as a static address before SLAAC tries to add it. + if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive a PI where the generated address will be the same as the one + // that we already have assigned statically. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") + default: + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Should not get an invalidation event after the PI's invalidation + // time. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } +} + +// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event +// to the integrator when an RA is received with the NDP Recursive DNS Server +// option with at least one valid address. +func TestNDPRecursiveDNSServerDispatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opt header.NDPRecursiveDNSServer + expected *ndpRDNSS + }{ + { + "Unspecified", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }), + nil, + }, + { + "Multicast", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }), + nil, + }, + { + "OptionTooSmall", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, + }), + nil, + }, + { + "0Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + }), + nil, + }, + { + "Valid1Address", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + }, + 2 * time.Second, + }, + }, + { + "Valid2Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + }, + time.Second, + }, + }, + { + "Valid3Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", + }, + 0, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + // We do not expect more than a single RDNSS + // event at any time for this test. + rdnssC: make(chan ndpRDNSSEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) + + if test.expected != nil { + select { + case e := <-ndpDisp.rdnssC: + if e.nicID != 1 { + t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) + } + if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { + t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) + } + if e.rdnss.lifetime != test.expected.lifetime { + t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) + } + default: + t.Fatal("expected an RDNSS option event") + } + } + + // Should have no more RDNSS options. + select { + case e := <-ndpDisp.rdnssC: + t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) + default: + } + }) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43f4ad91e..e8401c673 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -19,7 +19,6 @@ import ( "sync" "sync/atomic" - "gvisor.dev/gvisor/pkg/ilist" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -37,13 +36,20 @@ type NIC struct { mu sync.RWMutex spoofing bool promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint stats NICStats + // ndp is the NDP related state for NIC. + // + // Note, read and write operations on ndp require that the NIC is + // appropriately locked. ndp ndpState } @@ -78,16 +84,26 @@ const ( NeverPrimaryEndpoint ) +// newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { - return &NIC{ + // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For + // example, make sure that the link address it provides is a valid + // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + + nic := &NIC{ stack: stack, id: id, name: name, linkEP: ep, loopback: loopback, - primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), + packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -99,9 +115,24 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback }, }, ndp: ndpState{ - dad: make(map[tcpip.Address]dadState), + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), }, } + nic.ndp.nic = nic + + // Register supported packet endpoint protocols. + for _, netProto := range header.Ethertypes { + nic.packetEPs[netProto] = []PacketEndpoint{} + } + for _, netProto := range stack.networkProtocols { + nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + } + + return nic } // enable enables the NIC. enable will attach the link to its LinkEndpoint and @@ -126,11 +157,50 @@ func (n *NIC) enable() *tcpip.Error { // when we perform Duplicate Address Detection, or Router Advertisement // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 // section 4.2 for more information. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + // + // Also auto-generate an IPv6 link-local address based on the NIC's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] + if !ok { + return nil } - return nil + n.mu.Lock() + defer n.mu.Unlock() + + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + if !n.stack.autoGenIPv6LinkLocal { + return nil + } + + l2addr := n.linkEP.LinkAddress() + + // Only attempt to generate the link-local address if we have a + // valid MAC address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } + + addr := header.LinkLocalAddr(l2addr) + + _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + }, + }, CanBePrimaryEndpoint) + + return err } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -166,18 +236,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN n.mu.RLock() defer n.mu.RUnlock() - list := n.primary[protocol] - if list == nil { - return nil - } - - for e := list.Front(); e != nil; e = e.Next() { - r := e.(*referencedNetworkEndpoint) - // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set. - switch r.ep.ID().LocalAddress { - case header.IPv4Broadcast, header.IPv4Any: - continue - } + for _, r := range n.primary[protocol] { if r.isValidForOutgoing() && r.tryIncRef() { return r } @@ -186,6 +245,20 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +// hasPermanentAddrLocked returns true if n has a permanent (including currently +// tentative) address, addr. +func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + + if !ok { + return false + } + + kind := ref.getKind() + + return kind == permanent || kind == permanentTentative +} + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) } @@ -277,7 +350,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary) + }, peb, temporary, static) n.mu.Unlock() return ref @@ -291,9 +364,31 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: - // Promote the endpoint to become permanent. + // Promote the endpoint to become permanent and respect + // the new peb. if ref.tryIncRef() { ref.setKind(permanent) + + refs := n.primary[ref.protocol] + for i, r := range refs { + if r == ref { + switch peb { + case CanBePrimaryEndpoint: + return ref, nil + case FirstPrimaryEndpoint: + if i == 0 { + return ref, nil + } + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + case NeverPrimaryEndpoint: + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + return ref, nil + } + } + } + + n.insertPrimaryEndpointLocked(ref, peb) + return ref, nil } // tryIncRef failing means the endpoint is scheduled to be removed once @@ -304,10 +399,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent) + return n.addAddressLocked(protocolAddress, peb, permanent, static) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -337,11 +432,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, + configType: configType, } // Set up cache if link address resolution exists for this protocol. @@ -362,22 +458,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar n.endpoints[id] = ref - l, ok := n.primary[protocolAddress.Protocol] - if !ok { - l = &ilist.List{} - n.primary[protocolAddress.Protocol] = l - } - - switch peb { - case CanBePrimaryEndpoint: - l.PushBack(ref) - case FirstPrimaryEndpoint: - l.PushFront(ref) - } + n.insertPrimaryEndpointLocked(ref, peb) // If we are adding a tentative IPv6 address, start DAD. if isIPv6Unicast && kind == permanentTentative { - if err := n.ndp.startDuplicateAddressDetection(n, protocolAddress.AddressWithPrefix.Address, ref); err != nil { + if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } } @@ -430,8 +515,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for proto, list := range n.primary { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) + for _, ref := range list { // Don't include tentative, expired or tempory endpoints // to avoid confusion and prevent the caller from using // those. @@ -496,6 +580,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { return append(sns, n.addressRanges...) } +// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required +// by peb. +// +// n MUST be locked. +func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { + switch peb { + case CanBePrimaryEndpoint: + n.primary[r.protocol] = append(n.primary[r.protocol], r) + case FirstPrimaryEndpoint: + n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + } +} + func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() @@ -513,9 +610,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { } delete(n.endpoints, id) - wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front() - if wasInList { - n.primary[r.protocol].Remove(r) + refs := n.primary[r.protocol] + for i, ref := range refs { + if ref == r { + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + break + } } r.ep.Close() @@ -540,9 +640,18 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) - // If we are removing a tentative IPv6 unicast address, stop DAD. - if isIPv6Unicast && kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + if isIPv6Unicast { + // If we are removing a tentative IPv6 unicast address, stop + // DAD. + if kind == permanentTentative { + n.ndp.stopDuplicateAddressDetection(addr) + } + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + if r.configType == slaac { + n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + } } r.setKind(permanentExpired) @@ -585,6 +694,11 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // exists yet. Otherwise it just increments its count. n MUST be locked before // joinGroupLocked is called. func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + // TODO(b/143102137): When implementing MLD, make sure MLD packets are + // not sent unless a valid link-local address is available for use on n + // as an MLD packet's source address must be a link-local address as + // outlined in RFC 3810 section 5. + id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -635,10 +749,10 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -648,9 +762,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -658,19 +772,39 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr return } + // If no local link layer address is provided, assume it was sent + // directly to this NIC. + if local == "" { + local = n.linkEP.LinkAddress() + } + + // Are any packet sockets listening for this network protocol? + n.mu.RLock() + packetEPs := n.packetEPs[protocol] + // Check whether there are packet sockets listening for every protocol. + // If we received a packet with protocol EthernetProtocolAll, then the + // previous for loop will have handled it. + if protocol != header.EthernetProtocolAll { + packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + } + n.mu.RUnlock() + for _, ep := range packetEPs { + ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + } + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - if len(vv.First()) < netProto.MinimumPacketSize() { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(vv.First()) + src, dst := netProto.ParseAddresses(pkt.Data.First()) if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return } @@ -698,31 +832,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) - vv.RemoveFirst() + pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } } return } - n.stack.stats.IP.InvalidAddressesReceived.Increment() + // If a packet socket handled the packet, don't treat it as invalid. + if len(packetEPs) == 0 { + n.stack.stats.IP.InvalidAddressesReceived.Increment() + } } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -734,41 +871,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) + n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(vv.First()) < transProto.MinimumPacketSize() { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, pkt, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, netHeader, vv) { + if state.defaultHandler(r, id, pkt) { return } } // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { n.stack.stats.MalformedRcvdPackets.Increment() } } // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -779,17 +916,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(vv.First()) < 8 { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { return } } @@ -838,6 +975,26 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return n.removePermanentAddressLocked(addr) } +// setNDPConfigs sets the NDP configurations for n. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNDPConfigs(c NDPConfigurations) { + c.validate() + + n.mu.Lock() + n.ndp.configs = c + n.mu.Unlock() +} + +// handleNDPRA handles an NDP Router Advertisement message that arrived on n. +func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.handleRA(ip, ra) +} + type networkEndpointKind int32 const ( @@ -857,7 +1014,7 @@ const ( // removing the permanent address from the NIC. permanent - // An expired permanent endoint is a permanent endoint that had its address + // An expired permanent endpoint is a permanent endpoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes // hold a reference to it. This is achieved by decreasing its reference count // by 1. If its address is re-added before the endpoint is removed, its type @@ -873,8 +1030,50 @@ const ( temporary ) +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return tcpip.ErrNotSupported + } + n.packetEPs[netProto] = append(eps, ep) + + return nil +} + +func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return + } + + for i, epOther := range eps { + if epOther == ep { + n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + return + } + } +} + +type networkEndpointConfigType int32 + +const ( + // A statically configured endpoint is an address that was added by + // some user-specified action (adding an explicit address, joining a + // multicast group). + static networkEndpointConfigType = iota + + // A slaac configured endpoint is an IPv6 endpoint that was + // added by SLAAC as per RFC 4862 section 5.5.3. + slaac +) + type referencedNetworkEndpoint struct { - ilist.Entry ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -889,6 +1088,10 @@ type referencedNetworkEndpoint struct { // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind + + // configType is the method that was used to configure this endpoint. + // This must never change after the endpoint is added to a NIC. + configType networkEndpointConfigType } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 9d6157f22..61fd46d66 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -60,24 +60,64 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + // this transport endpoint. It sets pkt.TransportHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g., + // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + // HandleControlPacket takes ownership of pkt. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. This cleanup may happen asynchronously. Wait can + // be used to block on this asynchronous cleanup. + Close() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw // transport protocol endpoints. RawTransportEndpoints receive the entire -// packet - including the link, network, and transport headers - as delivered -// to netstack. +// packet - including the network and transport headers - as delivered to +// netstack. type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. - HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) +} + +// PacketEndpoint is the interface that needs to be implemented by packet +// transport protocol endpoints. These endpoints receive link layer headers in +// addition to whatever they contain (usually network and transport layer +// headers and a payload). +type PacketEndpoint interface { + // HandlePacket is called by the stack when new packets arrive that + // match the endpoint. + // + // Implementers should treat packet as immutable and should copy it + // before before modification. + // + // linkHeader may have a length of 0, in which case the PacketEndpoint + // should construct its own ethernet header for applications. + // + // HandlePacket takes ownership of pkt. + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -107,7 +147,9 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + // + // HandleUnknownDestinationPacket takes ownership of pkt. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -125,13 +167,21 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. It also returns the network layer - // header for the enpoint to inspect or pass up the stack. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) + // transport protocol endpoint. + // + // pkt.NetworkHeader must be set before calling DeliverTransportPacket. + // + // DeliverTransportPacket takes ownership of pkt. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) + // + // pkt.NetworkHeader must be set before calling + // DeliverTransportControlPacket. + // + // DeliverTransportControlPacket takes ownership of pkt. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -182,12 +232,17 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error + // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have + // already been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + + // WritePackets writes packets to the given destination address and + // protocol. pkts must not be zero length. + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -199,8 +254,10 @@ type NetworkEndpoint interface { NICID() tcpip.NICID // HandlePacket is called by the link layer when new packets arrive to - // this network endpoint. - HandlePacket(r *Route, vv buffer.VectorisedView) + // this network endpoint. It sets pkt.NetworkHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -225,7 +282,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -242,9 +299,15 @@ type NetworkProtocol interface { // packets to the appropriate network endpoint after it has been handled by // the data link layer. type NetworkDispatcher interface { - // DeliverNetworkPacket finds the appropriate network protocol - // endpoint and hands the packet over for further processing. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // DeliverNetworkPacket finds the appropriate network protocol endpoint + // and hands the packet over for further processing. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverNetworkPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverNetworkPacket takes ownership of pkt. + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -266,12 +329,18 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityGSO + CapabilityHardwareGSO + + // CapabilitySoftwareGSO indicates the link endpoint supports of sending + // multiple packets using a single call (LinkEndpoint.WritePackets). + CapabilitySoftwareGSO ) // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. +// out through the implementer's data link endpoint. When a link header exists, +// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a @@ -293,13 +362,27 @@ type LinkEndpoint interface { // link endpoint. LinkAddress() tcpip.LinkAddress - // WritePacket writes a packet with the given protocol through the given - // route. + // WritePacket writes a packet with the given protocol through the + // given route. It sets pkt.LinkHeader if a link layer header exists. + // pkt.NetworkHeader and pkt.TransportHeader must have already been + // set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error + + // WritePackets writes packets with the given protocol through the + // given route. pkts must not be zero length. + // + // Right now, WritePackets is used only when the software segmentation + // offload is enabled. If it will be used for something else, it may + // require to change syscall filters. + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + + // WriteRawPacket writes a packet directly to the link. The packet + // should already have an ethernet header. + WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. @@ -324,13 +407,14 @@ type LinkEndpoint interface { type InjectableLinkEndpoint interface { LinkEndpoint - // Inject injects an inbound packet. - Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // InjectInbound injects an inbound packet. + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) - // WriteRawPacket writes a fully formed outbound packet directly to the link. + // InjectOutbound writes a fully formed outbound packet directly to the + // link. // // dest is used by endpoints with multiple raw destinations. - WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error } // A LinkAddressResolver is an extension to a NetworkProtocol that @@ -359,10 +443,10 @@ type LinkAddressResolver interface { type LinkAddressCache interface { // CheckLocalAddress determines if the given local address exists, and if it // does not exist. - CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver @@ -373,17 +457,22 @@ type LinkAddressCache interface { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// UnassociatedEndpointFactory produces endpoints for writing packets not -// associated with a particular transport protocol. Such endpoints can be used -// to write arbitrary packets that include the IP header. -type UnassociatedEndpointFactory interface { - NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) +// RawFactory produces endpoints for writing various types of raw packets. +type RawFactory interface { + // NewUnassociatedEndpoint produces endpoints for writing packets not + // associated with a particular transport protocol. Such endpoints can + // be used to write arbitrary packets that include the network header. + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + + // NewPacketEndpoint produces endpoints for reading and writing packets + // that include network and (when cooked is false) link layer headers. + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } // GSOType is the type of GSO segments. @@ -394,8 +483,14 @@ type GSOType int // Types of gso segments. const ( GSONone GSOType = iota + + // Hardware GSO types: GSOTCPv4 GSOTCPv6 + + // GSOSW is used for software GSO segments which have to be sent by + // endpoint.WritePackets. + GSOSW ) // GSO contains generic segmentation offload properties. @@ -423,3 +518,7 @@ type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 } + +// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. +// This isn't a hard limit, because it is never set into packet headers. +const SoftwareGSOMaxSize = (1 << 16) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index e72373964..34307ae07 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -47,8 +46,8 @@ type Route struct { // starts. ref *referencedNetworkEndpoint - // loop controls where WritePacket should send packets. - loop PacketLooping + // Loop controls where WritePacket should send packets. + Loop PacketLooping } // makeRoute initializes a new route. It takes ownership of the provided @@ -69,7 +68,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip LocalLinkAddress: localLinkAddr, RemoteAddress: remoteAddr, ref: ref, - loop: loop, + Loop: loop, } } @@ -154,34 +153,54 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop) + err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } return err } +// WritePackets writes the set of packets through the given route. +func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { + if !r.ref.isValidForOutgoing() { + return 0, tcpip.ErrInvalidEndpointState + } + + n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + } + r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) + payloadSize := 0 + for i := 0; i < n; i++ { + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) + payloadSize += pkts[i].DataSize + } + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a199bc1cc..0e88643a4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "encoding/binary" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -50,7 +51,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -344,6 +345,13 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -351,10 +359,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver - // unassociatedFactory creates unassociated endpoints. If nil, raw - // endpoints are disabled. It is set during Stack creation and is - // immutable. - unassociatedFactory UnassociatedEndpointFactory + // rawFactory creates raw endpoints. If nil, raw endpoints are + // disabled. It is set during Stack creation and is immutable. + rawFactory RawFactory demux *transportDemuxer @@ -362,9 +369,10 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -394,14 +402,31 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // portSeed is a one-time random value initialized at stack startup + // seed is a one-time random value initialized at stack startup // and is used to seed the TCP port picking on active connections // // TODO(gvisor.dev/issue/940): S/R this field. - portSeed uint32 + seed uint32 - // ndpConfigs is the NDP configurations used by interfaces. + // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations + + // autoGenIPv6LinkLocal determines whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. @@ -425,16 +450,35 @@ type Options struct { // stack (false). HandleLocal bool - // UnassociatedFactory produces unassociated endpoints raw endpoints. - // Raw endpoints are enabled only if this is non-nil. - UnassociatedFactory UnassociatedEndpointFactory + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID - // NDPConfigs is the NDP configurations used by interfaces. + // NDPConfigs is the default NDP configurations used by interfaces. // // By default, NDPConfigs will have a zero value for its // DupAddrDetectTransmits field, implying that DAD will not be performed // before assigning an address to a NIC. NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determins whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // Note, setting this to true does not mean that a link-local address + // will be assigned right away, or at all. If Duplicate Address + // Detection is enabled, an address will only be assigned if it + // successfully resolves. If it fails, no further attempt will be made + // to auto-generate an IPv6 link-local address. + // + // The generated link-local address will follow RFC 4291 Appendix A + // guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // RawFactory produces raw endpoints. Raw endpoints are enabled only if + // this is non-nil. + RawFactory RawFactory } // TransportEndpointInfo holds useful information about a transport endpoint @@ -481,22 +525,30 @@ func New(opts Options) *Stack { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - icmpRateLimiter: NewICMPRateLimiter(), - portSeed: generateRandUint32(), - ndpConfigs: opts.NDPConfigs, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + ndpConfigs: opts.NDPConfigs, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, + ndpDisp: opts.NDPDisp, } // Add specified network protocols. @@ -514,8 +566,8 @@ func New(opts Options) *Stack { } } - // Add the factory for unassociated endpoints, if present. - s.unassociatedFactory = opts.UnassociatedFactory + // Add the factory for raw endpoints, if present. + s.rawFactory = opts.RawFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -523,6 +575,11 @@ func New(opts Options) *Stack { return s } +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value @@ -584,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -650,12 +707,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if s.unassociatedFactory == nil { + if s.rawFactory == nil { return nil, tcpip.ErrNotPermitted } if !associated { - return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue) + return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue) } t, ok := s.transportProtocols[transport] @@ -666,6 +723,16 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return t.proto.NewRawEndpoint(s, network, waiterQueue) } +// NewPacketEndpoint creates a new packet endpoint listening for the given +// netProto. +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if s.rawFactory == nil { + return nil, tcpip.ErrNotPermitted + } + + return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) +} + // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { @@ -988,13 +1055,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. -func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { +func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { s.mu.RLock() defer s.mu.RUnlock() // If a NIC is specified, we try to find the address there only. - if nicid != 0 { - nic := s.nics[nicid] + if nicID != 0 { + nic := s.nics[nicID] if nic == nil { return 0 } @@ -1053,35 +1120,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { } // AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.add(fullAddr, linkAddr) // TODO: provide a way for a transport endpoint to receive a signal // that AddLinkAddress for a particular address has been called. } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() - nic := s.nics[nicid] + nic := s.nics[nicID] if nic == nil { s.mu.RUnlock() return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // RemoveWaker implements LinkAddressCache.RemoveWaker. -func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { +func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { s.mu.RLock() defer s.mu.RUnlock() - if nic := s.nics[nicid]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + if nic := s.nics[nicID]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.removeWaker(fullAddr, waker) } } @@ -1100,6 +1167,31 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() + + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1121,6 +1213,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + +// Close closes all currently registered transport endpoints. +// +// Endpoints created or modified during this call may not get closed. +func (s *Stack) Close() { + for _, e := range s.RegisteredEndpoints() { + e.Close() + } +} + +// Wait waits for all transport and link endpoints to halt their worker +// goroutines. +// +// Endpoints created or modified during this call may not get waited on. +// +// Note that link endpoints must be stopped via an implementation specific +// mechanism. +func (s *Stack) Wait() { + for _, e := range s.RegisteredEndpoints() { + e.Wait() + } + for _, e := range s.CleanupEndpoints() { + e.Wait() + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nics { + n.linkEP.Wait() + } +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { @@ -1135,6 +1290,109 @@ func (s *Stack) Resume() { } } +// RegisterPacketEndpoint registers ep with the stack, causing it to receive +// all traffic of the specified netProto on the given NIC. If nicID is 0, it +// receives traffic from every NIC. +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + // If no NIC is specified, capture on all devices. + if nicID == 0 { + // Register with each NIC. + for _, nic := range s.nics { + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + s.unregisterPacketEndpointLocked(0, netProto, ep) + return err + } + } + return nil + } + + // Capture on a specific device. + nic, ok := s.nics[nicID] + if !ok { + return tcpip.ErrUnknownNICID + } + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + return err + } + + return nil +} + +// UnregisterPacketEndpoint unregisters ep for packets of the specified +// netProto from the specified NIC. If nicID is 0, ep is unregistered from all +// NICs. +func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + s.mu.Lock() + defer s.mu.Unlock() + s.unregisterPacketEndpointLocked(nicID, netProto, ep) +} + +func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + // If no NIC is specified, unregister on all devices. + if nicID == 0 { + // Unregister with each NIC. + for _, nic := range s.nics { + nic.unregisterPacketEndpoint(netProto, ep) + } + return + } + + // Unregister in a single device. + nic, ok := s.nics[nicID] + if !ok { + return + } + nic.unregisterPacketEndpoint(netProto, ep) +} + +// WritePacket writes data directly to the specified NIC. It adds an ethernet +// header based on the arguments. +func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + // Add our own fake ethernet header. + ethFields := header.EthernetFields{ + SrcAddr: nic.linkEP.LinkAddress(), + DstAddr: dst, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + vv := buffer.View(fakeHeader).ToVectorisedView() + vv.Append(payload) + + if err := nic.linkEP.WriteRawPacket(vv); err != nil { + return err + } + + return nil +} + +// WriteRawPacket writes data directly to the specified NIC without adding any +// headers. +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + if err := nic.linkEP.WriteRawPacket(payload); err != nil { + return err + } + + return nil +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1286,12 +1544,47 @@ func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tc return nic.dupTentativeAddrDetected(addr) } -// PortSeed returns a 32 bit value that can be used as a seed value for port -// picking. +// SetNDPConfigurations sets the per-interface NDP configurations on the NIC +// with ID id to c. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.setNDPConfigs(c) + + return nil +} + +// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement +// message that it needs to handle. +func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.handleNDPRA(ip, ra) + + return nil +} + +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. // // NOTE: The seed is generated once during stack initialization only. -func (s *Stack) PortSeed() uint32 { - return s.portSeed +func (s *Stack) Seed() uint32 { + return s.seed } func generateRandUint32() uint32 { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 10fd1065f..8fc034ca1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -24,11 +24,14 @@ import ( "sort" "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -55,7 +58,7 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int proto *fakeNetworkProtocol @@ -68,7 +71,7 @@ func (f *fakeNetworkEndpoint) MTU() uint32 { } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid + return f.nicID } func (f *fakeNetworkEndpoint) PrefixLen() int { @@ -83,28 +86,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + b := pkt.Data.First() + pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() + nb := pkt.Data.First() if len(nb) < fakeNetHeaderLen { return } - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -119,32 +122,38 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) + b := pkt.Header.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + f.HandlePacket(r, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) } if loop&stack.PacketOut == 0 { return nil } - return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -189,9 +198,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, @@ -251,7 +260,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -261,7 +272,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -271,7 +284,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -280,7 +295,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -290,7 +307,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -310,7 +329,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -365,7 +387,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -660,11 +684,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } } -func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { +func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { t.Helper() - info, ok := s.NICInfo()[nicid] + info, ok := s.NICInfo()[nicID] if !ok { - t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + t.Fatalf("NICInfo() failed to find nicID=%d", nicID) } if len(addr) == 0 { // No address given, verify that there is no address assigned to the NIC. @@ -697,7 +721,7 @@ func TestEndpointExpiration(t *testing.T) { localAddrByte byte = 0x01 remoteAddr tcpip.Address = "\x03" noAddr tcpip.Address = "" - nicid tcpip.NICID = 1 + nicID tcpip.NICID = 1 ) localAddr := tcpip.Address([]byte{localAddrByte}) @@ -709,7 +733,7 @@ func TestEndpointExpiration(t *testing.T) { }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -726,13 +750,13 @@ func TestEndpointExpiration(t *testing.T) { buf[0] = localAddrByte if promiscuous { - if err := s.SetPromiscuousMode(nicid, true); err != nil { + if err := s.SetPromiscuousMode(nicID, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } } if spoofing { - if err := s.SetSpoofing(nicid, true); err != nil { + if err := s.SetSpoofing(nicID, true); err != nil { t.Fatal("SetSpoofing failed:", err) } } @@ -740,7 +764,7 @@ func TestEndpointExpiration(t *testing.T) { // 1. No Address yet, send should only work for spoofing, receive for // promiscuous mode. //----------------------- - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -755,20 +779,20 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -783,10 +807,10 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -804,10 +828,10 @@ func TestEndpointExpiration(t *testing.T) { // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -823,10 +847,10 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) testSend(t, r, ep, nil) @@ -834,17 +858,17 @@ func TestEndpointExpiration(t *testing.T) { // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -1637,12 +1661,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } func TestAddAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1650,7 +1674,7 @@ func TestAddAddress(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) for _, addrLen := range []int{4, 16} { address := addrGen.next(addrLen) - if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { t.Fatalf("AddAddress(address=%s) failed: %s", address, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1659,17 +1683,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1686,24 +1710,24 @@ func TestAddProtocolAddress(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) } expectedAddresses = append(expectedAddresses, protocolAddress) } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1714,7 +1738,7 @@ func TestAddAddressWithOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1724,17 +1748,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1753,7 +1777,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) } expectedAddresses = append(expectedAddresses, protocolAddress) @@ -1761,7 +1785,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1787,7 +1811,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1847,7 +1873,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) select { case <-ep2.C: @@ -1864,3 +1892,297 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } + +// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses +// (or lack there-of if disabled (default)). Note, DAD will be disabled in +// these tests. +func TestNICAutoGenAddr(t *testing.T) { + tests := []struct { + name string + autoGen bool + linkAddr tcpip.LinkAddress + shouldGen bool + }{ + { + "Disabled", + false, + linkAddr1, + false, + }, + { + "Enabled", + true, + linkAddr1, + true, + }, + { + "Nil MAC", + true, + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty MAC", + true, + tcpip.LinkAddress(""), + false, + }, + { + "Invalid MAC", + true, + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Multicast MAC", + true, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Unspecified MAC", + true, + tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when + // test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by + // default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + + if test.shouldGen { + // Should have auto-generated an address and + // resolved immediately (DAD is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) + } +} + +// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 +// link-local addresses will only be assigned after the DAD process resolves. +func TestNICAutoGenAddrDoesDAD(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + } + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Address should not be considered bound to the + // NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + linkLocalAddr := header.LinkLocalAddr(linkAddr1) + + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } +} + +// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected +// when an address's kind gets "promoted" to permanent from permanentExpired. +func TestNewPEBOnPromotionToPermanent(t *testing.T) { + pebs := []stack.PrimaryEndpointBehavior{ + stack.NeverPrimaryEndpoint, + stack.CanBePrimaryEndpoint, + stack.FirstPrimaryEndpoint, + } + + for _, pi := range pebs { + for _, ps := range pebs { + t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + // Add a permanent address with initial + // PrimaryEndpointBehavior (peb), pi. If pi is + // NeverPrimaryEndpoint, the address should not + // be returned by a call to GetMainNICAddress; + // else, it should. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + addr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if pi == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatalf("NewSubnet failed:", err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // Take a route through the address so its ref + // count gets incremented and does not actually + // get deleted when RemoveAddress is called + // below. This is because we want to test that a + // new peb is respected when an address gets + // "promoted" to permanent from a + // permanentExpired kind. + r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + defer r.Release() + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + + // + // At this point, the address should still be + // known by the NIC, but have its + // kind = permanentExpired. + // + + // Add some other address with peb set to + // FirstPrimaryEndpoint. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + + } + + // Add back the address we removed earlier and + // make sure the new peb was respected. + // (The address should just be promoted now). + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + var primaryAddrs []tcpip.Address + for _, pa := range s.NICInfo()[1].ProtocolAddresses { + primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) + } + var expectedList []tcpip.Address + switch ps { + case stack.FirstPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x01", + "\x03", + } + case stack.CanBePrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + "\x01", + } + case stack.NeverPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + } + } + if !cmp.Equal(primaryAddrs, expectedList) { + t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) + } + + // Once we remove the other address, if the new + // peb, ps, was NeverPrimaryEndpoint, no address + // should be returned by a call to + // GetMainNICAddress; else, our original address + // should be returned. + if err := s.RemoveAddress(1, "\x03"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + addr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if ps == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else { + if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + } + }) + } + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 92267ce4d..67c21be42 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,10 +17,10 @@ package stack import ( "fmt" "math/rand" + "sort" "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -41,6 +41,31 @@ type transportEndpoints struct { rawEndpoints []RawTransportEndpoint } +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + eps.mu.Lock() + defer eps.mu.Unlock() + epsByNic, ok := eps.endpoints[id] + if !ok { + return + } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + type endpointsByNic struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint @@ -48,9 +73,19 @@ type endpointsByNic struct { seed uint32 } +func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNic.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] @@ -64,18 +99,17 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { - mpep.handlePacketAll(r, id, vv) + mpep.handlePacketAll(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } - // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() defer epsByNic.mu.RUnlock() @@ -91,7 +125,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -127,21 +161,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T return len(epsByNic.endpoints) == 0 } -// unregisterEndpoint unregisters the endpoint with the given id such that it -// won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - eps.mu.Lock() - defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] - if !ok { - return - } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { - return - } - delete(eps.endpoints, id) -} - // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then @@ -183,14 +202,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. reuse bool } +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + return eps +} + // reciprocalScale scales a value into range [0, n). // // This is similar to val % n, but faster. @@ -224,22 +256,40 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpointsArr[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy except for - // the final one. + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, pkt) break } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. } +// Close implements stack.TransportEndpoint.Close. +func (ep *multiPortEndpoint) Close() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Close() + } +} + +// Wait implements stack.TransportEndpoint.Wait. +func (ep *multiPortEndpoint) Wait() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Wait() + } +} + // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { @@ -257,6 +307,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + + // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints + // can be restored in the same order. + sort.Slice(ep.endpointsArr, func(i, j int) bool { + return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() + }) + for i, e := range ep.endpointsArr { + ep.endpointsMap[e] = i + } return nil } @@ -330,9 +389,9 @@ var loopbackSubnet = func() tcpip.Subnet { }() // deliverPacket attempts to find one or more matching transport endpoints, and -// then, if matches are found, delivers the packet to them. Returns true if it -// found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { +// then, if matches are found, delivers the packet to them. Returns true if +// the packet no longer needs to be handled. +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -341,13 +400,38 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RLock() // Determine which transport endpoint or endpoints to deliver this packet to. - // If the packet is a broadcast or multicast, then find all matching - // transport endpoints. + // If the packet is a UDP broadcast or multicast, then find all matching + // transport endpoints. If the packet is a TCP packet with a non-unicast + // source or destination address, then do nothing further and instruct + // the caller to do the same. var destEps []*endpointsByNic - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, vv, id) - } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { - destEps = append(destEps, ep) + switch protocol { + case header.UDPProtocolNumber: + if isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, id) + break + } + + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) + } + + case header.TCPProtocolNumber: + if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single + // source and a single destination; the addresses must + // be unicast. + eps.mu.RUnlock() + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true + } + + fallthrough + + default: + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) + } } eps.mu.RUnlock() @@ -361,17 +445,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // Deliver the packet. - for _, ep := range destEps { - ep.handlePacket(r, id, vv) + // HandlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEps[:len(destEps)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEps[len(destEps)-1].handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -385,7 +471,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + rawEP.HandlePacket(r, pkt) foundRaw = true } eps.mu.RUnlock() @@ -395,7 +481,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -403,7 +489,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() // Fail if we didn't find one. @@ -412,12 +498,12 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // Deliver the packet. - ep.handleControlPacket(n, id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, pkt) return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -445,14 +531,44 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv bu if ep, ok := eps.endpoints[nid]; ok { matchedEPs = append(matchedEPs, ep) } - return matchedEPs } +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil + } + // Try to find the endpoint. + eps.mu.RLock() + epsByNic := d.findEndpointLocked(eps, id) + // Fail if we didn't find one. + if epsByNic == nil { + eps.mu.RUnlock() + return nil + } + + epsByNic.mu.RLock() + eps.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } + } + + ep := selectEndpoint(id, mpep, epsByNic.seed) + epsByNic.mu.RUnlock() + return ep +} + // findEndpointLocked returns the endpoint that most closely matches the given // id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { return matchedEPs[0] } return nil @@ -465,7 +581,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return nil + return tcpip.ErrNotSupported } eps.mu.Lock() @@ -496,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN func isMulticastOrBroadcast(addr tcpip.Address) bool { return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) } + +func isUnicast(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 210233dc0..3b28b06d0 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -79,17 +79,17 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) linkEPs := make(map[string]*channel.Endpoint) for i, linkEpName := range linkEpNames { channelEP := channel.New(256, mtu, "") - nicid := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + nicID := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } linkEPs[linkEpName] = channelEP - if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -156,7 +156,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..748ce4ea5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -43,6 +43,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint @@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -82,7 +83,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(v).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -144,6 +148,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -192,7 +200,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -209,7 +217,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -218,15 +226,15 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool @@ -251,7 +259,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -266,7 +274,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -337,7 +345,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -346,7 +356,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -355,7 +367,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -408,7 +422,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -417,7 +433,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -426,7 +444,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -579,7 +599,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: req.ToVectorisedView(), + }) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -598,10 +620,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Header[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Header[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 678a94616..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -231,6 +231,13 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// Equal returns true if s equals o. +// +// Needed to use cmp.Equal on Subnet as its fields are unexported. +func (s Subnet) Equal(o Subnet) bool { + return s == o +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 @@ -255,7 +262,7 @@ type FullAddress struct { // This may not be used by all endpoint types. NIC NICID - // Addr is the network address. + // Addr is the network or link layer address. Addr Address // Port is the transport port. @@ -301,7 +308,7 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packed used to create + // Timestamp is the time (in ns) that the last packet used to create // the read data was received. Timestamp int64 @@ -310,6 +317,18 @@ type ControlMessages struct { // Inq is the number of bytes ready to be received. Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS int8 + + // HasTClass indicates whether Tclass is valid/set. + HasTClass bool + + // Tclass is the IPv6 traffic class of the associated packet. + TClass int32 } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -489,6 +508,11 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -501,11 +525,6 @@ type ErrorOption struct{} // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int -// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be -// sent out immediately by the transport protocol. For TCP, it determines if the -// Nagle algorithm is on or off. -type DelayOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int @@ -557,6 +576,11 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string @@ -579,6 +603,16 @@ type MaxSegOption int // A zero value indicates the default. type TTLOption uint8 +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration + +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -673,6 +707,11 @@ func (s *StatCounter) Increment() { s.IncrementBy(1) } +// Decrement minuses one to the counter. +func (s *StatCounter) Decrement() { + s.IncrementBy(^uint64(0)) +} + // Value returns the current value of the counter. func (s *StatCounter) Value() uint64 { return atomic.LoadUint64(&s.count) @@ -881,6 +920,23 @@ type TCPStats struct { // successfully via Listen. PassiveConnectionOpenings *StatCounter + // CurrentEstablished is the number of TCP connections for which the + // current state is either ESTABLISHED or CLOSE-WAIT. + CurrentEstablished *StatCounter + + // EstablishedResets is the number of times TCP connections have made + // a direct transition to the CLOSED state from either the + // ESTABLISHED state or the CLOSE-WAIT state. + EstablishedResets *StatCounter + + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter @@ -1308,8 +1364,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index a52262e87..48764b978 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9254c3dea..d8c5b5058 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -38,11 +38,3 @@ go_library( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "icmp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3187b336b..9c40931b5 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -31,9 +31,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -58,6 +55,7 @@ type endpoint struct { // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -90,9 +88,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -274,13 +278,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } toCopy := *to @@ -291,7 +295,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -425,7 +429,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: data.ToVectorisedView(), + TransportHeader: buffer.View(icmpv4), + }) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -445,13 +453,17 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - icmpv6.SetChecksum(0) - icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + dataVV := data.ToVectorisedView() + icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: dataVV, + TransportHeader: buffer.View(icmpv6), + }) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { @@ -479,7 +491,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC localPort := uint16(0) switch e.state { case stateBound, stateConnected: @@ -488,11 +500,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -503,7 +515,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -520,14 +532,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id e.route = r.Clone() - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.state = stateConnected @@ -578,18 +590,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -714,18 +726,18 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(vv.First()) + h := header.ICMPv4(pkt.Data.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(vv.First()) + h := header.ICMPv6(pkt.Data.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -753,19 +765,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &icmpPacket{ + packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + packet.data = pkt.Data - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() @@ -776,7 +788,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't @@ -798,3 +810,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index bfb16f7c3..9ce500e80 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD new file mode 100644 index 000000000..44b58ff6b --- /dev/null +++ b/pkg/tcpip/transport/packet/BUILD @@ -0,0 +1,38 @@ +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_template_instance( + name = "packet_list", + out = "packet_list.go", + package = "packet", + prefix = "packet", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*packet", + "Linker": "*packet", + }, +) + +go_library( + name = "packet", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "packet_list.go", + ], + importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/iptables", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go new file mode 100644 index 000000000..0010b5e5f --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -0,0 +1,362 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packet provides the implementation of packet sockets (see +// packet(7)). Packet sockets allow applications to: +// +// * manually write and inspect link, network, and transport headers +// * receive all traffic of a given network protocol, or all protocols +// +// Packet sockets are similar to raw sockets, but provide even more power to +// users, letting them effectively talk directly to the network device. +// +// Packet sockets skip the input and output iptables chains. +package packet + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal +// to have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + cooked bool + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +// NewEndpoint returns a new packet endpoint. +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + }, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { + return nil, err + } + return ep, nil +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + ep.closed = true + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (ep *endpoint) ModerateRecvBuf(copied int) {} + +// IPTables implements tcpip.Endpoint.IPTables. +func (ep *endpoint) IPTables() (iptables.IPTables, error) { + return ep.stack.IPTables(), nil +} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + ep.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // TODO(b/129292371): Implement. + return 0, nil, tcpip.ErrInvalidOptionValue +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be +// disconnected, and this function always returns tpcip.ErrNotSupported. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be +// connected, and this function always returnes tcpip.ErrNotSupported. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used +// with Shutdown, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with +// Listen, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with +// Accept, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + // TODO(gvisor.dev/issue/173): Add Bind support. + + // "By default, all packets of the specified protocol type are passed + // to a packet socket. To get packets only from a specific interface + // use bind(2) specifying an address in a struct sockaddr_ll to bind + // the packet socket to an interface. Fields used for binding are + // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." + // - packet(7). + + return tcpip.ErrNotSupported +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be +// used with SetSockOpt, and this function always returns +// tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// HandlePacket implements stack.PacketEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + var packet packet + // TODO(b/129292371): Return network protocol. + if len(pkt.LinkHeader) > 0 { + // Get info directly from the ethernet header. + hdr := header.Ethernet(pkt.LinkHeader) + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(hdr.SourceAddress()), + } + } else { + // Guess the would-be ethernet header. + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(localAddr), + } + } + + if ep.cooked { + // Cooked packets can simply be queued. + packet.data = pkt.Data + } else { + // Raw packets need their ethernet headers prepended before + // queueing. + var linkHeader buffer.View + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + } + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(&packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + ep.stats.PacketsReceived.Increment() + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} + +// Info returns a copy of the endpoint info. +func (ep *endpoint) Info() tcpip.EndpointInfo { + ep.mu.RLock() + // Make a copy of the endpoint info. + ret := ep.TransportEndpointInfo + ep.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (ep *endpoint) Stats() tcpip.EndpointStats { + return &ep.stats +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go new file mode 100644 index 000000000..9b88f17e4 --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -0,0 +1,72 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. + if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index fba598d51..00991ac8e 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -4,14 +4,14 @@ load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) go_template_instance( - name = "packet_list", - out = "packet_list.go", + name = "raw_packet_list", + out = "raw_packet_list.go", package = "raw", - prefix = "packet", + prefix = "rawPacket", template = "//pkg/ilist:generic_list", types = { - "Element": "*packet", - "Linker": "*packet", + "Element": "*rawPacket", + "Linker": "*rawPacket", }, ) @@ -20,8 +20,8 @@ go_library( srcs = [ "endpoint.go", "endpoint_state.go", - "packet_list.go", "protocol.go", + "raw_packet_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], @@ -34,14 +34,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/packet", "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b4c660859..5aafe2615 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -17,8 +17,7 @@ // // * manually write and inspect transport layer headers and payloads // * receive all traffic of a given transport protocol (e.g. ICMP or UDP) -// * optionally write and inspect network layer and link layer headers for -// packets +// * optionally write and inspect network layer headers of packets // // Raw sockets don't have any notion of ports, and incoming packets are // demultiplexed solely by protocol number. Thus, a raw UDP endpoint will @@ -38,15 +37,11 @@ import ( ) // +stateify savable -type packet struct { - packetEntry +type rawPacket struct { + rawPacketEntry // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -72,7 +67,7 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + rcvList rawPacketList rcvBufSizeMax int `state:".(int)"` rcvBufSize int rcvClosed bool @@ -90,7 +85,6 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -// TODO(b/129292371): IP_HDRINCL and AF_PACKET. func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } @@ -187,17 +181,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess return buffer.View{}, tcpip.ControlMessages{}, err } - packet := e.rcvList.Front() - e.rcvList.Remove(packet) - e.rcvBufSize -= packet.data.Size() + pkt := e.rcvList.Front() + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() e.rcvMu.Unlock() if addr != nil { - *addr = packet.senderAddr + *addr = pkt.senderAddr } - return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil + return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil } // Write implements tcpip.Endpoint.Write. @@ -344,13 +338,18 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -561,7 +560,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -602,16 +601,17 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - packet := &packet{ + packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: route.NICID(), Addr: route.RemoteAddress, }, } - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) + networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) + combinedVV := networkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV packet.timestampNS = e.stack.NowNanoseconds() e.rcvList.PushBack(packet) @@ -643,3 +643,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index a6c7cc43a..33bfb56cd 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -20,15 +20,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// saveData saves packet.data field. -func (p *packet) saveData() buffer.VectorisedView { +// saveData saves rawPacket.data field. +func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } -// loadData loads packet.data field. -func (p *packet) loadData(data buffer.VectorisedView) { +// loadData loads rawPacket.data field. +func (p *rawPacket) loadData(data buffer.VectorisedView) { // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point @@ -86,7 +86,9 @@ func (ep *endpoint) Resume(s *stack.Stack) { } } - if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { - panic(err) + if ep.associated { + if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + panic(err) + } } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index a2512d666..f30aa2a4a 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -17,13 +17,19 @@ package raw import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/packet" "gvisor.dev/gvisor/pkg/waiter" ) -// EndpointFactory implements stack.UnassociatedEndpointFactory. +// EndpointFactory implements stack.RawFactory. type EndpointFactory struct{} -// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. -func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } + +// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index aed70e06f..3b353d56c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -28,6 +28,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", @@ -44,6 +45,7 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", @@ -51,6 +53,7 @@ go_library( "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", @@ -60,17 +63,9 @@ go_library( ], ) -filegroup( - name = "autogen", - srcs = [ - "tcp_segment_list.go", - ], - visibility = ["//:sandbox"], -) - go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 844959fa0..5422ae80c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -241,8 +242,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() + // Now inherit any socket options that should be inherited from the + // listening endpoint. + // In case of Forwarder listenEP will be nil and hence this check. + if l.listenEP != nil { + l.listenEP.propagateInheritableOptions(n) + } + // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() return nil, err } @@ -268,8 +276,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -288,7 +296,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -297,7 +305,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.state = StateEstablished + ep.isConnectNotified = true ep.mu.Unlock() // Update the receive window scaling. We can't do it before the @@ -349,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } +// propagateInheritableOptions propagates any options set on the listening +// endpoint to the newly created endpoint. +func (e *endpoint) propagateInheritableOptions(n *endpoint) { + e.mu.Lock() + n.userTimeout = e.userTimeout + e.mu.Unlock() +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. @@ -359,6 +375,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -366,6 +383,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -399,8 +421,19 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - switch s.flags { - case header.TCPFlagSyn: + if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + return + } + + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { // Only handle the syn if the following conditions hold @@ -433,19 +466,18 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: mssForRoute(&s.route), } e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -459,6 +491,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } if !synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. @@ -519,7 +559,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + // We do not use transitionToStateEstablishedLocked here as there is + // no handshake state available when doing a SYN cookie based accept. + n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished + n.isConnectNotified = true // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -530,6 +574,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5ea036bea..cdd69f360 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -78,9 +80,6 @@ type handshake struct { // mss is the maximum segment size received from the peer. mss uint16 - // amss is the maximum segment size advertised by us to the peer. - amss uint16 - // sndWndScale is the send window scale, as defined in RFC 1323. A // negative value means no scaling is supported by the peer. sndWndScale int @@ -142,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -194,6 +218,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK + // was acceptable then signal the user "error: connection reset", drop + // the segment, enter CLOSED state, delete TCB, and return." + h.ep.mu.Lock() + h.ep.workerCleanup = true + h.ep.mu.Unlock() + // Although the RFC above calls out ECONNRESET, Linux actually returns + // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused } return nil @@ -228,6 +260,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -275,6 +312,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } + // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a + // sequence number outside of the window causes an ACK with the proper seq + // number and "After sending the acknowledgment, drop the unacceptable + // segment and return." + if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + return nil + } + if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole @@ -319,6 +365,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + return nil } @@ -445,7 +495,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -607,19 +657,14 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -// sendTCP sends a TCP segment with the provided options via the provided -// network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { optLen := len(opts) - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff - } - + hdr := &pkt.Header + packetSize := pkt.DataSize + off := pkt.DataOffset // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + pkt.TransportHeader = buffer.View(tcp) tcp.Encode(&header.TCPFields{ SrcPort: id.LocalPort, DstPort: id.RemotePort, @@ -631,7 +676,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise }) copy(tcp[header.TCPMinimumSize:], opts) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(hdr.UsedLength() + packetSize) xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { @@ -641,14 +686,79 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVV(data, xsum) + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } +} + +func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + mss := int(gso.MSS) + n := (data.Size() + mss - 1) / mss + + // Allocate one big slice for all the headers. + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + buf := make([]byte, n*hdrSize) + pkts := make([]tcpip.PacketBuffer, n) + for i := range pkts { + pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) + } + + size := data.Size() + off := 0 + for i := 0; i < n; i++ { + packetSize := mss + if packetSize > size { + packetSize = size + } + size -= packetSize + pkts[i].DataOffset = off + pkts[i].DataSize = packetSize + pkts[i].Data = data + buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso) + off += packetSize + seq = seq.Add(seqnum.Size(packetSize)) + } + if ttl == 0 { + ttl = r.DefaultTTL() + } + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + if err != nil { + r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) + } + r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent)) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + } + + pkt := tcpip.PacketBuffer{ + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + DataOffset: 0, + DataSize: data.Size(), + Data: data, + } + buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso) + if ttl == 0 { ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } @@ -754,10 +864,26 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. + if e.state == StateEstablished || e.state == StateCloseWait { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } e.state = StateError e.HardError = err - if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { + resetSeqNum = e.snd.sndNxt + } + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) } } @@ -771,6 +897,99 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver if required. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + if e.snd == nil { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + } + if e.rcv == nil { + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + } + h.ep.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished +} + +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + if e.state == StateClose { + return + } + e.cleanupLocked() + e.state = StateClose + e.stack.Stats().TCP.EstablishedClosed.Increment() +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil { + replyWithReset(s) + s.decRef() + return + } + ep.(*endpoint).enqueueSegment(s) +} + +func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + e.mu.Lock() + switch e.state { + // In case of a RST in CLOSE-WAIT linux moves + // the socket to closed state with an error set + // to indicate EPIPE. + // + // Technically this seems to be at odds w/ RFC. + // As per https://tools.ietf.org/html/rfc793#section-2.7 + // page 69 the behavior for a segment arriving + // w/ RST bit set in CLOSE-WAIT is inlined below. + // + // ESTABLISHED + // FIN-WAIT-1 + // FIN-WAIT-2 + // CLOSE-WAIT + + // If the RST bit is set then, any outstanding RECEIVEs and + // SEND should receive "reset" responses. All segment queues + // should be flushed. Users should also receive an unsolicited + // general "connection reset" signal. Enter the CLOSED state, + // delete the TCB, and return. + case StateCloseWait: + e.transitionToStateCloseLocked() + e.HardError = tcpip.ErrAborted + e.mu.Unlock() + return false, nil + default: + e.mu.Unlock() + return false, tcpip.ErrConnectionReset + } + } + return true, nil +} + // handleSegments pulls segments from the queue and processes them. It returns // no error if the protocol loop should continue, an error otherwise. func (e *endpoint) handleSegments() *tcpip.Error { @@ -788,14 +1007,34 @@ func (e *endpoint) handleSegments() *tcpip.Error { } if s.flagIsSet(header.TCPFlagRst) { - if e.rcv.acceptable(s.sequenceNumber, 0) { - // RFC 793, page 37 states that "in all states - // except SYN-SENT, all reset (RST) segments are - // validated by checking their SEQ-fields." So - // we only process it if it's acceptable. - s.decRef() - return tcpip.ErrConnectionReset + if ok, err := e.handleReset(s); !ok { + return err } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -804,7 +1043,33 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return nil + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -830,14 +1095,27 @@ func (e *endpoint) handleSegments() *tcpip.Error { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.mu.RLock() + userTimeout := e.userTimeout + e.mu.RUnlock() + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -854,7 +1132,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -862,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -869,6 +1147,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -903,7 +1182,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -932,21 +1210,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -954,7 +1217,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -985,13 +1247,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.transitionToStateCloseLocked() + e.mu.Unlock() + return nil }, }, { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil @@ -1028,17 +1297,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1054,12 +1324,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != StateError { + if e.state != StateClose && e.state != StateError { + // Only block the worker if the endpoint + // is not in closed state or error state. close(e.drainDone) <-e.undrain } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1086,15 +1364,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1110,15 +1389,168 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - e.state = StateClose + e.stack.Stats().TCP.CurrentEstablished.Decrement() + e.transitionToStateCloseLocked() } + // Lock released below. epilogue() + // epilogue removes the endpoint from the transport-demuxer and + // unlocks e.mu. Now that no new segments can get enqueued to this + // endpoint, try to re-match the segment to a different endpoint + // as the current endpoint is closed. + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } + + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a1b784b49..fe629aa40 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -121,6 +122,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -287,6 +293,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -319,6 +326,11 @@ type endpoint struct { state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` @@ -330,6 +342,11 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple @@ -411,7 +428,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -458,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -502,6 +525,36 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS } // StopWork halts packet processing. Only to be used in tests. @@ -550,6 +603,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption @@ -572,6 +626,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptInt(tcpip.DelayOption, 1) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -659,6 +723,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -671,14 +742,18 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -704,9 +779,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -732,16 +805,19 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -752,7 +828,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -1133,16 +1211,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { e.sndBufMu.Unlock() return nil - default: - return nil - } -} - -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 - switch v := opt.(type) { case tcpip.DelayOption: if v == 0 { atomic.StoreUint32(&e.delay, 0) @@ -1154,6 +1222,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { case tcpip.CorkOption: if v == 0 { atomic.StoreUint32(&e.cork, 0) @@ -1184,9 +1262,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = 0 return nil } - for nicid, nic := range e.stack.NICInfo() { + for nicID, nic := range e.stack.NICInfo() { if nic.Name == string(v) { - e.bindToDevice = nicid + e.bindToDevice = nicID return nil } } @@ -1206,7 +1284,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidOptionValue } e.mu.Lock() - e.userMSS = int(userMSS) + e.userMSS = uint16(userMSS) e.mu.Unlock() e.notifyProtocolGoroutine(notifyMSSChanged) return nil @@ -1262,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1319,6 +1403,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1345,6 +1451,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: e.sndBufMu.Lock() v := e.sndBufSize @@ -1357,8 +1464,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil + case tcpip.DelayOption: + var o int + if v := atomic.LoadUint32(&e.delay); v != 0 { + o = 1 + } + return o, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -1379,13 +1494,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.DelayOption: - *o = 0 - if v := atomic.LoadUint32(&e.delay); v != 0 { - *o = 1 - } - return nil - case *tcpip.CorkOption: *o = 0 if v := atomic.LoadUint32(&e.cork); v != 0 { @@ -1496,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 @@ -1530,6 +1644,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1602,7 +1722,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnected } - nicid := addr.NIC + nicID := addr.NIC switch e.state { case StateBound: // If we're already bound to a NIC but the caller is requesting @@ -1611,11 +1731,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1634,7 +1754,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -1649,7 +1769,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) if err != nil { return err } @@ -1664,7 +1784,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1678,15 +1798,18 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // reusePort is false below because connect cannot reuse a port even if // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { return false, nil } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: + // Port picking successful. Save the details of + // the selected port. e.ID = id + e.boundBindToDevice = e.bindToDevice return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1702,14 +1825,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } e.isRegistered = true e.state = StateConnecting e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1729,6 +1852,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.state = StateEstablished + e.stack.Stats().TCP.CurrentEstablished.Increment() } if run { @@ -1749,9 +1873,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1766,6 +1889,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1784,14 +1908,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1799,11 +1920,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } @@ -1844,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } + if e.state == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() @@ -1851,7 +1990,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { return err } @@ -1895,12 +2034,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -1908,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + return e.bindLocked(addr) +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. @@ -1932,26 +2070,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = flags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func(bindToDevice tcpip.NICID) { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }(e.bindToDevice) + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -2001,8 +2146,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -2025,6 +2170,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() @@ -2037,7 +2186,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -2327,11 +2476,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { return s } -func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityGSO == 0 { - return - } - +func (e *endpoint) initHardwareGSO() { gso := &stack.GSO{} switch e.route.NetProto { case header.IPv4ProtocolNumber: @@ -2349,6 +2494,18 @@ func (e *endpoint) initGSO() { e.gso = gso } +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + e.initHardwareGSO() + } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + e.gso = &stack.GSO{ + MaxSize: e.route.GSOMaxSize(), + Type: stack.GSOSW, + NeedsCsum: false, + } + } +} + // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { @@ -2371,6 +2528,23 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index eae17237e..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -217,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) { if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..4983bca81 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index db40785d3..bc718064c 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,12 +55,23 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP // protocol. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol. +type DelayEnabled bool + // SendBufferSizeOption allows the default, min and max send buffer sizes for // TCP endpoints to be queried or configured. type SendBufferSizeOption struct { @@ -84,11 +96,14 @@ const ( type protocol struct { mu sync.Mutex sackEnabled bool + delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -97,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, waiterQueue), nil } @@ -126,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { @@ -147,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo func replyWithReset(s *segment) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. @@ -165,6 +193,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case DelayEnabled: + p.mu.Lock() + p.delayEnabled = bool(v) + p.mu.Unlock() + return nil + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue @@ -202,6 +236,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -216,6 +268,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *DelayEnabled: + p.mu.Lock() + *v = DelayEnabled(p.delayEnabled) + p.mu.Unlock() + return nil + case *SendBufferSizeOption: p.mu.Lock() *v = p.sendBufferSize @@ -246,6 +304,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -258,5 +328,7 @@ func NewProtocol() stack.TransportProtocol { recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..0a5534959 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -49,16 +50,20 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } @@ -204,15 +209,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: - r.ep.state = StateClose + r.ep.transitionToStateCloseLocked() } r.ep.mu.Unlock() } @@ -253,25 +263,110 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { @@ -288,7 +383,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +410,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index ea725d513..1c10da5ca 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -60,13 +61,13 @@ type segment struct { xmitTime time.Time `state:".(unixTime)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) s.rcvdTime = time.Now() return s } @@ -99,8 +100,14 @@ func (s *segment) clone() *segment { return t } -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags } func (s *segment) decRef() { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8332a0179..8a947dc66 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -28,8 +28,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -134,6 +137,10 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + closed bool writeNext *segment writeList segmentList @@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < MinRTO { + s.rto = MinRTO } } @@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool { s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + s.ep.mu.RLock() + uto := s.ep.userTimeout + s.ep.mu.RUnlock() + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := MaxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + if remaining <= 0 || s.rto >= MaxRTO { return false } @@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // // Retransmit timeouts: @@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6d022a266..e8fe4dab5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) { if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestConnectIncrementActiveConnection(t *testing.T) { @@ -206,17 +220,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -256,7 +271,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -264,6 +279,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -285,6 +307,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -296,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -376,6 +403,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -396,12 +430,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -413,15 +459,128 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(10 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) +} + func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -474,6 +633,352 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv4. +func TestTCPAckBeforeAcceptV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv6. +func TestTCPAckBeforeAcceptV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint not processing +// any receive when it is on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + // Wait for the endpoint state to be propagated. + time.Sleep(10 * time.Millisecond) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + c.CheckNoPacket("Packet received when listening socket was shutdown") +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -635,6 +1140,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestRstOnSynSent(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) + + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + // Ensure that we've reached SynSent state + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + // Send a packet with a proper ACK and a RST flag to cause the socket + // to Error and close out + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused) + } + + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -930,8 +1500,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1623,7 +2192,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -1671,7 +2240,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -1704,7 +2273,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOpt(tcpip.DelayOption(0)) + c.EP.SetSockOptInt(tcpip.DelayOption, 0) // Check that data is received. second := c.GetPacket() @@ -1741,7 +2310,7 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, } @@ -2211,6 +2780,13 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -2905,6 +3481,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -2912,12 +3495,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2955,10 +3538,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -2975,7 +3557,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -2986,7 +3568,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -2996,11 +3578,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -3773,8 +4355,9 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) @@ -3864,19 +4447,43 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -3890,7 +4497,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki RcvWnd: 30000, }) - // Receive the SYN-ACK reply.w + // Receive the SYN-ACK reply. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss = seqnum.Value(tcp.SequenceNumber()) @@ -3923,6 +4530,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki return irs, iss } +func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptionsV6(), + })) + } + + checker.IPv6(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + // TestListenBacklogFull tests that netstack does not complete handshakes if the // listen backlog for the endpoint is full. func TestListenBacklogFull(t *testing.T) { @@ -4025,6 +4676,210 @@ func TestListenBacklogFull(t *testing.T) { } } +// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a +// non unicast IPv4 address are not accepted. +func TestListenNoAcceptNonUnicastV4(t *testing.T) { + multicastAddr := tcpip.Address("\xe0\x00\x01\x02") + otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv4Any, + context.StackAddr, + }, + { + "SourceBroadcast", + header.IPv4Broadcast, + context.StackAddr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackAddr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackAddr, + }, + { + "DestUnspecified", + context.TestAddr, + header.IPv4Any, + }, + { + "DestBroadcast", + context.TestAddr, + header.IPv4Broadcast, + }, + { + "DestOurMulticast", + context.TestAddr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestAddr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestAddr, context.StackAddr) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + +// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a +// non unicast IPv6 address are not accepted. +func TestListenNoAcceptNonUnicastV6(t *testing.T) { + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv6Any, + context.StackV6Addr, + }, + { + "SourceAllNodes", + header.IPv6AllNodesMulticastAddress, + context.StackV6Addr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackV6Addr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackV6Addr, + }, + { + "DestUnspecified", + context.TestV6Addr, + header.IPv6Any, + }, + { + "DestAllNodes", + context.TestV6Addr, + header.IPv6AllNodesMulticastAddress, + }, + { + "DestOurMulticast", + context.TestV6Addr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestV6Addr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestV6Addr, context.StackV6Addr) + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + func TestListenSynRcvdQueueFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4056,7 +4911,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { SrcPort: context.TestPort, DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: irs, RcvWnd: 30000, }) @@ -4167,7 +5022,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4207,6 +5063,125 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } +func TestSynRcvdBadSeqNumber(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcpHdr.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now send a packet with an out-of-window sequence number + largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: largeSeqnum, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Should receive an ACK with the expected SEQ number + b = c.GetPacket() + tcpCheckers = []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.AckNum(uint32(irs) + 1), + checker.SeqNum(uint32(iss + 1)), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now that the socket replied appropriately with the ACK, + // complete the connection to test that the large SEQ num + // did not change the state from SYN-RCVD. + + // Send ACK to move to ESTABLISHED state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + newEP, _, err := c.EP.Accept() + + if err != nil && err != tcpip.ErrWouldBlock { + t.Fatalf("Accept failed: %s", err) + } + + if err == tcpip.ErrWouldBlock { + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + pkt := c.GetPacket() + tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + if string(tcpHdr.Payload()) != data { + t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) + } +} + func TestPassiveConnectionAttemptIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4381,6 +5356,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) + } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -4668,3 +5646,918 @@ func TestReceiveBufferAutoTuning(t *testing.T) { payloadSize *= 2 } } + +func TestDelayEnabled(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, 0) // Delay is disabled by default. + + for _, v := range []struct { + delayEnabled tcp.DelayEnabled + wantDelayOption int + }{ + {delayEnabled: false, wantDelayOption: 0}, + {delayEnabled: true, wantDelayOption: 1}, + } { + c := context.New(t, defaultMTU) + defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) + } + checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + } +} + +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { + t.Helper() + + var gotDelayEnabled tcp.DelayEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) + } + if gotDelayEnabled != wantDelayEnabled { + t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) + } + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) + if err != nil { + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) + } + gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { + t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + } + if gotDelayOption != wantDelayOption { + t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + } +} + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } +} + +func TestTCPCloseWithData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + + c.EP.Close() + // Check the FIN. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.AckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Check the RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + userTimeout := 50 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for a little over the minimum retransmit timeout of 200ms for + // the retransmitTimer to fire and close the connection. + time.Sleep(tcp.MinRTO + 10*time.Millisecond) + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 10 * time.Millisecond + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // Set userTimeout to be the duration for 3 keepalive probes. + userTimeout := 30 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Now receive 2 keepalives, but don't ACK them. The connection should + // be reset when the 3rd one should be sent due to userTimeout being + // 30ms and each keepalive probe should be sent 10ms apart as set above after + // the connection has been idle for 10ms. + for i := 0; i < 2; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + + // The connection should be terminated after 30ms. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index 19b0d31c5..b33ec2087 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -8,7 +8,7 @@ go_library( srcs = ["context.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ - "//:sandbox", + "//visibility:public", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ef823e4ae..b0a376eba 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -231,14 +231,15 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -259,14 +260,15 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -302,11 +304,19 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // BuildSegment builds a TCP segment based on the given Headers and payload. func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { + return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) +} + +// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, +// payload and source and destination IPv4 addresses. +func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -319,8 +329,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, + SrcAddr: src, + DstAddr: dst, }) ip.SetChecksum(^ip.CalculateChecksum()) @@ -337,7 +347,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -350,13 +360,26 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: s, + }) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) +} + +// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload +// & TCPheaders) in an IPv4 packet via the link layer endpoint using the +// provided source and destination IPv4 addresses. +func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegmentWithAddrs(payload, h, src, dst), + }) } // SendAck sends an ACK packet. @@ -462,14 +485,15 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -484,6 +508,13 @@ func (c *Context) GetV6Packet() []byte { // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of // the context. func (c *Context) SendV6Packet(payload []byte, h *Headers) { + c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) +} + +// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer +// endpoint of the context using the provided source and destination IPv6 +// addresses. +func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -494,8 +525,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { PayloadLength: uint16(header.TCPMinimumSize + len(payload)), NextHeader: uint8(tcp.ProtocolNumber), HopLimit: 65, - SrcAddr: TestV6Addr, - DstAddr: StackV6Addr, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. @@ -511,14 +542,16 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // CreateConnected creates a connected TCP endpoint. @@ -1059,3 +1092,9 @@ func (c *Context) SetGSOEnabled(enable bool) { func (c *Context) MSSWithoutOptions() uint16 { return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) } + +// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no +// options are in use for IPv6 packets. +func (c *Context) MSSWithoutOptionsV6() uint16 { + return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c9460aa0d..97e4d5825 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", @@ -59,11 +60,3 @@ go_test( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "udp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6e87245b7..1ac4705af 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -31,9 +32,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } // EndpointState represents the state of a UDP endpoint. @@ -80,6 +78,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -106,6 +105,11 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + // Values used to reserve a port or register a transport endpoint. + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 @@ -140,7 +144,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, - TransProto: header.TCPProtocolNumber, + TransProto: header.UDPProtocolNumber, }, waiterQueue: waiterQueue, // RFC 1075 section 5.4 recommends a TTL of 1 for membership @@ -160,9 +164,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: StateInitial, + uniqueID: s.UniqueID(), } } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -171,8 +181,10 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } for _, mem := range e.multicastMemberships { @@ -278,7 +290,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -286,20 +298,20 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr } if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicid == 0 { - nicid = e.multicastNICID + if nicID == 0 { + nicID = e.multicastNICID } - if localAddr == "" && nicid == 0 { + if localAddr == "" && nicID == 0 { localAddr = e.multicastAddr } } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { return stack.Route{}, 0, err } - return r, nicid, nil + return r, nicID, nil } // Write writes data to the endpoint's peer. This method does not block @@ -378,13 +390,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -396,7 +408,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - r, _, err := e.connectRoute(nicid, *to, netProto) + r, _, err := e.connectRoute(nicID, *to, netProto) if err != nil { return 0, nil, err } @@ -618,9 +630,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = 0 return nil } - for nicid, nic := range e.stack.NICInfo() { + for nicID, nic := range e.stack.NICInfo() { if nic.Name == string(v) { - e.bindToDevice = nicid + e.bindToDevice = nicID return nil } } @@ -728,6 +740,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.MulticastLoopOption(v) return nil + case *tcpip.ReuseAddressOption: + *o = 0 + return nil + case *tcpip.ReusePortOption: e.mu.RLock() v := e.reusePort @@ -809,7 +825,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + Header: hdr, + Data: data, + TransportHeader: buffer.View(udp), + }); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -859,7 +879,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { if e.state != StateConnected { return nil } - id := stack.TransportEndpointID{} + var ( + id stack.TransportEndpointID + btd tcpip.NICID + ) // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error @@ -867,7 +890,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } @@ -875,13 +898,15 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundPortFlags = ports.Flags{} } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.ID = id + e.boundBindToDevice = btd e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -903,7 +928,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC var localPort uint16 switch e.state { case StateInitial: @@ -913,16 +938,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } - r, nicid, err := e.connectRoute(nicid, addr, netProto) + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err } @@ -950,20 +975,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } - id, err = e.registerWithStack(nicid, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) } e.ID = id + e.boundBindToDevice = btd e.route = r.Clone() e.dstPort = addr.Port - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos e.state = StateConnected @@ -1018,20 +1044,27 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + // FIXME(b/129164367): Support SO_REUSEADDR. + MostRecent: false, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) if err != nil { - return id, err + return id, e.bindToDevice, err } + e.boundPortFlags = flags id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.boundPortFlags = ports.Flags{} } - return id, err + return id, e.bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -1057,11 +1090,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } } - nicid := addr.NIC + nicID := addr.NIC if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { // A local unicast address was specified, verify that it's valid. - nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicid == 0 { + nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicID == 0 { return tcpip.ErrBadLocalAddress } } @@ -1070,13 +1103,14 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, err = e.registerWithStack(nicid, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id - e.RegisterNICID = nicid + e.boundBindToDevice = btd + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos // Mark endpoint as bound. @@ -1111,9 +1145,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() + addr := e.ID.LocalAddress + if e.state == StateConnected { + addr = e.route.LocalAddress + } + return tcpip.FullAddress{ NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, + Addr: addr, Port: e.ID.LocalPort, }, nil } @@ -1154,17 +1193,17 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.UDPMinimumSize) + pkt.Data.TrimFront(header.UDPMinimumSize) e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() @@ -1188,18 +1227,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &udpPacket{ + packet := &udpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) - e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() @@ -1210,7 +1249,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. @@ -1234,6 +1273,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index b227e353b..43fb047ed 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -109,7 +109,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index d399ec722..fc706ede2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -16,7 +16,6 @@ package udp import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, id: id, - vv: vv, + pkt: pkt, }) return true @@ -62,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - vv buffer.VectorisedView + pkt tcpip.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.vv) + ep.HandlePacket(r.route, r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index de026880f..259c3072a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -116,13 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) + payload := newNetHeader.ToVectorisedView() + payload.Append(pkt.Data.ToView().ToVectorisedView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -130,7 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -151,12 +159,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload.Append(pkt.Data) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -164,7 +172,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b724d788c..7051a7a9c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -356,9 +356,9 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw if p.Proto != flow.netProto() { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) h := flow.header4Tuple(outgoing) checkers := append( @@ -397,7 +397,8 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv6(buf) @@ -431,7 +432,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } // injectV4Packet creates a V4 test packet with the given payload and header @@ -441,7 +446,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv4(buf) @@ -471,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } func newPayload() []byte { @@ -1442,8 +1453,8 @@ func TestV4UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1516,8 +1527,8 @@ func TestV6UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 1f7efb064..0427bc41f 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -34,11 +34,3 @@ go_test( ], embed = [":waiter"], ) - -filegroup( - name = "autogen", - srcs = [ - "waiter_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/runsc/BUILD b/runsc/BUILD index e4e8e64a3..e5587421d 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -76,16 +76,24 @@ pkg_tar( genrule( name = "deb-version", + # Note that runsc must appear in the srcs parameter and not the tools + # parameter, otherwise it will not be stamped. This is reasonable, as tools + # may be encoded differently in the build graph (cached more aggressively + # because they are assumes to be hermetic). + srcs = [":runsc"], outs = ["version.txt"], cmd = "$(location :runsc) -version | grep 'runsc version' | sed 's/^[^0-9]*//' > $@", stamp = 1, - tools = [":runsc"], ) pkg_deb( name = "runsc-debian", architecture = "amd64", data = ":debian-data", + # Note that the description_file will be flatten (all newlines removed), + # and therefore it is kept to a simple one-line description. The expected + # format for debian packages is "short summary\nLonger explanation of + # tool." and this is impossible with the flattening. description_file = "debian/description", homepage = "https://gvisor.dev/", maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>", diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 6fe2b57de..6226b63f8 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "compat.go", "compat_amd64.go", + "compat_arm64.go", "config.go", "controller.go", "debug.go", @@ -15,6 +16,8 @@ go_library( "fs.go", "limits.go", "loader.go", + "loader_amd64.go", + "loader_arm64.go", "network.go", "pprof.go", "strace.go", @@ -60,6 +63,7 @@ go_library( "//pkg/sentry/socket/hostinet", "//pkg/sentry/socket/netlink", "//pkg/sentry/socket/netlink/route", + "//pkg/sentry/socket/netlink/uevent", "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix", "//pkg/sentry/state", @@ -107,7 +111,6 @@ go_test( "//pkg/control/server", "//pkg/log", "//pkg/p9", - "//pkg/sentry/arch:registers_go_proto", "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 07e35ab10..352e710d2 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -21,10 +21,8 @@ import ( "syscall" "github.com/golang/protobuf/proto" - "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/arch" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/strace" @@ -53,9 +51,9 @@ type compatEmitter struct { } func newCompatEmitter(logFD int) (*compatEmitter, error) { - nameMap, ok := strace.Lookup(abi.Linux, arch.AMD64) + nameMap, ok := getSyscallNameMap() if !ok { - return nil, fmt.Errorf("amd64 Linux syscall table not found") + return nil, fmt.Errorf("Linux syscall table not found") } c := &compatEmitter{ @@ -86,16 +84,16 @@ func (c *compatEmitter) Emit(msg proto.Message) (bool, error) { } func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) { - regs := us.Registers.GetArch().(*rpb.Registers_Amd64).Amd64 + regs := us.Registers c.mu.Lock() defer c.mu.Unlock() - sysnr := regs.OrigRax + sysnr := syscallNum(regs) tr := c.trackers[sysnr] if tr == nil { switch sysnr { - case syscall.SYS_PRCTL, syscall.SYS_ARCH_PRCTL: + case syscall.SYS_PRCTL: // args: cmd, ... tr = newArgsTracker(0) @@ -112,10 +110,14 @@ func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) { tr = newArgsTracker(2) default: - tr = &onceTracker{} + tr = newArchArgsTracker(sysnr) + if tr == nil { + tr = &onceTracker{} + } } c.trackers[sysnr] = tr } + if tr.shouldReport(regs) { c.sink.Infof("Unsupported syscall: %s, regs: %+v", c.nameMap.Name(uintptr(sysnr)), regs) tr.onReported(regs) @@ -139,10 +141,10 @@ func (c *compatEmitter) Close() error { // the syscall and arguments. type syscallTracker interface { // shouldReport returns true is the syscall should be reported. - shouldReport(regs *rpb.AMD64Registers) bool + shouldReport(regs *rpb.Registers) bool // onReported marks the syscall as reported. - onReported(regs *rpb.AMD64Registers) + onReported(regs *rpb.Registers) } // onceTracker reports only a single time, used for most syscalls. @@ -150,10 +152,45 @@ type onceTracker struct { reported bool } -func (o *onceTracker) shouldReport(_ *rpb.AMD64Registers) bool { +func (o *onceTracker) shouldReport(_ *rpb.Registers) bool { return !o.reported } -func (o *onceTracker) onReported(_ *rpb.AMD64Registers) { +func (o *onceTracker) onReported(_ *rpb.Registers) { o.reported = true } + +// argsTracker reports only once for each different combination of arguments. +// It's used for generic syscalls like ioctl to report once per 'cmd'. +type argsTracker struct { + // argsIdx is the syscall arguments to use as unique ID. + argsIdx []int + reported map[string]struct{} + count int +} + +func newArgsTracker(argIdx ...int) *argsTracker { + return &argsTracker{argsIdx: argIdx, reported: make(map[string]struct{})} +} + +// key returns the command based on the syscall argument index. +func (a *argsTracker) key(regs *rpb.Registers) string { + var rv string + for _, idx := range a.argsIdx { + rv += fmt.Sprintf("%d|", argVal(idx, regs)) + } + return rv +} + +func (a *argsTracker) shouldReport(regs *rpb.Registers) bool { + if a.count >= reportLimit { + return false + } + _, ok := a.reported[a.key(regs)] + return !ok +} + +func (a *argsTracker) onReported(regs *rpb.Registers) { + a.count++ + a.reported[a.key(regs)] = struct{}{} +} diff --git a/runsc/boot/compat_amd64.go b/runsc/boot/compat_amd64.go index 43cd0db94..42b0ca8b0 100644 --- a/runsc/boot/compat_amd64.go +++ b/runsc/boot/compat_amd64.go @@ -16,62 +16,81 @@ package boot import ( "fmt" + "syscall" + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/sentry/arch" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + "gvisor.dev/gvisor/pkg/sentry/strace" ) // reportLimit is the max number of events that should be reported per tracker. const reportLimit = 100 -// argsTracker reports only once for each different combination of arguments. -// It's used for generic syscalls like ioctl to report once per 'cmd'. -type argsTracker struct { - // argsIdx is the syscall arguments to use as unique ID. - argsIdx []int - reported map[string]struct{} - count int +// newRegs create a empty Registers instance. +func newRegs() *rpb.Registers { + return &rpb.Registers{ + Arch: &rpb.Registers_Amd64{ + Amd64: &rpb.AMD64Registers{}, + }, + } } -func newArgsTracker(argIdx ...int) *argsTracker { - return &argsTracker{argsIdx: argIdx, reported: make(map[string]struct{})} -} +func argVal(argIdx int, regs *rpb.Registers) uint32 { + amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64 -// cmd returns the command based on the syscall argument index. -func (a *argsTracker) key(regs *rpb.AMD64Registers) string { - var rv string - for _, idx := range a.argsIdx { - rv += fmt.Sprintf("%d|", argVal(idx, regs)) + switch argIdx { + case 0: + return uint32(amd64Regs.Rdi) + case 1: + return uint32(amd64Regs.Rsi) + case 2: + return uint32(amd64Regs.Rdx) + case 3: + return uint32(amd64Regs.R10) + case 4: + return uint32(amd64Regs.R8) + case 5: + return uint32(amd64Regs.R9) } - return rv + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) } -func argVal(argIdx int, regs *rpb.AMD64Registers) uint32 { +func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) { + amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64 + switch argIdx { case 0: - return uint32(regs.Rdi) + amd64Regs.Rdi = argVal case 1: - return uint32(regs.Rsi) + amd64Regs.Rsi = argVal case 2: - return uint32(regs.Rdx) + amd64Regs.Rdx = argVal case 3: - return uint32(regs.R10) + amd64Regs.R10 = argVal case 4: - return uint32(regs.R8) + amd64Regs.R8 = argVal case 5: - return uint32(regs.R9) + amd64Regs.R9 = argVal + default: + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) } - panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) } -func (a *argsTracker) shouldReport(regs *rpb.AMD64Registers) bool { - if a.count >= reportLimit { - return false - } - _, ok := a.reported[a.key(regs)] - return !ok +func getSyscallNameMap() (strace.SyscallMap, bool) { + return strace.Lookup(abi.Linux, arch.AMD64) +} + +func syscallNum(regs *rpb.Registers) uint64 { + amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64 + return amd64Regs.OrigRax } -func (a *argsTracker) onReported(regs *rpb.AMD64Registers) { - a.count++ - a.reported[a.key(regs)] = struct{}{} +func newArchArgsTracker(sysnr uint64) syscallTracker { + switch sysnr { + case syscall.SYS_ARCH_PRCTL: + // args: cmd, ... + return newArgsTracker(0) + } + return nil } diff --git a/runsc/boot/compat_arm64.go b/runsc/boot/compat_arm64.go new file mode 100644 index 000000000..f784cd237 --- /dev/null +++ b/runsc/boot/compat_arm64.go @@ -0,0 +1,91 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package boot + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/sentry/arch" + rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + "gvisor.dev/gvisor/pkg/sentry/strace" +) + +// reportLimit is the max number of events that should be reported per tracker. +const reportLimit = 100 + +// newRegs create a empty Registers instance. +func newRegs() *rpb.Registers { + return &rpb.Registers{ + Arch: &rpb.Registers_Arm64{ + Arm64: &rpb.ARM64Registers{}, + }, + } +} + +func argVal(argIdx int, regs *rpb.Registers) uint32 { + arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64 + + switch argIdx { + case 0: + return uint32(arm64Regs.R0) + case 1: + return uint32(arm64Regs.R1) + case 2: + return uint32(arm64Regs.R2) + case 3: + return uint32(arm64Regs.R3) + case 4: + return uint32(arm64Regs.R4) + case 5: + return uint32(arm64Regs.R5) + } + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) +} + +func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) { + arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64 + + switch argIdx { + case 0: + arm64Regs.R0 = argVal + case 1: + arm64Regs.R1 = argVal + case 2: + arm64Regs.R2 = argVal + case 3: + arm64Regs.R3 = argVal + case 4: + arm64Regs.R4 = argVal + case 5: + arm64Regs.R5 = argVal + default: + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) + } +} + +func getSyscallNameMap() (strace.SyscallMap, bool) { + return strace.Lookup(abi.Linux, arch.ARM64) +} + +func syscallNum(regs *rpb.Registers) uint64 { + arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64 + return arm64Regs.R8 +} + +func newArchArgsTracker(sysnr uint64) syscallTracker { + // currently, no arch specific syscalls need to be handled here. + return nil +} diff --git a/runsc/boot/compat_test.go b/runsc/boot/compat_test.go index 388298d8d..839c5303b 100644 --- a/runsc/boot/compat_test.go +++ b/runsc/boot/compat_test.go @@ -16,8 +16,6 @@ package boot import ( "testing" - - rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" ) func TestOnceTracker(t *testing.T) { @@ -35,31 +33,34 @@ func TestOnceTracker(t *testing.T) { func TestArgsTracker(t *testing.T) { for _, tc := range []struct { - name string - idx []int - rdi1 uint64 - rdi2 uint64 - rsi1 uint64 - rsi2 uint64 - want bool + name string + idx []int + arg1_1 uint64 + arg1_2 uint64 + arg2_1 uint64 + arg2_2 uint64 + want bool }{ - {name: "same rdi", idx: []int{0}, rdi1: 123, rdi2: 123, want: false}, - {name: "same rsi", idx: []int{1}, rsi1: 123, rsi2: 123, want: false}, - {name: "diff rdi", idx: []int{0}, rdi1: 123, rdi2: 321, want: true}, - {name: "diff rsi", idx: []int{1}, rsi1: 123, rsi2: 321, want: true}, - {name: "cmd is uint32", idx: []int{0}, rsi1: 0xdead00000123, rsi2: 0xbeef00000123, want: false}, - {name: "same 2 args", idx: []int{0, 1}, rsi1: 123, rdi1: 321, rsi2: 123, rdi2: 321, want: false}, - {name: "diff 2 args", idx: []int{0, 1}, rsi1: 123, rdi1: 321, rsi2: 789, rdi2: 987, want: true}, + {name: "same arg1", idx: []int{0}, arg1_1: 123, arg1_2: 123, want: false}, + {name: "same arg2", idx: []int{1}, arg2_1: 123, arg2_2: 123, want: false}, + {name: "diff arg1", idx: []int{0}, arg1_1: 123, arg1_2: 321, want: true}, + {name: "diff arg2", idx: []int{1}, arg2_1: 123, arg2_2: 321, want: true}, + {name: "cmd is uint32", idx: []int{0}, arg2_1: 0xdead00000123, arg2_2: 0xbeef00000123, want: false}, + {name: "same 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 123, arg1_2: 321, want: false}, + {name: "diff 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 789, arg1_2: 987, want: true}, } { t.Run(tc.name, func(t *testing.T) { c := newArgsTracker(tc.idx...) - regs := &rpb.AMD64Registers{Rdi: tc.rdi1, Rsi: tc.rsi1} + regs := newRegs() + setArgVal(0, tc.arg1_1, regs) + setArgVal(1, tc.arg2_1, regs) if !c.shouldReport(regs) { t.Error("first call to shouldReport, got: false, want: true") } c.onReported(regs) - regs.Rdi, regs.Rsi = tc.rdi2, tc.rsi2 + setArgVal(0, tc.arg1_2, regs) + setArgVal(1, tc.arg2_2, regs) if got := c.shouldReport(regs); tc.want != got { t.Errorf("second call to shouldReport, got: %t, want: %t", got, tc.want) } @@ -70,7 +71,9 @@ func TestArgsTracker(t *testing.T) { func TestArgsTrackerLimit(t *testing.T) { c := newArgsTracker(0, 1) for i := 0; i < reportLimit; i++ { - regs := &rpb.AMD64Registers{Rdi: 123, Rsi: uint64(i)} + regs := newRegs() + setArgVal(0, 123, regs) + setArgVal(1, uint64(i), regs) if !c.shouldReport(regs) { t.Error("shouldReport before limit was reached, got: false, want: true") } @@ -78,7 +81,9 @@ func TestArgsTrackerLimit(t *testing.T) { } // Should hit the count limit now. - regs := &rpb.AMD64Registers{Rdi: 123, Rsi: 123456} + regs := newRegs() + setArgVal(0, 123, regs) + setArgVal(1, 123456, regs) if c.shouldReport(regs) { t.Error("shouldReport after limit was reached, got: true, want: false") } diff --git a/runsc/boot/config.go b/runsc/boot/config.go index 01a29e8d5..a878bc2ce 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -178,8 +178,11 @@ type Config struct { // capabilities. EnableRaw bool - // GSO indicates that generic segmentation offload is enabled. - GSO bool + // HardwareGSO indicates that hardware segmentation offload is enabled. + HardwareGSO bool + + // SoftwareGSO indicates that software segmentation offload is enabled. + SoftwareGSO bool // LogPackets indicates that all network packets should be logged. LogPackets bool @@ -247,6 +250,12 @@ type Config struct { // multiple tests are run in parallel, since there is no way to pass // parameters to the runtime from docker. TestOnlyTestNameEnv string + + // CPUNumFromQuota sets CPU number count to available CPU quota, using + // least integer value greater than or equal to quota. + // + // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2. + CPUNumFromQuota bool } // ToFlags returns a slice of flags that correspond to the given Config. @@ -275,8 +284,13 @@ func (c *Config) ToFlags() []string { "--rootless=" + strconv.FormatBool(c.Rootless), "--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr), "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode), + "--gso=" + strconv.FormatBool(c.HardwareGSO), + "--software-gso=" + strconv.FormatBool(c.SoftwareGSO), "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead), } + if c.CPUNumFromQuota { + f = append(f, "--cpu-num-from-quota") + } // Only include these if set since it is never to be used by users. if c.TestOnlyAllowRunAsCurrentUserWithoutChroot { f = append(f, "--TESTONLY-unsafe-nonroot=true") diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 5f644b57e..9c9e94864 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -51,7 +51,7 @@ const ( ContainerEvent = "containerManager.Event" // ContainerExecuteAsync is the URPC endpoint for executing a command in a - // container.. + // container. ContainerExecuteAsync = "containerManager.ExecuteAsync" // ContainerPause pauses the container. @@ -152,7 +152,9 @@ func newController(fd int, l *Loader) (*controller, error) { srv.Register(&debug{}) srv.Register(&control.Logging{}) if l.conf.ProfileEnable { - srv.Register(&control.Profile{}) + srv.Register(&control.Profile{ + Kernel: l.k, + }) } return &controller{ @@ -161,7 +163,7 @@ func newController(fd int, l *Loader) (*controller, error) { }, nil } -// containerManager manages sandboes containers. +// containerManager manages sandbox containers. type containerManager struct { // startChan is used to signal when the root container process should // be started. @@ -380,7 +382,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { } // Since we have a new kernel we also must make a new watchdog. - dog := watchdog.New(k, watchdog.DefaultTimeout, cm.l.conf.WatchdogAction) + dogOpts := watchdog.DefaultOpts + dogOpts.TaskTimeoutAction = cm.l.conf.WatchdogAction + dog := watchdog.New(k, dogOpts) // Change the loader fields to reflect the changes made when restoring. cm.l.k = k diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD index f5509b6b7..3a9dcfc04 100644 --- a/runsc/boot/filter/BUILD +++ b/runsc/boot/filter/BUILD @@ -6,6 +6,8 @@ go_library( name = "filter", srcs = [ "config.go", + "config_amd64.go", + "config_arm64.go", "extra_filters.go", "extra_filters_msan.go", "extra_filters_race.go", diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index efbf1fd4a..4fb9adca6 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -26,10 +26,6 @@ import ( // allowedSyscalls is the set of syscalls executed by the Sentry to the host OS. var allowedSyscalls = seccomp.SyscallRules{ - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_GET_FS)}, - {seccomp.AllowValue(linux.ARCH_SET_FS)}, - }, syscall.SYS_CLOCK_GETTIME: {}, syscall.SYS_CLONE: []seccomp.Rule{ { @@ -42,9 +38,15 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.CLONE_THREAD), }, }, - syscall.SYS_CLOSE: {}, - syscall.SYS_DUP: {}, - syscall.SYS_DUP2: {}, + syscall.SYS_CLOSE: {}, + syscall.SYS_DUP: {}, + syscall.SYS_DUP3: []seccomp.Rule{ + { + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowValue(0), + }, + }, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ @@ -132,11 +134,6 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.AllowValue(syscall.SOL_SOCKET), seccomp.AllowValue(syscall.SO_SNDBUF), }, - { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_REUSEADDR), - }, }, syscall.SYS_GETTID: {}, syscall.SYS_GETTIMEOFDAY: {}, @@ -243,6 +240,15 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.AllowValue(0), }, }, + unix.SYS_SENDMMSG: []seccomp.Rule{ + { + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.MSG_DONTWAIT), + seccomp.AllowValue(0), + }, + }, syscall.SYS_RESTART_SYSCALL: {}, syscall.SYS_RT_SIGACTION: {}, syscall.SYS_RT_SIGPROCMASK: {}, @@ -306,6 +312,26 @@ func hostInetFilters() seccomp.SyscallRules { syscall.SYS_GETSOCKOPT: []seccomp.Rule{ { seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IP), + seccomp.AllowValue(syscall.IP_TOS), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IP), + seccomp.AllowValue(syscall.IP_RECVTOS), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IPV6), + seccomp.AllowValue(syscall.IPV6_TCLASS), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IPV6), + seccomp.AllowValue(syscall.IPV6_RECVTCLASS), + }, + { + seccomp.AllowAny{}, seccomp.AllowValue(syscall.SOL_IPV6), seccomp.AllowValue(syscall.IPV6_V6ONLY), }, @@ -407,6 +433,34 @@ func hostInetFilters() seccomp.SyscallRules { seccomp.AllowAny{}, seccomp.AllowValue(4), }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IP), + seccomp.AllowValue(syscall.IP_TOS), + seccomp.AllowAny{}, + seccomp.AllowValue(4), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IP), + seccomp.AllowValue(syscall.IP_RECVTOS), + seccomp.AllowAny{}, + seccomp.AllowValue(4), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IPV6), + seccomp.AllowValue(syscall.IPV6_TCLASS), + seccomp.AllowAny{}, + seccomp.AllowValue(4), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IPV6), + seccomp.AllowValue(syscall.IPV6_RECVTCLASS), + seccomp.AllowAny{}, + seccomp.AllowValue(4), + }, }, syscall.SYS_SHUTDOWN: []seccomp.Rule{ { diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go new file mode 100644 index 000000000..5335ff82c --- /dev/null +++ b/runsc/boot/filter/config_amd64.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package filter + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" +) + +func init() { + allowedSyscalls[syscall.SYS_ARCH_PRCTL] = append(allowedSyscalls[syscall.SYS_ARCH_PRCTL], + seccomp.Rule{seccomp.AllowValue(linux.ARCH_GET_FS)}, + seccomp.Rule{seccomp.AllowValue(linux.ARCH_SET_FS)}, + ) +} diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go new file mode 100644 index 000000000..7fa9bbda3 --- /dev/null +++ b/runsc/boot/filter/config_arm64.go @@ -0,0 +1,21 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package filter + +// Reserve for future customization. +func init() { +} diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 76036c147..421ccd255 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -16,7 +16,6 @@ package boot import ( "fmt" - "path" "path/filepath" "sort" "strconv" @@ -52,7 +51,7 @@ const ( rootDevice = "9pfs-/" // MountPrefix is the annotation prefix for mount hints. - MountPrefix = "gvisor.dev/spec/mount" + MountPrefix = "dev.gvisor.spec.mount." // Filesystems that runsc supports. bind = "bind" @@ -465,6 +464,13 @@ func (m *mountHint) checkCompatible(mount specs.Mount) error { return nil } +func (m *mountHint) fileAccessType() FileAccessType { + if m.share == container { + return FileAccessExclusive + } + return FileAccessShared +} + func filterUnsupportedOptions(mount specs.Mount) []string { rv := make([]string, 0, len(mount.Options)) for _, o := range mount.Options { @@ -483,14 +489,15 @@ type podMountHints struct { func newPodMountHints(spec *specs.Spec) (*podMountHints, error) { mnts := make(map[string]*mountHint) for k, v := range spec.Annotations { - // Look for 'gvisor.dev/spec/mount' annotations and parse them. + // Look for 'dev.gvisor.spec.mount' annotations and parse them. if strings.HasPrefix(k, MountPrefix) { - parts := strings.Split(k, "/") - if len(parts) != 5 { + // Remove the prefix and split the rest. + parts := strings.Split(k[len(MountPrefix):], ".") + if len(parts) != 2 { return nil, fmt.Errorf("invalid mount annotation: %s=%s", k, v) } - name := parts[3] - if len(name) == 0 || path.Clean(name) != name { + name := parts[0] + if len(name) == 0 { return nil, fmt.Errorf("invalid mount name: %s", name) } mnt := mnts[name] @@ -498,7 +505,7 @@ func newPodMountHints(spec *specs.Spec) (*podMountHints, error) { mnt = &mountHint{name: name} mnts[name] = mnt } - if err := mnt.setField(parts[4], v); err != nil { + if err := mnt.setField(parts[1], v); err != nil { return nil, err } } @@ -568,6 +575,11 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin func (c *containerMounter) processHints(conf *Config) error { ctx := c.k.SupervisorContext() for _, hint := range c.hints.mounts { + // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a + // common gofer to mount all shared volumes. + if hint.mount.Type != tmpfs { + continue + } log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type) inode, err := c.mountSharedMaster(ctx, conf, hint) if err != nil { @@ -764,8 +776,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( case bind: fd := c.fds.remove() fsName = "9p" - // Non-root bind mounts are always shared. - opts = p9MountOptions(fd, FileAccessShared) + opts = p9MountOptions(fd, c.getMountAccessType(m)) // If configured, add overlay to all writable mounts. useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly @@ -778,6 +789,14 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( return fsName, opts, useOverlay, nil } +func (c *containerMounter) getMountAccessType(mount specs.Mount) FileAccessType { + if hint := c.hints.findMount(mount); hint != nil { + return hint.fileAccessType() + } + // Non-root bind mounts are always shared if no hints were provided. + return FileAccessShared +} + // mountSubmount mounts volumes inside the container's root. Because mounts may // be readonly, a lower ramfs overlay is added to create the mount point dir. // Another overlay is added with tmpfs on top if Config.Overlay is true. @@ -837,7 +856,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns return fmt.Errorf("mount %q error: %v", m.Destination, err) } - log.Infof("Mounted %q to %q type %s", m.Source, m.Destination, m.Type) + log.Infof("Mounted %q to %q type: %s, internal-options: %q", m.Source, m.Destination, m.Type, opts) return nil } diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go index 49ab34b33..912037075 100644 --- a/runsc/boot/fs_test.go +++ b/runsc/boot/fs_test.go @@ -15,7 +15,6 @@ package boot import ( - "path" "reflect" "strings" "testing" @@ -26,19 +25,19 @@ import ( func TestPodMountHintsHappy(t *testing.T) { spec := &specs.Spec{ Annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", - path.Join(MountPrefix, "mount2", "source"): "bar", - path.Join(MountPrefix, "mount2", "type"): "bind", - path.Join(MountPrefix, "mount2", "share"): "container", - path.Join(MountPrefix, "mount2", "options"): "rw,private", + MountPrefix + "mount2.source": "bar", + MountPrefix + "mount2.type": "bind", + MountPrefix + "mount2.share": "container", + MountPrefix + "mount2.options": "rw,private", }, } podHints, err := newPodMountHints(spec) if err != nil { - t.Errorf("newPodMountHints failed: %v", err) + t.Fatalf("newPodMountHints failed: %v", err) } // Check that fields were set correctly. @@ -86,95 +85,95 @@ func TestPodMountHintsErrors(t *testing.T) { { name: "too short", annotations: map[string]string{ - path.Join(MountPrefix, "mount1"): "foo", + MountPrefix + "mount1": "foo", }, error: "invalid mount annotation", }, { name: "no name", annotations: map[string]string{ - MountPrefix + "//source": "foo", + MountPrefix + ".source": "foo", }, error: "invalid mount name", }, { name: "missing source", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", }, error: "source field", }, { name: "missing type", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.share": "pod", }, error: "type field", }, { name: "missing share", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", }, error: "share field", }, { name: "invalid field name", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "invalid"): "foo", + MountPrefix + "mount1.invalid": "foo", }, error: "invalid mount annotation", }, { name: "invalid source", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", }, error: "source cannot be empty", }, { name: "invalid type", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "invalid-type", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "invalid-type", + MountPrefix + "mount1.share": "pod", }, error: "invalid type", }, { name: "invalid share", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "invalid-share", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "invalid-share", }, error: "invalid share", }, { name: "invalid options", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", - path.Join(MountPrefix, "mount1", "options"): "invalid-option", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", + MountPrefix + "mount1.options": "invalid-option", }, error: "unknown mount option", }, { name: "duplicate source", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", - path.Join(MountPrefix, "mount2", "source"): "foo", - path.Join(MountPrefix, "mount2", "type"): "bind", - path.Join(MountPrefix, "mount2", "share"): "container", + MountPrefix + "mount2.source": "foo", + MountPrefix + "mount2.type": "bind", + MountPrefix + "mount2.share": "container", }, error: "have the same mount source", }, @@ -191,3 +190,61 @@ func TestPodMountHintsErrors(t *testing.T) { }) } } + +func TestGetMountAccessType(t *testing.T) { + const source = "foo" + for _, tst := range []struct { + name string + annotations map[string]string + want FileAccessType + }{ + { + name: "container=exclusive", + annotations: map[string]string{ + MountPrefix + "mount1.source": source, + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "container", + }, + want: FileAccessExclusive, + }, + { + name: "pod=shared", + annotations: map[string]string{ + MountPrefix + "mount1.source": source, + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "pod", + }, + want: FileAccessShared, + }, + { + name: "shared=shared", + annotations: map[string]string{ + MountPrefix + "mount1.source": source, + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "shared", + }, + want: FileAccessShared, + }, + { + name: "default=shared", + annotations: map[string]string{ + MountPrefix + "mount1.source": source + "mismatch", + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "container", + }, + want: FileAccessShared, + }, + } { + t.Run(tst.name, func(t *testing.T) { + spec := &specs.Spec{Annotations: tst.annotations} + podHints, err := newPodMountHints(spec) + if err != nil { + t.Fatalf("newPodMountHints failed: %v", err) + } + mounter := containerMounter{hints: podHints} + if got := mounter.getMountAccessType(specs.Mount{Source: source}); got != tst.want { + t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got) + } + }) + } +} diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index c8e5e86ee..bc1d0c1bb 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -43,7 +43,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/sighandling" - slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/watchdog" @@ -65,6 +64,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/hostinet" _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink" _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route" + _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" _ "gvisor.dev/gvisor/pkg/sentry/socket/unix" ) @@ -93,10 +93,6 @@ type Loader struct { // spec is the base configuration for the root container. spec *specs.Spec - // startSignalForwarding enables forwarding of signals to the sandboxed - // container. It should be called after the init process is loaded. - startSignalForwarding func() func() - // stopSignalForwarding disables forwarding of signals to the sandboxed // container. It should be called when a sandbox is destroyed. stopSignalForwarding func() @@ -146,9 +142,6 @@ type execProcess struct { func init() { // Initialize the random number generator. mrand.Seed(gtime.Now().UnixNano()) - - // Register the global syscall table. - kernel.RegisterSyscallTable(slinux.AMD64) } // Args are the arguments for New(). @@ -232,7 +225,7 @@ func New(args Args) (*Loader, error) { // this point. Netns is configured before Run() is called. Netstack is // configured using a control uRPC message. Host network is configured inside // Run(). - networkStack, err := newEmptyNetworkStack(args.Conf, k) + networkStack, err := newEmptyNetworkStack(args.Conf, k, k) if err != nil { return nil, fmt.Errorf("creating network: %v", err) } @@ -300,7 +293,9 @@ func New(args Args) (*Loader, error) { } // Create a watchdog. - dog := watchdog.New(k, watchdog.DefaultTimeout, args.Conf.WatchdogAction) + dogOpts := watchdog.DefaultOpts + dogOpts.TaskTimeoutAction = args.Conf.WatchdogAction + dog := watchdog.New(k, dogOpts) procArgs, err := newProcess(args.ID, args.Spec, creds, k, k.RootPIDNamespace()) if err != nil { @@ -337,29 +332,6 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("ignore child stop signals failed: %v", err) } - // Handle signals by forwarding them to the root container process - // (except for panic signal, which should cause a panic). - l.startSignalForwarding = sighandling.PrepareHandler(func(sig linux.Signal) { - // Panic signal should cause a panic. - if args.Conf.PanicSignal != -1 && sig == linux.Signal(args.Conf.PanicSignal) { - panic("Signal-induced panic") - } - - // Otherwise forward to root container. - deliveryMode := DeliverToProcess - if args.Console { - // Since we are running with a console, we should - // forward the signal to the foreground process group - // so that job control signals like ^C can be handled - // properly. - deliveryMode = DeliverToForegroundProcessGroup - } - log.Infof("Received external signal %d, mode: %v", sig, deliveryMode) - if err := l.signal(args.ID, 0, int32(sig), deliveryMode); err != nil { - log.Warningf("error sending signal %v to container %q: %v", sig, args.ID, err) - } - }) - // Create the control server using the provided FD. // // This must be done *after* we have initialized the kernel since the @@ -567,8 +539,27 @@ func (l *Loader) run() error { ep.tty.InitForegroundProcessGroup(ep.tg.ProcessGroup()) } - // Start signal forwarding only after an init process is created. - l.stopSignalForwarding = l.startSignalForwarding() + // Handle signals by forwarding them to the root container process + // (except for panic signal, which should cause a panic). + l.stopSignalForwarding = sighandling.StartSignalForwarding(func(sig linux.Signal) { + // Panic signal should cause a panic. + if l.conf.PanicSignal != -1 && sig == linux.Signal(l.conf.PanicSignal) { + panic("Signal-induced panic") + } + + // Otherwise forward to root container. + deliveryMode := DeliverToProcess + if l.console { + // Since we are running with a console, we should forward the signal to + // the foreground process group so that job control signals like ^C can + // be handled properly. + deliveryMode = DeliverToForegroundProcessGroup + } + log.Infof("Received external signal %d, mode: %v", sig, deliveryMode) + if err := l.signal(l.sandboxID, 0, int32(sig), deliveryMode); err != nil { + log.Warningf("error sending signal %v to container %q: %v", sig, l.sandboxID, err) + } + }) log.Infof("Process should have started...") l.watchdog.Start() @@ -905,7 +896,7 @@ func (l *Loader) WaitExit() kernel.ExitStatus { return l.k.GlobalInit().ExitStatus() } -func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { +func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { switch conf.Network { case NetworkHost: return hostinet.NewStack(), nil @@ -922,7 +913,8 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { HandleLocal: true, // Enable raw sockets for users with sufficient // privileges. - UnassociatedFactory: raw.EndpointFactory{}, + RawFactory: raw.EndpointFactory{}, + UniqueID: uniqueID, })} // Enable SACK Recovery. diff --git a/runsc/boot/loader_amd64.go b/runsc/boot/loader_amd64.go new file mode 100644 index 000000000..b9669f2ac --- /dev/null +++ b/runsc/boot/loader_amd64.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package boot + +import ( + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" +) + +func init() { + // Register the global syscall table. + kernel.RegisterSyscallTable(linux.AMD64) +} diff --git a/runsc/boot/loader_arm64.go b/runsc/boot/loader_arm64.go new file mode 100644 index 000000000..cf64d28c8 --- /dev/null +++ b/runsc/boot/loader_arm64.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package boot + +import ( + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" +) + +func init() { + // Register the global syscall table. + kernel.RegisterSyscallTable(linux.ARM64) +} diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 32cba5ac1..dd4926bb9 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -50,12 +50,13 @@ type DefaultRoute struct { // FDBasedLink configures an fd-based link. type FDBasedLink struct { - Name string - MTU int - Addresses []net.IP - Routes []Route - GSOMaxSize uint32 - LinkAddress net.HardwareAddr + Name string + MTU int + Addresses []net.IP + Routes []Route + GSOMaxSize uint32 + SoftwareGSOEnabled bool + LinkAddress net.HardwareAddr // NumChannels controls how many underlying FD's are to be used to // create this endpoint. @@ -79,7 +80,8 @@ type CreateLinksAndRoutesArgs struct { LoopbackLinks []LoopbackLink FDBasedLinks []FDBasedLink - DefaultGateway DefaultRoute + Defaultv4Gateway DefaultRoute + Defaultv6Gateway DefaultRoute } // Empty returns true if route hasn't been set. @@ -121,10 +123,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct nicID++ nicids[link.Name] = nicID - ep := loopback.New() + linkEP := loopback.New() log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses) - if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil { return err } @@ -156,13 +158,14 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } mac := tcpip.LinkAddress(link.LinkAddress) - ep, err := fdbased.New(&fdbased.Options{ + linkEP, err := fdbased.New(&fdbased.Options{ FDs: FDs, MTU: uint32(link.MTU), EthernetHeader: true, Address: mac, PacketDispatchMode: fdbased.RecvMMsg, GSOMaxSize: link.GSOMaxSize, + SoftwareGSOEnabled: link.SoftwareGSOEnabled, RXChecksumOffload: true, }) if err != nil { @@ -170,7 +173,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) - if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil { return err } @@ -184,12 +187,24 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } } - if !args.DefaultGateway.Route.Empty() { - nicID, ok := nicids[args.DefaultGateway.Name] + if !args.Defaultv4Gateway.Route.Empty() { + nicID, ok := nicids[args.Defaultv4Gateway.Name] if !ok { - return fmt.Errorf("invalid interface name %q for default route", args.DefaultGateway.Name) + return fmt.Errorf("invalid interface name %q for default route", args.Defaultv4Gateway.Name) } - route, err := args.DefaultGateway.Route.toTcpipRoute(nicID) + route, err := args.Defaultv4Gateway.Route.toTcpipRoute(nicID) + if err != nil { + return err + } + routes = append(routes, route) + } + + if !args.Defaultv6Gateway.Route.Empty() { + nicID, ok := nicids[args.Defaultv6Gateway.Name] + if !ok { + return fmt.Errorf("invalid interface name %q for default route", args.Defaultv6Gateway.Name) + } + route, err := args.Defaultv6Gateway.Route.toTcpipRoute(nicID) if err != nil { return err } @@ -206,11 +221,11 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error { if loopback { if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil { - return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err) + return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, ep, err) } } else { if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil { - return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err) + return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, ep, err) } } diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index ab3a25b9b..653ca5f52 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -101,6 +101,14 @@ func getValue(path, name string) (string, error) { return string(out), nil } +func getInt(path, name string) (int, error) { + s, err := getValue(path, name) + if err != nil { + return 0, err + } + return strconv.Atoi(strings.TrimSpace(s)) +} + // fillFromAncestor sets the value of a cgroup file from the first ancestor // that has content. It does nothing if the file in 'path' has already been set. func fillFromAncestor(path string) (string, error) { @@ -323,6 +331,22 @@ func (c *Cgroup) Join() (func(), error) { return undo, nil } +func (c *Cgroup) CPUQuota() (float64, error) { + path := c.makePath("cpu") + quota, err := getInt(path, "cpu.cfs_quota_us") + if err != nil { + return -1, err + } + period, err := getInt(path, "cpu.cfs_period_us") + if err != nil { + return -1, err + } + if quota <= 0 || period <= 0 { + return -1, err + } + return float64(quota) / float64(period), nil +} + // NumCPU returns the number of CPUs configured in 'cpuset/cpuset.cpus'. func (c *Cgroup) NumCPU() (int, error) { path := c.makePath("cpuset") diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index 7313e473f..f37415810 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -32,16 +32,17 @@ import ( // Debug implements subcommands.Command for the "debug" command. type Debug struct { - pid int - stacks bool - signal int - profileHeap string - profileCPU string - profileDelay int - trace string - strace string - logLevel string - logPackets string + pid int + stacks bool + signal int + profileHeap string + profileCPU string + trace string + strace string + logLevel string + logPackets string + duration time.Duration + ps bool } // Name implements subcommands.Command. @@ -65,12 +66,13 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log") f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.") f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.") - f.IntVar(&d.profileDelay, "profile-delay", 5, "amount of time to wait before stoping CPU profile") + f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles") f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.") f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox") f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all`) f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).") f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.") + f.BoolVar(&d.ps, "ps", false, "lists processes") } // Execute implements subcommands.Command.Execute. @@ -163,7 +165,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) if err := c.Sandbox.StartCPUProfile(f); err != nil { return Errorf(err.Error()) } - log.Infof("CPU profile started for %d sec, writing to %q", d.profileDelay, d.profileCPU) + log.Infof("CPU profile started for %v, writing to %q", d.duration, d.profileCPU) } if d.trace != "" { delay = true @@ -181,8 +183,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) if err := c.Sandbox.StartTrace(f); err != nil { return Errorf(err.Error()) } - log.Infof("Tracing started for %d sec, writing to %q", d.profileDelay, d.trace) - + log.Infof("Tracing started for %v, writing to %q", d.duration, d.trace) } if d.strace != "" || len(d.logLevel) != 0 || len(d.logPackets) != 0 { @@ -241,9 +242,20 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } log.Infof("Logging options changed") } + if d.ps { + pList, err := c.Processes() + if err != nil { + Fatalf("getting processes for container: %v", err) + } + o, err := control.ProcessListToJSON(pList) + if err != nil { + Fatalf("generating JSON: %v", err) + } + log.Infof(o) + } if delay { - time.Sleep(time.Duration(d.profileDelay) * time.Second) + time.Sleep(d.duration) } return subcommands.ExitSuccess diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 26d1cd5ab..2bd12120d 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "container.go", "hook.go", + "state_file.go", "status.go", ], importpath = "gvisor.dev/gvisor/runsc/container", diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 7d67c3a75..5ed131a7f 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -28,6 +28,7 @@ import ( "github.com/kr/pty" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/control" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/testutil" @@ -219,9 +220,9 @@ func TestJobControlSignalExec(t *testing.T) { // Make sure all the processes are running. expectedPL := []*control.Process{ // Root container process. - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, // Bash from exec process. - {PID: 2, Cmd: "bash"}, + {PID: 2, Cmd: "bash", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(c, expectedPL); err != nil { t.Error(err) @@ -231,7 +232,7 @@ func TestJobControlSignalExec(t *testing.T) { ptyMaster.Write([]byte("sleep 100\n")) // Wait for it to start. Sleep's PPID is bash's PID. - expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep"}) + expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}}) if err := waitForProcessList(c, expectedPL); err != nil { t.Error(err) } @@ -361,7 +362,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { // Wait for bash to start. expectedPL := []*control.Process{ - {PID: 1, Cmd: "bash"}, + {PID: 1, Cmd: "bash", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(c, expectedPL); err != nil { t.Fatal(err) @@ -371,7 +372,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { ptyMaster.Write([]byte("sleep 100\n")) // Wait for sleep to start. - expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep"}) + expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{2}}) if err := waitForProcessList(c, expectedPL); err != nil { t.Fatal(err) } diff --git a/runsc/container/container.go b/runsc/container/container.go index a721c1c31..68782c4be 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -17,13 +17,11 @@ package container import ( "context" - "encoding/json" "fmt" "io/ioutil" "os" "os/exec" "os/signal" - "path/filepath" "regexp" "strconv" "strings" @@ -31,7 +29,6 @@ import ( "time" "github.com/cenkalti/backoff" - "github.com/gofrs/flock" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" @@ -41,17 +38,6 @@ import ( "gvisor.dev/gvisor/runsc/specutils" ) -const ( - // metadataFilename is the name of the metadata file relative to the - // container root directory that holds sandbox metadata. - metadataFilename = "meta.json" - - // metadataLockFilename is the name of a lock file in the container - // root directory that is used to prevent concurrent modifications to - // the container state and metadata. - metadataLockFilename = "meta.lock" -) - // validateID validates the container id. func validateID(id string) error { // See libcontainer/factory_linux.go. @@ -99,11 +85,6 @@ type Container struct { // BundleDir is the directory containing the container bundle. BundleDir string `json:"bundleDir"` - // Root is the directory containing the container metadata file. If this - // container is the root container, Root and RootContainerDir will be the - // same. - Root string `json:"root"` - // CreatedAt is the time the container was created. CreatedAt time.Time `json:"createdAt"` @@ -121,21 +102,24 @@ type Container struct { // be 0 if the gofer has been killed. GoferPid int `json:"goferPid"` + // Sandbox is the sandbox this container is running in. It's set when the + // container is created and reset when the sandbox is destroyed. + Sandbox *sandbox.Sandbox `json:"sandbox"` + + // Saver handles load from/save to the state file safely from multiple + // processes. + Saver StateFile `json:"saver"` + + // + // Fields below this line are not saved in the state file and will not + // be preserved across commands. + // + // goferIsChild is set if a gofer process is a child of the current process. // // This field isn't saved to json, because only a creator of a gofer // process will have it as a child process. goferIsChild bool - - // Sandbox is the sandbox this container is running in. It's set when the - // container is created and reset when the sandbox is destroyed. - Sandbox *sandbox.Sandbox `json:"sandbox"` - - // RootContainerDir is the root directory containing the metadata file of the - // sandbox root container. It's used to lock in order to serialize creating - // and deleting this Container's metadata directory. If this container is the - // root container, this is the same as Root. - RootContainerDir string } // loadSandbox loads all containers that belong to the sandbox with the given @@ -166,43 +150,35 @@ func loadSandbox(rootDir, id string) ([]*Container, error) { return containers, nil } -// Load loads a container with the given id from a metadata file. id may be an -// abbreviation of the full container id, in which case Load loads the -// container to which id unambiguously refers to. -// Returns ErrNotExist if container doesn't exist. -func Load(rootDir, id string) (*Container, error) { - log.Debugf("Load container %q %q", rootDir, id) - if err := validateID(id); err != nil { +// Load loads a container with the given id from a metadata file. partialID may +// be an abbreviation of the full container id, in which case Load loads the +// container to which id unambiguously refers to. Returns ErrNotExist if +// container doesn't exist. +func Load(rootDir, partialID string) (*Container, error) { + log.Debugf("Load container %q %q", rootDir, partialID) + if err := validateID(partialID); err != nil { return nil, fmt.Errorf("validating id: %v", err) } - cRoot, err := findContainerRoot(rootDir, id) + id, err := findContainerID(rootDir, partialID) if err != nil { // Preserve error so that callers can distinguish 'not found' errors. return nil, err } - // Lock the container metadata to prevent other runsc instances from - // writing to it while we are reading it. - unlock, err := lockContainerMetadata(cRoot) - if err != nil { - return nil, err + state := StateFile{ + RootDir: rootDir, + ID: id, } - defer unlock() + defer state.close() - // Read the container metadata file and create a new Container from it. - metaFile := filepath.Join(cRoot, metadataFilename) - metaBytes, err := ioutil.ReadFile(metaFile) - if err != nil { + c := &Container{} + if err := state.load(c); err != nil { if os.IsNotExist(err) { // Preserve error so that callers can distinguish 'not found' errors. return nil, err } - return nil, fmt.Errorf("reading container metadata file %q: %v", metaFile, err) - } - var c Container - if err := json.Unmarshal(metaBytes, &c); err != nil { - return nil, fmt.Errorf("unmarshaling container metadata from %q: %v", metaFile, err) + return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err) } // If the status is "Running" or "Created", check that the sandbox @@ -223,57 +199,37 @@ func Load(rootDir, id string) (*Container, error) { } } - return &c, nil + return c, nil } -func findContainerRoot(rootDir, partialID string) (string, error) { +func findContainerID(rootDir, partialID string) (string, error) { // Check whether the id fully specifies an existing container. - cRoot := filepath.Join(rootDir, partialID) - if _, err := os.Stat(cRoot); err == nil { - return cRoot, nil + stateFile := buildStatePath(rootDir, partialID) + if _, err := os.Stat(stateFile); err == nil { + return partialID, nil } // Now see whether id could be an abbreviation of exactly 1 of the // container ids. If id is ambiguous (it could match more than 1 // container), it is an error. - cRoot = "" ids, err := List(rootDir) if err != nil { return "", err } + rv := "" for _, id := range ids { if strings.HasPrefix(id, partialID) { - if cRoot != "" { - return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, cRoot, id) + if rv != "" { + return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id) } - cRoot = id + rv = id } } - if cRoot == "" { + if rv == "" { return "", os.ErrNotExist } - log.Debugf("abbreviated id %q resolves to full id %q", partialID, cRoot) - return filepath.Join(rootDir, cRoot), nil -} - -// List returns all container ids in the given root directory. -func List(rootDir string) ([]string, error) { - log.Debugf("List containers %q", rootDir) - fs, err := ioutil.ReadDir(rootDir) - if err != nil { - return nil, fmt.Errorf("reading dir %q: %v", rootDir, err) - } - var out []string - for _, f := range fs { - // Filter out directories that do no belong to a container. - cid := f.Name() - if validateID(cid) == nil { - if _, err := os.Stat(filepath.Join(rootDir, cid, metadataFilename)); err == nil { - out = append(out, f.Name()) - } - } - } - return out, nil + log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv) + return rv, nil } // Args is used to configure a new container. @@ -316,44 +272,34 @@ func New(conf *boot.Config, args Args) (*Container, error) { return nil, err } - unlockRoot, err := maybeLockRootContainer(args.Spec, conf.RootDir) - if err != nil { - return nil, err + if err := os.MkdirAll(conf.RootDir, 0711); err != nil { + return nil, fmt.Errorf("creating container root directory: %v", err) } - defer unlockRoot() + + c := &Container{ + ID: args.ID, + Spec: args.Spec, + ConsoleSocket: args.ConsoleSocket, + BundleDir: args.BundleDir, + Status: Creating, + CreatedAt: time.Now(), + Owner: os.Getenv("USER"), + Saver: StateFile{ + RootDir: conf.RootDir, + ID: args.ID, + }, + } + // The Cleanup object cleans up partially created containers when an error + // occurs. Any errors occurring during cleanup itself are ignored. + cu := specutils.MakeCleanup(func() { _ = c.Destroy() }) + defer cu.Clean() // Lock the container metadata file to prevent concurrent creations of // containers with the same id. - containerRoot := filepath.Join(conf.RootDir, args.ID) - unlock, err := lockContainerMetadata(containerRoot) - if err != nil { + if err := c.Saver.lockForNew(); err != nil { return nil, err } - defer unlock() - - // Check if the container already exists by looking for the metadata - // file. - if _, err := os.Stat(filepath.Join(containerRoot, metadataFilename)); err == nil { - return nil, fmt.Errorf("container with id %q already exists", args.ID) - } else if !os.IsNotExist(err) { - return nil, fmt.Errorf("looking for existing container in %q: %v", containerRoot, err) - } - - c := &Container{ - ID: args.ID, - Spec: args.Spec, - ConsoleSocket: args.ConsoleSocket, - BundleDir: args.BundleDir, - Root: containerRoot, - Status: Creating, - CreatedAt: time.Now(), - Owner: os.Getenv("USER"), - RootContainerDir: conf.RootDir, - } - // The Cleanup object cleans up partially created containers when an error occurs. - // Any errors occuring during cleanup itself are ignored. - cu := specutils.MakeCleanup(func() { _ = c.Destroy() }) - defer cu.Clean() + defer c.Saver.unlock() // If the metadata annotations indicate that this container should be // started in an existing sandbox, we must do so. The metadata will @@ -431,7 +377,7 @@ func New(conf *boot.Config, args Args) (*Container, error) { c.changeStatus(Created) // Save the metadata file. - if err := c.save(); err != nil { + if err := c.saveLocked(); err != nil { return nil, err } @@ -451,17 +397,12 @@ func New(conf *boot.Config, args Args) (*Container, error) { func (c *Container) Start(conf *boot.Config) error { log.Debugf("Start container %q", c.ID) - unlockRoot, err := maybeLockRootContainer(c.Spec, c.RootContainerDir) - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlockRoot() + unlock := specutils.MakeCleanup(func() { c.Saver.unlock() }) + defer unlock.Clean() - unlock, err := c.lock() - if err != nil { - return err - } - defer unlock() if err := c.requireStatus("start", Created); err != nil { return err } @@ -509,14 +450,15 @@ func (c *Container) Start(conf *boot.Config) error { } c.changeStatus(Running) - if err := c.save(); err != nil { + if err := c.saveLocked(); err != nil { return err } - // Adjust the oom_score_adj for sandbox. This must be done after - // save(). - err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false) - if err != nil { + // Release lock before adjusting OOM score because the lock is acquired there. + unlock.Clean() + + // Adjust the oom_score_adj for sandbox. This must be done after saveLocked(). + if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil { return err } @@ -529,11 +471,10 @@ func (c *Container) Start(conf *boot.Config) error { // to restore a container from its state file. func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error { log.Debugf("Restore container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if err := c.requireStatus("restore", Created); err != nil { return err @@ -551,7 +492,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile str return err } c.changeStatus(Running) - return c.save() + return c.saveLocked() } // Run is a helper that calls Create + Start + Wait. @@ -711,11 +652,10 @@ func (c *Container) Checkpoint(f *os.File) error { // The call only succeeds if the container's status is created or running. func (c *Container) Pause() error { log.Debugf("Pausing container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if c.Status != Created && c.Status != Running { return fmt.Errorf("cannot pause container %q in state %v", c.ID, c.Status) @@ -725,18 +665,17 @@ func (c *Container) Pause() error { return fmt.Errorf("pausing container: %v", err) } c.changeStatus(Paused) - return c.save() + return c.saveLocked() } // Resume unpauses the container and its kernel. // The call only succeeds if the container's status is paused. func (c *Container) Resume() error { log.Debugf("Resuming container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if c.Status != Paused { return fmt.Errorf("cannot resume container %q in state %v", c.ID, c.Status) @@ -745,7 +684,7 @@ func (c *Container) Resume() error { return fmt.Errorf("resuming container: %v", err) } c.changeStatus(Running) - return c.save() + return c.saveLocked() } // State returns the metadata of the container. @@ -773,6 +712,17 @@ func (c *Container) Processes() ([]*control.Process, error) { func (c *Container) Destroy() error { log.Debugf("Destroy container %q", c.ID) + if err := c.Saver.lock(); err != nil { + return err + } + defer func() { + c.Saver.unlock() + c.Saver.close() + }() + + // Stored for later use as stop() sets c.Sandbox to nil. + sb := c.Sandbox + // We must perform the following cleanup steps: // * stop the container and gofer processes, // * remove the container filesystem on the host, and @@ -782,48 +732,43 @@ func (c *Container) Destroy() error { // do our best to perform all of the cleanups. Hence, we keep a slice // of errors return their concatenation. var errs []string - - unlock, err := maybeLockRootContainer(c.Spec, c.RootContainerDir) - if err != nil { - return err - } - defer unlock() - - // Stored for later use as stop() sets c.Sandbox to nil. - sb := c.Sandbox - if err := c.stop(); err != nil { err = fmt.Errorf("stopping container: %v", err) log.Warningf("%v", err) errs = append(errs, err.Error()) } - if err := os.RemoveAll(c.Root); err != nil && !os.IsNotExist(err) { - err = fmt.Errorf("deleting container root directory %q: %v", c.Root, err) + if err := c.Saver.destroy(); err != nil { + err = fmt.Errorf("deleting container state files: %v", err) log.Warningf("%v", err) errs = append(errs, err.Error()) } c.changeStatus(Stopped) - // Adjust oom_score_adj for the sandbox. This must be done after the - // container is stopped and the directory at c.Root is removed. - // We must test if the sandbox is nil because Destroy should be - // idempotent. - if sb != nil { - if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil { + // Adjust oom_score_adj for the sandbox. This must be done after the container + // is stopped and the directory at c.Root is removed. Adjustment can be + // skipped if the root container is exiting, because it brings down the entire + // sandbox. + // + // Use 'sb' to tell whether it has been executed before because Destroy must + // be idempotent. + if sb != nil && !isRoot(c.Spec) { + if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil { errs = append(errs, err.Error()) } } // "If any poststop hook fails, the runtime MUST log a warning, but the - // remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec. - // Based on the OCI, "The post-stop hooks MUST be called after the container is - // deleted but before the delete operation returns" + // remaining hooks and lifecycle continue as if the hook had + // succeeded" - OCI spec. + // + // Based on the OCI, "The post-stop hooks MUST be called after the container + // is deleted but before the delete operation returns" // Run it here to: // 1) Conform to the OCI. - // 2) Make sure it only runs once, because the root has been deleted, the container - // can't be loaded again. + // 2) Make sure it only runs once, because the root has been deleted, the + // container can't be loaded again. if c.Spec.Hooks != nil { executeHooksBestEffort(c.Spec.Hooks.Poststop, c.State()) } @@ -834,18 +779,13 @@ func (c *Container) Destroy() error { return fmt.Errorf(strings.Join(errs, "\n")) } -// save saves the container metadata to a file. +// saveLocked saves the container metadata to a file. // // Precondition: container must be locked with container.lock(). -func (c *Container) save() error { +func (c *Container) saveLocked() error { log.Debugf("Save container %q", c.ID) - metaFile := filepath.Join(c.Root, metadataFilename) - meta, err := json.Marshal(c) - if err != nil { - return fmt.Errorf("invalid container metadata: %v", err) - } - if err := ioutil.WriteFile(metaFile, meta, 0640); err != nil { - return fmt.Errorf("writing container metadata: %v", err) + if err := c.Saver.saveLocked(c); err != nil { + return fmt.Errorf("saving container metadata: %v", err) } return nil } @@ -1106,50 +1046,8 @@ func (c *Container) requireStatus(action string, statuses ...Status) error { return fmt.Errorf("cannot %s container %q in state %s", action, c.ID, c.Status) } -// lock takes a file lock on the container metadata lock file. -func (c *Container) lock() (func() error, error) { - return lockContainerMetadata(filepath.Join(c.Root, c.ID)) -} - -// lockContainerMetadata takes a file lock on the metadata lock file in the -// given container root directory. -func lockContainerMetadata(containerRootDir string) (func() error, error) { - if err := os.MkdirAll(containerRootDir, 0711); err != nil { - return nil, fmt.Errorf("creating container root directory %q: %v", containerRootDir, err) - } - f := filepath.Join(containerRootDir, metadataLockFilename) - l := flock.NewFlock(f) - if err := l.Lock(); err != nil { - return nil, fmt.Errorf("acquiring lock on container lock file %q: %v", f, err) - } - return l.Unlock, nil -} - -// maybeLockRootContainer locks the sandbox root container. It is used to -// prevent races to create and delete child container sandboxes. -func maybeLockRootContainer(spec *specs.Spec, rootDir string) (func() error, error) { - if isRoot(spec) { - return func() error { return nil }, nil - } - - sbid, ok := specutils.SandboxID(spec) - if !ok { - return nil, fmt.Errorf("no sandbox ID found when locking root container") - } - sb, err := Load(rootDir, sbid) - if err != nil { - return nil, err - } - - unlock, err := sb.lock() - if err != nil { - return nil, err - } - return unlock, nil -} - func isRoot(spec *specs.Spec) bool { - return specutils.ShouldCreateSandbox(spec) + return specutils.SpecContainerType(spec) != specutils.ContainerTypeContainer } // runInCgroup executes fn inside the specified cgroup. If cg is nil, execute @@ -1170,7 +1068,12 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error { func (c *Container) adjustGoferOOMScoreAdj() error { if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil { if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil { - return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) + // Ignore NotExist error because it can be returned when the sandbox + // exited while OOM score was being adjusted. + if !os.IsNotExist(err) { + return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) + } + log.Warningf("Gofer process (%d) not found setting oom_score_adj", c.GoferPid) } } @@ -1198,7 +1101,7 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // Get the lowest score for all containers. var lowScore int scoreFound := false - if len(containers) == 1 && len(containers[0].Spec.Annotations[specutils.ContainerdContainerTypeAnnotation]) == 0 { + if len(containers) == 1 && specutils.SpecContainerType(containers[0].Spec) == specutils.ContainerTypeUnspecified { // This is a single-container sandbox. Set the oom_score_adj to // the value specified in the OCI bundle. if containers[0].Spec.Process.OOMScoreAdj != nil { @@ -1214,7 +1117,7 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // // We will use OOMScoreAdj in the single-container case where the // containerd container-type annotation is not present. - if container.Spec.Annotations[specutils.ContainerdContainerTypeAnnotation] == specutils.ContainerdContainerTypeSandbox { + if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox { continue } @@ -1252,7 +1155,12 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // Set the lowest of all containers oom_score_adj to the sandbox. if err := setOOMScoreAdj(s.Pid, lowScore); err != nil { - return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) + // Ignore NotExist error because it can be returned when the sandbox + // exited while OOM score was being adjusted. + if !os.IsNotExist(err) { + return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) + } + log.Warningf("Sandbox process (%d) not found setting oom_score_adj", s.Pid) } return nil diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index c4c56b2e0..c10f85992 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -37,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" @@ -52,8 +53,9 @@ func waitForProcessList(cont *Container, want []*control.Process) error { err = fmt.Errorf("error getting process data from container: %v", err) return &backoff.PermanentError{Err: err} } - if !procListsEqual(got, want) { - return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want)) + if r, err := procListsEqual(got, want); !r { + return fmt.Errorf("container got process list: %s, want: %s: error: %v", + procListToString(got), procListToString(want), err) } return nil } @@ -91,22 +93,34 @@ func blockUntilWaitable(pid int) error { // procListsEqual is used to check whether 2 Process lists are equal for all // implemented fields. -func procListsEqual(got, want []*control.Process) bool { +func procListsEqual(got, want []*control.Process) (bool, error) { if len(got) != len(want) { - return false + return false, nil } for i := range got { pd1 := got[i] pd2 := want[i] - // Zero out unimplemented and timing dependant fields. + // Zero out timing dependant fields. pd1.Time = "" pd1.STime = "" pd1.C = 0 - if *pd1 != *pd2 { - return false + // Ignore TTY field too, since it's not relevant in the cases + // where we use this method. Tests that care about the TTY + // field should check for it themselves. + pd1.TTY = "" + pd1Json, err := control.ProcessListToJSON([]*control.Process{pd1}) + if err != nil { + return false, err + } + pd2Json, err := control.ProcessListToJSON([]*control.Process{pd2}) + if err != nil { + return false, err + } + if pd1Json != pd2Json { + return false, nil } } - return true + return true, nil } // getAndCheckProcLists is similar to waitForProcessList, but does not wait and retry the @@ -116,7 +130,11 @@ func getAndCheckProcLists(cont *Container, want []*control.Process) error { if err != nil { return fmt.Errorf("error getting process data from container: %v", err) } - if procListsEqual(got, want) { + equal, err := procListsEqual(got, want) + if err != nil { + return err + } + if equal { return nil } return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want)) @@ -288,11 +306,12 @@ func TestLifecycle(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, } // Create the container. @@ -590,18 +609,20 @@ func TestExec(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: uid, + PID: 2, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{2}, }, } @@ -1062,18 +1083,20 @@ func TestPauseResume(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "bash", + UID: uid, + PID: 2, + PPID: 0, + C: 0, + Cmd: "bash", + Threads: []kernel.ThreadID{2}, }, } @@ -1126,11 +1149,12 @@ func TestPauseResume(t *testing.T) { expectedPL2 := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, } @@ -1241,18 +1265,20 @@ func TestCapabilities(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "exe", + UID: uid, + PID: 2, + PPID: 0, + C: 0, + Cmd: "exe", + Threads: []kernel.ThreadID{2}, }, } if err := waitForProcessList(cont, expectedPL[:1]); err != nil { @@ -1548,7 +1574,8 @@ func TestAbbreviatedIDs(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfigWithRoot(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir cids := []string{ "foo-" + testutil.UniqueContainerID(), @@ -2111,6 +2138,95 @@ func TestOverlayfsStaleRead(t *testing.T) { } } +// TestTTYField checks TTY field returned by container.Processes(). +func TestTTYField(t *testing.T) { + stop := testutil.StartReaper() + defer stop() + + testApp, err := testutil.FindFile("runsc/container/test_app/test_app") + if err != nil { + t.Fatal("error finding test_app:", err) + } + + testCases := []struct { + name string + useTTY bool + wantTTYField string + }{ + { + name: "no tty", + useTTY: false, + wantTTYField: "?", + }, + { + name: "tty used", + useTTY: true, + wantTTYField: "pts/0", + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + conf := testutil.TestConfig() + + // We will run /bin/sleep, possibly with an open TTY. + cmd := []string{"/bin/sleep", "10000"} + if test.useTTY { + // Run inside the "pty-runner". + cmd = append([]string{testApp, "pty-runner"}, cmd...) + } + + spec := testutil.NewSpecWithArgs(cmd...) + rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) + if err != nil { + t.Fatalf("error setting up container: %v", err) + } + defer os.RemoveAll(rootDir) + defer os.RemoveAll(bundleDir) + + // Create and start the container. + args := Args{ + ID: testutil.UniqueContainerID(), + Spec: spec, + BundleDir: bundleDir, + } + c, err := New(conf, args) + if err != nil { + t.Fatalf("error creating container: %v", err) + } + defer c.Destroy() + if err := c.Start(conf); err != nil { + t.Fatalf("error starting container: %v", err) + } + + // Wait for sleep to be running, and check the TTY + // field. + var gotTTYField string + cb := func() error { + ps, err := c.Processes() + if err != nil { + err = fmt.Errorf("error getting process data from container: %v", err) + return &backoff.PermanentError{Err: err} + } + for _, p := range ps { + if strings.Contains(p.Cmd, "sleep") { + gotTTYField = p.TTY + return nil + } + } + return fmt.Errorf("sleep not running") + } + if err := testutil.Poll(cb, 30*time.Second); err != nil { + t.Fatalf("error waiting for sleep process: %v", err) + } + + if gotTTYField != test.wantTTYField { + t.Errorf("tty field got %q, want %q", gotTTYField, test.wantTTYField) + } + }) + } +} + // executeSync synchronously executes a new process. func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { pid, err := cont.Execute(args) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 9e02a825e..4ad09ceab 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -60,13 +60,8 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { } func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { - // Setup root dir if one hasn't been provided. if len(conf.RootDir) == 0 { - rootDir, err := testutil.SetupRootDir() - if err != nil { - return nil, nil, fmt.Errorf("error creating root dir: %v", err) - } - conf.RootDir = rootDir + panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } var containers []*Container @@ -78,7 +73,6 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*C for _, b := range bundles { os.RemoveAll(b) } - os.RemoveAll(conf.RootDir) } for i, spec := range specs { bundleDir, err := testutil.SetupBundleDir(spec) @@ -129,11 +123,11 @@ func execMany(execs []execDesc) error { func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { for _, spec := range pod { - spec.Annotations[path.Join(boot.MountPrefix, name, "source")] = mount.Source - spec.Annotations[path.Join(boot.MountPrefix, name, "type")] = mount.Type - spec.Annotations[path.Join(boot.MountPrefix, name, "share")] = "pod" + spec.Annotations[boot.MountPrefix+name+".source"] = mount.Source + spec.Annotations[boot.MountPrefix+name+".type"] = mount.Type + spec.Annotations[boot.MountPrefix+name+".share"] = "pod" if len(mount.Options) > 0 { - spec.Annotations[path.Join(boot.MountPrefix, name, "options")] = strings.Join(mount.Options, ",") + spec.Annotations[boot.MountPrefix+name+".options"] = strings.Join(mount.Options, ",") } } } @@ -144,6 +138,13 @@ func TestMultiContainerSanity(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep) @@ -155,13 +156,13 @@ func TestMultiContainerSanity(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) } expectedPL = []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -175,6 +176,13 @@ func TestMultiPIDNS(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} testSpecs, ids := createSpecs(sleep, sleep) @@ -194,13 +202,13 @@ func TestMultiPIDNS(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) } expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -213,6 +221,13 @@ func TestMultiPIDNSPath(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} testSpecs, ids := createSpecs(sleep, sleep, sleep) @@ -249,7 +264,7 @@ func TestMultiPIDNSPath(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -259,7 +274,7 @@ func TestMultiPIDNSPath(t *testing.T) { } expectedPL = []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -268,13 +283,21 @@ func TestMultiPIDNSPath(t *testing.T) { } func TestMultiContainerWait(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // The first container should run the entire duration of the test. cmd1 := []string{"sleep", "100"} // We'll wait on the second container, which is much shorter lived. cmd2 := []string{"sleep", "1"} specs, ids := createSpecs(cmd1, cmd2) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -283,7 +306,7 @@ func TestMultiContainerWait(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -328,7 +351,7 @@ func TestMultiContainerWait(t *testing.T) { // After Wait returns, ensure that the root container is running and // the child has finished. expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for %q to start: %v", strings.Join(containers[0].Spec.Process.Args, " "), err) @@ -344,12 +367,14 @@ func TestExecWait(t *testing.T) { } defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + // The first container should run the entire duration of the test. cmd1 := []string{"sleep", "100"} // We'll wait on the second container, which is much shorter lived. cmd2 := []string{"sleep", "1"} specs, ids := createSpecs(cmd1, cmd2) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -358,7 +383,7 @@ func TestExecWait(t *testing.T) { // Check via ps that process is running. expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Fatalf("failed to wait for sleep to start: %v", err) @@ -393,7 +418,7 @@ func TestExecWait(t *testing.T) { // Wait for the exec'd process to exit. expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Fatalf("failed to wait for second container to stop: %v", err) @@ -432,7 +457,15 @@ func TestMultiContainerMount(t *testing.T) { }) // Setup the containers. + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + containers, cleanup, err := startContainers(conf, sps, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -454,6 +487,13 @@ func TestMultiContainerSignal(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep) @@ -465,7 +505,7 @@ func TestMultiContainerSignal(t *testing.T) { // Check via ps that container 1 process is running. expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { @@ -479,7 +519,7 @@ func TestMultiContainerSignal(t *testing.T) { // Make sure process 1 is still running. expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -548,6 +588,13 @@ func TestMultiContainerDestroy(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // First container will remain intact while the second container is killed. podSpecs, ids := createSpecs( []string{"sleep", "100"}, @@ -586,9 +633,10 @@ func TestMultiContainerDestroy(t *testing.T) { if err != nil { t.Fatalf("error getting process data from sandbox: %v", err) } - expectedPL := []*control.Process{{PID: 1, Cmd: "sleep"}} - if !procListsEqual(pss, expectedPL) { - t.Errorf("container got process list: %s, want: %s", procListToString(pss), procListToString(expectedPL)) + expectedPL := []*control.Process{{PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}} + if r, err := procListsEqual(pss, expectedPL); !r { + t.Errorf("container got process list: %s, want: %s: error: %v", + procListToString(pss), procListToString(expectedPL), err) } // Check that cont.Destroy is safe to call multiple times. @@ -599,13 +647,21 @@ func TestMultiContainerDestroy(t *testing.T) { } func TestMultiContainerProcesses(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Note: use curly braces to keep 'sh' process around. Otherwise, shell // will just execve into 'sleep' and both containers will look the // same. specs, ids := createSpecs( []string{"sleep", "100"}, []string{"sh", "-c", "{ sleep 100; }"}) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -614,7 +670,7 @@ func TestMultiContainerProcesses(t *testing.T) { // Check root's container process list doesn't include other containers. expectedPL0 := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL0); err != nil { t.Errorf("failed to wait for process to start: %v", err) @@ -622,8 +678,8 @@ func TestMultiContainerProcesses(t *testing.T) { // Same for the other container. expectedPL1 := []*control.Process{ - {PID: 2, Cmd: "sh"}, - {PID: 3, PPID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sh", Threads: []kernel.ThreadID{2}}, + {PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}}, } if err := waitForProcessList(containers[1], expectedPL1); err != nil { t.Errorf("failed to wait for process to start: %v", err) @@ -637,7 +693,7 @@ func TestMultiContainerProcesses(t *testing.T) { if _, err := containers[1].Execute(args); err != nil { t.Fatalf("error exec'ing: %v", err) } - expectedPL1 = append(expectedPL1, &control.Process{PID: 4, Cmd: "sleep"}) + expectedPL1 = append(expectedPL1, &control.Process{PID: 4, Cmd: "sleep", Threads: []kernel.ThreadID{4}}) if err := waitForProcessList(containers[1], expectedPL1); err != nil { t.Errorf("failed to wait for process to start: %v", err) } @@ -650,6 +706,15 @@ func TestMultiContainerProcesses(t *testing.T) { // TestMultiContainerKillAll checks that all process that belong to a container // are killed when SIGKILL is sent to *all* processes in that container. func TestMultiContainerKillAll(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + for _, tc := range []struct { killContainer bool }{ @@ -665,7 +730,6 @@ func TestMultiContainerKillAll(t *testing.T) { specs, ids := createSpecs( []string{app, "task-tree", "--depth=2", "--width=2"}, []string{app, "task-tree", "--depth=4", "--width=2"}) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -739,19 +803,13 @@ func TestMultiContainerDestroyNotStarted(t *testing.T) { specs, ids := createSpecs( []string{"/bin/sleep", "100"}, []string{"/bin/sleep", "100"}) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfigWithRoot(rootDir) - // Create and start root container. - rootBundleDir, err := testutil.SetupBundleDir(specs[0]) + conf := testutil.TestConfig() + rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(rootBundleDir) rootArgs := Args{ @@ -800,19 +858,12 @@ func TestMultiContainerDestroyStarting(t *testing.T) { } specs, ids := createSpecs(cmds...) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfigWithRoot(rootDir) - - // Create and start root container. - rootBundleDir, err := testutil.SetupBundleDir(specs[0]) + conf := testutil.TestConfig() + rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(rootBundleDir) rootArgs := Args{ @@ -886,9 +937,17 @@ func TestMultiContainerDifferentFilesystems(t *testing.T) { script := fmt.Sprintf("if [ -f %q ]; then exit 1; else touch %q; fi", filename, filename) cmd := []string{"sh", "-c", script} + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Make sure overlay is enabled, and none of the root filesystems are // read-only, otherwise we won't be able to create the file. - conf := testutil.TestConfig() conf.Overlay = true specs, ids := createSpecs(cmdRoot, cmd, cmd) for _, s := range specs { @@ -941,26 +1000,21 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { } allSpecs, allIDs := createSpecs(cmds...) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - // Split up the specs and IDs. rootSpec := allSpecs[0] rootID := allIDs[0] childrenSpecs := allSpecs[1:] childrenIDs := allIDs[1:] - bundleDir, err := testutil.SetupBundleDir(rootSpec) + conf := testutil.TestConfig() + rootDir, bundleDir, err := testutil.SetupContainer(rootSpec, conf) if err != nil { - t.Fatalf("error setting up bundle dir: %v", err) + t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(bundleDir) // Start root container. - conf := testutil.TestConfigWithRoot(rootDir) rootArgs := Args{ ID: rootID, Spec: rootSpec, @@ -1029,6 +1083,13 @@ func TestMultiContainerSharedMount(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1137,6 +1198,13 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1197,6 +1265,13 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1300,8 +1375,14 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { // Test that unsupported pod mounts options are ignored when matching master and // slave mounts. func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() - t.Logf("Running test with conf: %+v", conf) + conf.RootDir = rootDir // Setup the containers. sleep := []string{"/bin/sleep", "100"} @@ -1376,6 +1457,15 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { Type: "tmpfs", } + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Create the specs. specs, ids := createSpecs( []string{"sleep", "1000"}, @@ -1386,7 +1476,6 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { specs[1].Mounts = append(specs[2].Mounts, sharedMnt, writeableMnt) specs[2].Mounts = append(specs[1].Mounts, sharedMnt) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -1405,9 +1494,17 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { // Test that container is destroyed when Gofer is killed. func TestMultiContainerGoferKilled(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep, sleep) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -1417,7 +1514,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { // Ensure container is running c := containers[2] expectedPL := []*control.Process{ - {PID: 3, Cmd: "sleep"}, + {PID: 3, Cmd: "sleep", Threads: []kernel.ThreadID{3}}, } if err := waitForProcessList(c, expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -1445,7 +1542,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { continue // container[2] has been killed. } pl := []*control.Process{ - {PID: kernel.ThreadID(i + 1), Cmd: "sleep"}, + {PID: kernel.ThreadID(i + 1), Cmd: "sleep", Threads: []kernel.ThreadID{kernel.ThreadID(i + 1)}}, } if err := waitForProcessList(c, pl); err != nil { t.Errorf("Container %q was affected by another container: %v", c.ID, err) @@ -1465,7 +1562,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { // Wait until sandbox stops. waitForProcessList will loop until sandbox exits // and RPC errors out. impossiblePL := []*control.Process{ - {PID: 100, Cmd: "non-existent-process"}, + {PID: 100, Cmd: "non-existent-process", Threads: []kernel.ThreadID{100}}, } if err := waitForProcessList(c, impossiblePL); err == nil { t.Fatalf("Sandbox was not killed after gofer death") @@ -1483,7 +1580,15 @@ func TestMultiContainerGoferKilled(t *testing.T) { func TestMultiContainerLoadSandbox(t *testing.T) { sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep, sleep) + + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir // Create containers for the sandbox. wants, cleanup, err := startContainers(conf, specs, ids) @@ -1576,7 +1681,15 @@ func TestMultiContainerRunNonRoot(t *testing.T) { Type: "bind", }) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + pod, cleanup, err := startContainers(conf, podSpecs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go new file mode 100644 index 000000000..d95151ea5 --- /dev/null +++ b/runsc/container/state_file.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package container + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sync" + + "github.com/gofrs/flock" + "gvisor.dev/gvisor/pkg/log" +) + +const stateFileExtension = ".state" + +// StateFile handles load from/save to container state safely from multiple +// processes. It uses a lock file to provide synchronization between operations. +// +// The lock file is located at: "${s.RootDir}/${s.ID}.lock". +// The state file is located at: "${s.RootDir}/${s.ID}.state". +type StateFile struct { + // RootDir is the directory containing the container metadata file. + RootDir string `json:"rootDir"` + + // ID is the container ID. + ID string `json:"id"` + + // + // Fields below this line are not saved in the state file and will not + // be preserved across commands. + // + + once sync.Once + flock *flock.Flock +} + +// List returns all container ids in the given root directory. +func List(rootDir string) ([]string, error) { + log.Debugf("List containers %q", rootDir) + list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension)) + if err != nil { + return nil, err + } + var out []string + for _, path := range list { + // Filter out files that do no belong to a container. + fileName := filepath.Base(path) + if len(fileName) < len(stateFileExtension) { + panic(fmt.Sprintf("invalid file match %q", path)) + } + // Remove the extension. + cid := fileName[:len(fileName)-len(stateFileExtension)] + if validateID(cid) == nil { + out = append(out, cid) + } + } + return out, nil +} + +// lock globally locks all locking operations for the container. +func (s *StateFile) lock() error { + s.once.Do(func() { + s.flock = flock.NewFlock(s.lockPath()) + }) + + if err := s.flock.Lock(); err != nil { + return fmt.Errorf("acquiring lock on %q: %v", s.flock, err) + } + return nil +} + +// lockForNew acquires the lock and checks if the state file doesn't exist. This +// is done to ensure that more than one creation didn't race to create +// containers with the same ID. +func (s *StateFile) lockForNew() error { + if err := s.lock(); err != nil { + return err + } + + // Checks if the container already exists by looking for the metadata file. + if _, err := os.Stat(s.statePath()); err == nil { + s.unlock() + return fmt.Errorf("container already exists") + } else if !os.IsNotExist(err) { + s.unlock() + return fmt.Errorf("looking for existing container: %v", err) + } + return nil +} + +// unlock globally unlocks all locking operations for the container. +func (s *StateFile) unlock() error { + if !s.flock.Locked() { + panic("unlock called without lock held") + } + + if err := s.flock.Unlock(); err != nil { + log.Warningf("Error to release lock on %q: %v", s.flock, err) + return fmt.Errorf("releasing lock on %q: %v", s.flock, err) + } + return nil +} + +// saveLocked saves 'v' to the state file. +// +// Preconditions: lock() must been called before. +func (s *StateFile) saveLocked(v interface{}) error { + if !s.flock.Locked() { + panic("saveLocked called without lock held") + } + + meta, err := json.Marshal(v) + if err != nil { + return err + } + if err := ioutil.WriteFile(s.statePath(), meta, 0640); err != nil { + return fmt.Errorf("writing json file: %v", err) + } + return nil +} + +func (s *StateFile) load(v interface{}) error { + if err := s.lock(); err != nil { + return err + } + defer s.unlock() + + metaBytes, err := ioutil.ReadFile(s.statePath()) + if err != nil { + return err + } + return json.Unmarshal(metaBytes, &v) +} + +func (s *StateFile) close() error { + if s.flock == nil { + return nil + } + if s.flock.Locked() { + panic("Closing locked file") + } + return s.flock.Close() +} + +func buildStatePath(rootDir, id string) string { + return filepath.Join(rootDir, id+stateFileExtension) +} + +// statePath is the full path to the state file. +func (s *StateFile) statePath() string { + return buildStatePath(s.RootDir, s.ID) +} + +// lockPath is the full path to the lock file. +func (s *StateFile) lockPath() string { + return filepath.Join(s.RootDir, s.ID+".lock") +} + +// destroy deletes all state created by the stateFile. It may be called with the +// lock file held. In that case, the lock file must still be unlocked and +// properly closed after destroy returns. +func (s *StateFile) destroy() error { + if err := os.Remove(s.statePath()); err != nil && !os.IsNotExist(err) { + return err + } + if err := os.Remove(s.lockPath()); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD index 9bf9e6e9d..bfd338bb6 100644 --- a/runsc/container/test_app/BUILD +++ b/runsc/container/test_app/BUILD @@ -15,5 +15,6 @@ go_binary( "//pkg/unet", "//runsc/testutil", "@com_github_google_subcommands//:go_default_library", + "@com_github_kr_pty//:go_default_library", ], ) diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go index 913d781c6..a1c8a741a 100644 --- a/runsc/container/test_app/test_app.go +++ b/runsc/container/test_app/test_app.go @@ -19,6 +19,7 @@ package main import ( "context" "fmt" + "io" "io/ioutil" "log" "net" @@ -31,6 +32,7 @@ import ( "flag" "github.com/google/subcommands" + "github.com/kr/pty" "gvisor.dev/gvisor/runsc/testutil" ) @@ -41,6 +43,7 @@ func main() { subcommands.Register(new(fdReceiver), "") subcommands.Register(new(fdSender), "") subcommands.Register(new(forkBomb), "") + subcommands.Register(new(ptyRunner), "") subcommands.Register(new(reaper), "") subcommands.Register(new(syscall), "") subcommands.Register(new(taskTree), "") @@ -352,3 +355,40 @@ func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...inter return subcommands.ExitSuccess } + +type ptyRunner struct{} + +// Name implements subcommands.Command. +func (*ptyRunner) Name() string { + return "pty-runner" +} + +// Synopsis implements subcommands.Command. +func (*ptyRunner) Synopsis() string { + return "runs the given command with an open pty terminal" +} + +// Usage implements subcommands.Command. +func (*ptyRunner) Usage() string { + return "pty-runner [command]" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (*ptyRunner) SetFlags(f *flag.FlagSet) {} + +// Execute implements subcommands.Command. +func (*ptyRunner) Execute(_ context.Context, fs *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + c := exec.Command(fs.Args()[0], fs.Args()[1:]...) + f, err := pty.Start(c) + if err != nil { + fmt.Printf("pty.Start failed: %v", err) + return subcommands.ExitFailure + } + defer f.Close() + + // Copy stdout from the command to keep this process alive until the + // subprocess exits. + io.Copy(os.Stdout, f) + + return subcommands.ExitSuccess +} diff --git a/runsc/debian/description b/runsc/debian/description index 6e3b1b2c0..9e8e08805 100644 --- a/runsc/debian/description +++ b/runsc/debian/description @@ -1,5 +1 @@ -gVisor is a user-space kernel, written in Go, that implements a substantial -portion of the Linux system surface. It includes an Open Container Initiative -(OCI) runtime called runsc that provides an isolation boundary between the -application and the host kernel. The runsc runtime integrates with Docker and -Kubernetes, making it simple to run sandboxed containers. +gVisor container sandbox runtime diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go index 57f6ae8de..9b6346ca2 100644 --- a/runsc/dockerutil/dockerutil.go +++ b/runsc/dockerutil/dockerutil.go @@ -380,6 +380,16 @@ func (d *Docker) FindPort(sandboxPort int) (int, error) { return port, nil } +// FindIP returns the IP address of the container as a string. +func (d *Docker) FindIP() (string, error) { + const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}` + out, err := do("inspect", "-f", format, d.Name) + if err != nil { + return "", fmt.Errorf("error retrieving IP: %v", err) + } + return strings.TrimSpace(out), nil +} + // SandboxPid returns the PID to the sandbox process. func (d *Docker) SandboxPid() (int, error) { out, err := do("inspect", "-f={{.State.Pid}}", d.Name) diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index 80a4aa2fe..afcb41801 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -6,6 +6,8 @@ go_library( name = "fsgofer", srcs = [ "fsgofer.go", + "fsgofer_amd64_unsafe.go", + "fsgofer_arm64_unsafe.go", "fsgofer_unsafe.go", ], importpath = "gvisor.dev/gvisor/runsc/fsgofer", diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD index 02168ad1b..bac73f89d 100644 --- a/runsc/fsgofer/filter/BUILD +++ b/runsc/fsgofer/filter/BUILD @@ -6,6 +6,8 @@ go_library( name = "filter", srcs = [ "config.go", + "config_amd64.go", + "config_arm64.go", "extra_filters.go", "extra_filters_msan.go", "extra_filters_race.go", diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index 0bf7507b7..a1792330f 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -25,11 +25,7 @@ import ( // allowedSyscalls is the set of syscalls executed by the gofer. var allowedSyscalls = seccomp.SyscallRules{ - syscall.SYS_ACCEPT: {}, - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_GET_FS)}, - {seccomp.AllowValue(linux.ARCH_SET_FS)}, - }, + syscall.SYS_ACCEPT: {}, syscall.SYS_CLOCK_GETTIME: {}, syscall.SYS_CLONE: []seccomp.Rule{ { @@ -155,7 +151,6 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_MPROTECT: {}, syscall.SYS_MUNMAP: {}, syscall.SYS_NANOSLEEP: {}, - syscall.SYS_NEWFSTATAT: {}, syscall.SYS_OPENAT: {}, syscall.SYS_PPOLL: {}, syscall.SYS_PREAD64: {}, @@ -220,6 +215,18 @@ var udsSyscalls = seccomp.SyscallRules{ syscall.SYS_SOCKET: []seccomp.Rule{ { seccomp.AllowValue(syscall.AF_UNIX), + seccomp.AllowValue(syscall.SOCK_STREAM), + seccomp.AllowValue(0), + }, + { + seccomp.AllowValue(syscall.AF_UNIX), + seccomp.AllowValue(syscall.SOCK_DGRAM), + seccomp.AllowValue(0), + }, + { + seccomp.AllowValue(syscall.AF_UNIX), + seccomp.AllowValue(syscall.SOCK_SEQPACKET), + seccomp.AllowValue(0), }, }, syscall.SYS_CONNECT: []seccomp.Rule{ diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go new file mode 100644 index 000000000..a4b28cb8b --- /dev/null +++ b/runsc/fsgofer/filter/config_amd64.go @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package filter + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" +) + +func init() { + allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{ + {seccomp.AllowValue(linux.ARCH_GET_FS)}, + {seccomp.AllowValue(linux.ARCH_SET_FS)}, + } + + allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{} +} diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go new file mode 100644 index 000000000..d2697deb7 --- /dev/null +++ b/runsc/fsgofer/filter/config_arm64.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package filter + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + +func init() { + allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{} +} diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index ed8b02cf0..b59e1a70e 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -199,6 +199,7 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // The reason that the file is not opened initially as read-write is for better // performance with 'overlay2' storage driver. overlay2 eagerly copies the // entire file up when it's opened in write mode, and would perform badly when +// multiple files are only being opened for read (esp. startup). type localFile struct { p9.DefaultWalkGetAttr @@ -265,10 +266,10 @@ func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, erro // actual file open and is customizable by the caller. func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) { // Attempt to open file in the following mode in order: - // 1. RDONLY | NONBLOCK: for all files, works for directories and ro mounts too. - // Use non-blocking to prevent getting stuck inside open(2) for FIFOs. This option - // has no effect on regular files. - // 2. PATH: for symlinks + // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs. + // Use non-blocking to prevent getting stuck inside open(2) for + // FIFOs. This option has no effect on regular files. + // 2. PATH: for symlinks, sockets. modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH} var err error @@ -366,23 +367,24 @@ func fchown(fd int, uid p9.UID, gid p9.GID) error { } // Open implements p9.File. -func (l *localFile) Open(mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { +func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { if l.isOpen() { panic(fmt.Sprintf("attempting to open already opened file: %q", l.hostPath)) } // Check if control file can be used or if a new open must be created. var newFile *fd.FD - if mode == p9.ReadOnly { - log.Debugf("Open reusing control file, mode: %v, %q", mode, l.hostPath) + if flags == p9.ReadOnly { + log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath) newFile = l.file } else { // Ideally reopen would call name_to_handle_at (with empty name) and // open_by_handle_at to reopen the file without using 'hostPath'. However, // name_to_handle_at and open_by_handle_at aren't supported by overlay2. - log.Debugf("Open reopening file, mode: %v, %q", mode, l.hostPath) + log.Debugf("Open reopening file, flags: %v, %q", flags, l.hostPath) var err error - newFile, err = reopenProcFd(l.file, openFlags|mode.OSFlags()) + // Constrain open flags to the open mode and O_TRUNC. + newFile, err = reopenProcFd(l.file, openFlags|(flags.OSFlags()&(syscall.O_ACCMODE|syscall.O_TRUNC))) if err != nil { return nil, p9.QID{}, 0, extractErrno(err) } @@ -409,7 +411,7 @@ func (l *localFile) Open(mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { } l.file = newFile } - l.mode = mode + l.mode = flags & p9.OpenFlagsModeMask return fd, l.attachPoint.makeQID(stat), 0, nil } @@ -601,7 +603,7 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), GID: p9.GID(stat.Gid), - NLink: stat.Nlink, + NLink: uint64(stat.Nlink), RDev: stat.Rdev, Size: uint64(stat.Size), BlockSize: uint64(stat.Blksize), @@ -955,14 +957,14 @@ func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) { } func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) ([]p9.Dirent, error) { + var dirents []p9.Dirent + // Limit 'count' to cap the slice size that is returned. const maxCount = 100000 if count > maxCount { count = maxCount } - dirents := make([]p9.Dirent, 0, count) - // Pre-allocate buffers that will be reused to get partial results. direntsBuf := make([]byte, 8192) names := make([]string, 0, 100) @@ -1032,12 +1034,48 @@ func (l *localFile) Flush() error { } // Connect implements p9.File. -func (l *localFile) Connect(p9.ConnectFlags) (*fd.FD, error) { - // Check to see if the CLI option has been set to allow the UDS mount. +func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { if !l.attachPoint.conf.HostUDS { return nil, syscall.ECONNREFUSED } - return fd.DialUnix(l.hostPath) + + // TODO(gvisor.dev/issue/1003): Due to different app vs replacement + // mappings, the app path may have fit in the sockaddr, but we can't + // fit f.path in our sockaddr. We'd need to redirect through a shorter + // path in order to actually connect to this socket. + if len(l.hostPath) > linux.UnixPathMax { + return nil, syscall.ECONNREFUSED + } + + var stype int + switch flags { + case p9.StreamSocket: + stype = syscall.SOCK_STREAM + case p9.DgramSocket: + stype = syscall.SOCK_DGRAM + case p9.SeqpacketSocket: + stype = syscall.SOCK_SEQPACKET + default: + return nil, syscall.ENXIO + } + + f, err := syscall.Socket(syscall.AF_UNIX, stype, 0) + if err != nil { + return nil, err + } + + if err := syscall.SetNonblock(f, true); err != nil { + syscall.Close(f) + return nil, err + } + + sa := syscall.SockaddrUnix{Name: l.hostPath} + if err := syscall.Connect(f, &sa); err != nil { + syscall.Close(f) + return nil, err + } + + return fd.New(f), nil } // Close implements p9.File. diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go new file mode 100644 index 000000000..5d4aab597 --- /dev/null +++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go @@ -0,0 +1,49 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package fsgofer + +import ( + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/syserr" +) + +func statAt(dirFd int, name string) (syscall.Stat_t, error) { + nameBytes, err := syscall.BytePtrFromString(name) + if err != nil { + return syscall.Stat_t{}, err + } + namePtr := unsafe.Pointer(nameBytes) + + var stat syscall.Stat_t + statPtr := unsafe.Pointer(&stat) + + if _, _, errno := syscall.Syscall6( + syscall.SYS_NEWFSTATAT, + uintptr(dirFd), + uintptr(namePtr), + uintptr(statPtr), + linux.AT_SYMLINK_NOFOLLOW, + 0, + 0); errno != 0 { + + return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + } + return stat, nil +} diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go new file mode 100644 index 000000000..8041fd352 --- /dev/null +++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go @@ -0,0 +1,49 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package fsgofer + +import ( + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/syserr" +) + +func statAt(dirFd int, name string) (syscall.Stat_t, error) { + nameBytes, err := syscall.BytePtrFromString(name) + if err != nil { + return syscall.Stat_t{}, err + } + namePtr := unsafe.Pointer(nameBytes) + + var stat syscall.Stat_t + statPtr := unsafe.Pointer(&stat) + + if _, _, errno := syscall.Syscall6( + syscall.SYS_FSTATAT, + uintptr(dirFd), + uintptr(namePtr), + uintptr(statPtr), + linux.AT_SYMLINK_NOFOLLOW, + 0, + 0); errno != 0 { + + return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + } + return stat, nil +} diff --git a/runsc/fsgofer/fsgofer_unsafe.go b/runsc/fsgofer/fsgofer_unsafe.go index ff2556aee..542b54365 100644 --- a/runsc/fsgofer/fsgofer_unsafe.go +++ b/runsc/fsgofer/fsgofer_unsafe.go @@ -18,34 +18,9 @@ import ( "syscall" "unsafe" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) -func statAt(dirFd int, name string) (syscall.Stat_t, error) { - nameBytes, err := syscall.BytePtrFromString(name) - if err != nil { - return syscall.Stat_t{}, err - } - namePtr := unsafe.Pointer(nameBytes) - - var stat syscall.Stat_t - statPtr := unsafe.Pointer(&stat) - - if _, _, errno := syscall.Syscall6( - syscall.SYS_NEWFSTATAT, - uintptr(dirFd), - uintptr(namePtr), - uintptr(statPtr), - linux.AT_SYMLINK_NOFOLLOW, - 0, - 0); errno != 0 { - - return syscall.Stat_t{}, syserr.FromHost(errno).ToError() - } - return stat, nil -} - func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) error { // utimensat(2) doesn't accept empty name, instead name must be nil to make it // operate directly on 'dirFd' unlike other *at syscalls. diff --git a/runsc/main.go b/runsc/main.go index 80b2d300c..abf929511 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -26,6 +26,7 @@ import ( "path/filepath" "strings" "syscall" + "time" "flag" @@ -46,6 +47,8 @@ var ( logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") debug = flag.Bool("debug", false, "enable debug logging.") showVersion = flag.Bool("version", false, "show version and exit.") + // TODO(gvisor.dev/issue/193): support systemd cgroups + systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.") // These flags are unique to runsc, and are used to configure parts of the // system that are not covered by the runtime spec. @@ -66,7 +69,8 @@ var ( // Flags that control sandbox runtime behavior. platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") - gso = flag.Bool("gso", true, "enable generic segmenation offload.") + hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") + softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware ofload can't be enabled.") fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") @@ -78,6 +82,7 @@ var ( numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.") + cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") // Test flags, not to be used outside tests, ever. testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") @@ -135,6 +140,12 @@ func main() { os.Exit(0) } + // TODO(gvisor.dev/issue/193): support systemd cgroups + if *systemdCgroup { + fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193") + os.Exit(1) + } + var errorLogger io.Writer if *logFD > -1 { errorLogger = os.NewFile(uintptr(*logFD), "error log file") @@ -200,7 +211,8 @@ func main() { FSGoferHostUDS: *fsGoferHostUDS, Overlay: *overlay, Network: netType, - GSO: *gso, + HardwareGSO: *hardwareGSO, + SoftwareGSO: *softwareGSO, LogPackets: *logPackets, Platform: platformType, Strace: *strace, @@ -214,6 +226,7 @@ func main() { AlsoLogToStderr: *alsoLogToStderr, ReferenceLeakMode: refsLeakMode, OverlayfsStaleRead: *overlayfsStaleRead, + CPUNumFromQuota: *cpuNumFromQuota, TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, TestOnlyTestNameEnv: *testOnlyTestNameEnv, @@ -227,6 +240,18 @@ func main() { log.SetLevel(log.Debug) } + // Logging will include the local date and time via the time package. + // + // On first use, time.Local initializes the local time zone, which + // involves opening tzdata files on the host. Since this requires + // opening host files, it must be done before syscall filter + // installation. + // + // Generally there will be a log message before filter installation + // that will force initialization, but force initialization here in + // case that does not occur. + _ = time.Local.String() + subcommand := flag.CommandLine.Arg(0) var e log.Emitter diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 7fdceaab6..8001949d5 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -19,6 +19,8 @@ go_library( "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/platform", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", "//pkg/urpc", "//runsc/boot", "//runsc/boot/platforms", diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 5634f0707..be8b72b3e 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -28,6 +28,8 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" @@ -61,7 +63,7 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi // Build the path to the net namespace of the sandbox process. // This is what we will copy. nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net") - if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.GSO, conf.NumNetworkChannels); err != nil { + if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.NumNetworkChannels); err != nil { return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err) } case boot.NetworkHost: @@ -136,7 +138,7 @@ func isRootNS() (bool, error) { // createInterfacesAndRoutesFromNS scrapes the interface and routes from the // net namespace with the given path, creates them in the sandbox, and removes // them from the host. -func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO bool, numNetworkChannels int) error { +func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, numNetworkChannels int) error { // Join the network namespace that we will be copying. restore, err := joinNetNS(nsPath) if err != nil { @@ -182,36 +184,39 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO continue } - // Keep only IPv4 addresses. - var ip4addrs []*net.IPNet + var ipAddrs []*net.IPNet for _, ifaddr := range allAddrs { ipNet, ok := ifaddr.(*net.IPNet) if !ok { return fmt.Errorf("address is not IPNet: %+v", ifaddr) } - if ipNet.IP.To4() == nil { - log.Warningf("IPv6 is not supported, skipping: %v", ipNet) - continue - } - ip4addrs = append(ip4addrs, ipNet) + ipAddrs = append(ipAddrs, ipNet) } - if len(ip4addrs) == 0 { - log.Warningf("No IPv4 address found for interface %q, skipping", iface.Name) + if len(ipAddrs) == 0 { + log.Warningf("No usable IP addresses found for interface %q, skipping", iface.Name) continue } // Scrape the routes before removing the address, since that // will remove the routes as well. - routes, def, err := routesForIface(iface) + routes, defv4, defv6, err := routesForIface(iface) if err != nil { return fmt.Errorf("getting routes for interface %q: %v", iface.Name, err) } - if def != nil { - if !args.DefaultGateway.Route.Empty() { - return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, def, args.DefaultGateway) + if defv4 != nil { + if !args.Defaultv4Gateway.Route.Empty() { + return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, defv4, args.Defaultv4Gateway) } - args.DefaultGateway.Route = *def - args.DefaultGateway.Name = iface.Name + args.Defaultv4Gateway.Route = *defv4 + args.Defaultv4Gateway.Name = iface.Name + } + + if defv6 != nil { + if !args.Defaultv6Gateway.Route.Empty() { + return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, defv6, args.Defaultv6Gateway) + } + args.Defaultv6Gateway.Route = *defv6 + args.Defaultv6Gateway.Name = iface.Name } link := boot.FDBasedLink{ @@ -232,7 +237,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO // Create the socket for the device. for i := 0; i < link.NumChannels; i++ { log.Debugf("Creating Channel %d", i) - socketEntry, err := createSocket(iface, ifaceLink, enableGSO) + socketEntry, err := createSocket(iface, ifaceLink, hardwareGSO) if err != nil { return fmt.Errorf("failed to createSocket for %s : %v", iface.Name, err) } @@ -247,9 +252,15 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO args.FilePayload.Files = append(args.FilePayload.Files, socketEntry.deviceFile) } + if link.GSOMaxSize == 0 && softwareGSO { + // Hardware GSO is disabled. Let's enable software GSO. + link.GSOMaxSize = stack.SoftwareGSOMaxSize + link.SoftwareGSOEnabled = true + } + // Collect the addresses for the interface, enable forwarding, // and remove them from the host. - for _, addr := range ip4addrs { + for _, addr := range ipAddrs { link.Addresses = append(link.Addresses, addr.IP) // Steal IP address from NIC. @@ -345,46 +356,56 @@ func loopbackLinks(iface net.Interface, addrs []net.Addr) ([]boot.LoopbackLink, } // routesForIface iterates over all routes for the given interface and converts -// them to boot.Routes. -func routesForIface(iface net.Interface) ([]boot.Route, *boot.Route, error) { +// them to boot.Routes. It also returns the a default v4/v6 route if found. +func routesForIface(iface net.Interface) ([]boot.Route, *boot.Route, *boot.Route, error) { link, err := netlink.LinkByIndex(iface.Index) if err != nil { - return nil, nil, err + return nil, nil, nil, err } rs, err := netlink.RouteList(link, netlink.FAMILY_ALL) if err != nil { - return nil, nil, fmt.Errorf("getting routes from %q: %v", iface.Name, err) + return nil, nil, nil, fmt.Errorf("getting routes from %q: %v", iface.Name, err) } - var def *boot.Route + var defv4, defv6 *boot.Route var routes []boot.Route for _, r := range rs { // Is it a default route? if r.Dst == nil { if r.Gw == nil { - return nil, nil, fmt.Errorf("default route with no gateway %q: %+v", iface.Name, r) - } - if r.Gw.To4() == nil { - log.Warningf("IPv6 is not supported, skipping default route: %v", r) - continue - } - if def != nil { - return nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, def, r) + return nil, nil, nil, fmt.Errorf("default route with no gateway %q: %+v", iface.Name, r) } // Create a catch all route to the gateway. - def = &boot.Route{ - Destination: net.IPNet{ - IP: net.IPv4zero, - Mask: net.IPMask(net.IPv4zero), - }, - Gateway: r.Gw, + switch len(r.Gw) { + case header.IPv4AddressSize: + if defv4 != nil { + return nil, nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, defv4, r) + } + defv4 = &boot.Route{ + Destination: net.IPNet{ + IP: net.IPv4zero, + Mask: net.IPMask(net.IPv4zero), + }, + Gateway: r.Gw, + } + case header.IPv6AddressSize: + if defv6 != nil { + return nil, nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, defv6, r) + } + + defv6 = &boot.Route{ + Destination: net.IPNet{ + IP: net.IPv6zero, + Mask: net.IPMask(net.IPv6zero), + }, + Gateway: r.Gw, + } + default: + return nil, nil, nil, fmt.Errorf("unexpected address size for gateway: %+v for route: %+v", r.Gw, r) } continue } - if r.Dst.IP.To4() == nil { - log.Warningf("IPv6 is not supported, skipping route: %v", r) - continue - } + dst := *r.Dst dst.IP = dst.IP.Mask(dst.Mask) routes = append(routes, boot.Route{ @@ -392,7 +413,7 @@ func routesForIface(iface net.Interface) ([]boot.Route, *boot.Route, error) { Gateway: r.Gw, }) } - return routes, def, nil + return routes, defv4, defv6, nil } // removeAddress removes IP address from network device. It's equivalent to: diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index ee9327fc8..ce1452b87 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -18,6 +18,7 @@ package sandbox import ( "context" "fmt" + "math" "os" "os/exec" "strconv" @@ -631,6 +632,26 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF if err != nil { return fmt.Errorf("getting cpu count from cgroups: %v", err) } + if conf.CPUNumFromQuota { + // Dropping below 2 CPUs can trigger application to disable + // locks that can lead do hard to debug errors, so just + // leaving two cores as reasonable default. + const minCPUs = 2 + + quota, err := s.Cgroup.CPUQuota() + if err != nil { + return fmt.Errorf("getting cpu qouta from cgroups: %v", err) + } + if n := int(math.Ceil(quota)); n > 0 { + if n < minCPUs { + n = minCPUs + } + if n < cpuNum { + // Only lower the cpu number. + cpuNum = n + } + } + } cmd.Args = append(cmd.Args, "--cpu-num", strconv.Itoa(cpuNum)) mem, err := s.Cgroup.MemoryLimit() @@ -1004,16 +1025,22 @@ func (s *Sandbox) ChangeLogging(args control.LoggingArgs) error { // DestroyContainer destroys the given container. If it is the root container, // then the entire sandbox is destroyed. func (s *Sandbox) DestroyContainer(cid string) error { + if err := s.destroyContainer(cid); err != nil { + // If the sandbox isn't running, the container has already been destroyed, + // ignore the error in this case. + if s.IsRunning() { + return err + } + } + return nil +} + +func (s *Sandbox) destroyContainer(cid string) error { if s.IsRootContainer(cid) { log.Debugf("Destroying root container %q by destroying sandbox", cid) return s.destroy() } - if !s.IsRunning() { - // Sandbox isn't running anymore, container is already destroyed. - return nil - } - log.Debugf("Destroying container %q in sandbox %q", cid, s.ID) conn, err := s.sandboxConnect() if err != nil { diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index fa58313a0..205638803 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_library( name = "specutils", srcs = [ + "cri.go", "fs.go", "namespace.go", "specutils.go", diff --git a/runsc/specutils/cri.go b/runsc/specutils/cri.go new file mode 100644 index 000000000..9c5877cd5 --- /dev/null +++ b/runsc/specutils/cri.go @@ -0,0 +1,110 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package specutils + +import ( + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +const ( + // ContainerdContainerTypeAnnotation is the OCI annotation set by + // containerd to indicate whether the container to create should have + // its own sandbox or a container within an existing sandbox. + ContainerdContainerTypeAnnotation = "io.kubernetes.cri.container-type" + // ContainerdContainerTypeContainer is the container type value + // indicating the container should be created in an existing sandbox. + ContainerdContainerTypeContainer = "container" + // ContainerdContainerTypeSandbox is the container type value + // indicating the container should be created in a new sandbox. + ContainerdContainerTypeSandbox = "sandbox" + + // ContainerdSandboxIDAnnotation is the OCI annotation set to indicate + // which sandbox the container should be created in when the container + // is not the first container in the sandbox. + ContainerdSandboxIDAnnotation = "io.kubernetes.cri.sandbox-id" + + // CRIOContainerTypeAnnotation is the OCI annotation set by + // CRI-O to indicate whether the container to create should have + // its own sandbox or a container within an existing sandbox. + CRIOContainerTypeAnnotation = "io.kubernetes.cri-o.ContainerType" + + // CRIOContainerTypeContainer is the container type value + // indicating the container should be created in an existing sandbox. + CRIOContainerTypeContainer = "container" + // CRIOContainerTypeSandbox is the container type value + // indicating the container should be created in a new sandbox. + CRIOContainerTypeSandbox = "sandbox" + + // CRIOSandboxIDAnnotation is the OCI annotation set to indicate + // which sandbox the container should be created in when the container + // is not the first container in the sandbox. + CRIOSandboxIDAnnotation = "io.kubernetes.cri-o.SandboxID" +) + +// ContainerType represents the type of container requested by the calling container manager. +type ContainerType int + +const ( + // ContainerTypeUnspecified indicates that no known container type + // annotation was found in the spec. + ContainerTypeUnspecified ContainerType = iota + // ContainerTypeUnknown indicates that a container type was specified + // but is unknown to us. + ContainerTypeUnknown + // ContainerTypeSandbox indicates that the container should be run in a + // new sandbox. + ContainerTypeSandbox + // ContainerTypeContainer indicates that the container should be run in + // an existing sandbox. + ContainerTypeContainer +) + +// SpecContainerType tries to determine the type of container specified by the +// container manager using well-known container annotations. +func SpecContainerType(spec *specs.Spec) ContainerType { + if t, ok := spec.Annotations[ContainerdContainerTypeAnnotation]; ok { + switch t { + case ContainerdContainerTypeSandbox: + return ContainerTypeSandbox + case ContainerdContainerTypeContainer: + return ContainerTypeContainer + default: + return ContainerTypeUnknown + } + } + if t, ok := spec.Annotations[CRIOContainerTypeAnnotation]; ok { + switch t { + case CRIOContainerTypeSandbox: + return ContainerTypeSandbox + case CRIOContainerTypeContainer: + return ContainerTypeContainer + default: + return ContainerTypeUnknown + } + } + return ContainerTypeUnspecified +} + +// SandboxID returns the ID of the sandbox to join and whether an ID was found +// in the spec. +func SandboxID(spec *specs.Spec) (string, bool) { + if id, ok := spec.Annotations[ContainerdSandboxIDAnnotation]; ok { + return id, true + } + if id, ok := spec.Annotations[CRIOSandboxIDAnnotation]; ok { + return id, true + } + return "", false +} diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 3d9ced1b6..d3c2e4e78 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -108,23 +108,18 @@ func ValidateSpec(spec *specs.Spec) error { } } - // Two annotations are use by containerd to support multi-container pods. - // "io.kubernetes.cri.container-type" - // "io.kubernetes.cri.sandbox-id" - containerType, hasContainerType := spec.Annotations[ContainerdContainerTypeAnnotation] - _, hasSandboxID := spec.Annotations[ContainerdSandboxIDAnnotation] - switch { - // Non-containerd use won't set a container type. - case !hasContainerType: - case containerType == ContainerdContainerTypeSandbox: - // When starting a container in an existing sandbox, the sandbox ID - // must be set. - case containerType == ContainerdContainerTypeContainer: - if !hasSandboxID { - return fmt.Errorf("spec has container-type of %s, but no sandbox ID set", containerType) + // CRI specifies whether a container should start a new sandbox, or run + // another container in an existing sandbox. + switch SpecContainerType(spec) { + case ContainerTypeContainer: + // When starting a container in an existing sandbox, the + // sandbox ID must be set. + if _, ok := SandboxID(spec); !ok { + return fmt.Errorf("spec has container-type of container, but no sandbox ID set") } + case ContainerTypeUnknown: + return fmt.Errorf("unknown container-type") default: - return fmt.Errorf("unknown container-type: %s", containerType) } return nil @@ -338,39 +333,6 @@ func IsSupportedDevMount(m specs.Mount) bool { return true } -const ( - // ContainerdContainerTypeAnnotation is the OCI annotation set by - // containerd to indicate whether the container to create should have - // its own sandbox or a container within an existing sandbox. - ContainerdContainerTypeAnnotation = "io.kubernetes.cri.container-type" - // ContainerdContainerTypeContainer is the container type value - // indicating the container should be created in an existing sandbox. - ContainerdContainerTypeContainer = "container" - // ContainerdContainerTypeSandbox is the container type value - // indicating the container should be created in a new sandbox. - ContainerdContainerTypeSandbox = "sandbox" - - // ContainerdSandboxIDAnnotation is the OCI annotation set to indicate - // which sandbox the container should be created in when the container - // is not the first container in the sandbox. - ContainerdSandboxIDAnnotation = "io.kubernetes.cri.sandbox-id" -) - -// ShouldCreateSandbox returns true if the spec indicates that a new sandbox -// should be created for the container. If false, the container should be -// started in an existing sandbox. -func ShouldCreateSandbox(spec *specs.Spec) bool { - t, ok := spec.Annotations[ContainerdContainerTypeAnnotation] - return !ok || t == ContainerdContainerTypeSandbox -} - -// SandboxID returns the ID of the sandbox to join and whether an ID was found -// in the spec. -func SandboxID(spec *specs.Spec) (string, bool) { - id, ok := spec.Annotations[ContainerdSandboxIDAnnotation] - return id, ok -} - // WaitForReady waits for a process to become ready. The process is ready when // the 'ready' function returns true. It continues to wait if 'ready' returns // false. It returns error on timeout, if the process stops or if 'ready' fails. diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index d44ebc906..c96ca2eb6 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/runsc/testutil", visibility = ["//:sandbox"], deps = [ + "//pkg/log", "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index edf8b126c..9632776d2 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -25,7 +25,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "math" "math/rand" "net/http" @@ -42,6 +41,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" ) @@ -151,13 +151,6 @@ func TestConfig() *boot.Config { } } -// TestConfigWithRoot returns the default configuration to use in tests. -func TestConfigWithRoot(rootDir string) *boot.Config { - conf := TestConfig() - conf.RootDir = rootDir - return conf -} - // NewSpecWithArgs creates a simple spec with the given args suitable for use // in tests. func NewSpecWithArgs(args ...string) *specs.Spec { @@ -286,7 +279,7 @@ func WaitForHTTP(port int, timeout time.Duration) error { url := fmt.Sprintf("http://localhost:%d/", port) resp, err := c.Get(url) if err != nil { - log.Printf("Waiting %s: %v", url, err) + log.Infof("Waiting %s: %v", url, err) return err } resp.Body.Close() diff --git a/scripts/benchmarks.sh b/scripts/benchmarks.sh new file mode 100755 index 000000000..6b9065b07 --- /dev/null +++ b/scripts/benchmarks.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env bash + +if [ "$#" -lt "1" ]; then + echo "usage: $0 <--mock |--env=<filename>> ..." + echo "example: $0 --mock --runs=8" + exit 1 +fi + +source $(dirname $0)/common.sh + +readonly TIMESTAMP=`date "+%Y%m%d-%H%M%S"` +readonly OUTDIR="$(mktemp --tmpdir -d run-${TIMESTAMP}-XXX)" +readonly DEFAULT_RUNTIMES="--runtime=runc --runtime=runsc --runtime=runsc-kvm" +readonly ALL_RUNTIMES="--runtime=runc --runtime=runsc --runtime=runsc-kvm" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'fio.(read|write)' --metric=bandwidth --size=5g --ioengine=sync --blocksize=1m > "${OUTDIR}/fio.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} fio.rand --metric=bandwidth --size=5g --ioengine=sync --blocksize=4k --time=30 > "${OUTDIR}/tmp_fio.csv" +cat "${OUTDIR}/tmp_fio.csv" | grep "\(runc\|runsc\)" >> "${OUTDIR}/fio.csv" && rm "${OUTDIR}/tmp_fio.csv" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'fio.(read|write)' --metric=bandwidth --tmpfs=True --size=5g --ioengine=sync --blocksize=1m > "${OUTDIR}/fio-tmpfs.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} fio.rand --metric=bandwidth --tmpfs=True --size=5g --ioengine=sync --blocksize=4k --time=30 > "${OUTDIR}/tmp_fio.csv" +cat "${OUTDIR}/tmp_fio.csv" | grep "\(runc\|runsc\)" >> "${OUTDIR}/fio-tmpfs.csv" && rm "${OUTDIR}/tmp_fio.csv" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} startup --count=50 > "${OUTDIR}/startup.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} density > "${OUTDIR}/density.csv" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} sysbench.cpu --threads=1 --max_prime=50000 --options='--max-time=5' > "${OUTDIR}/sysbench-cpu.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} sysbench.memory --threads=1 --options='--memory-block-size=1M --memory-total-size=500G' > "${OUTDIR}/sysbench-memory.csv" +run //benchmarks:perf -- run "$@" ${ALL_RUNTIMES} syscall > "${OUTDIR}/syscall.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'network.(upload|download)' --runs=20 > "${OUTDIR}/iperf.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} ml.tensorflow > "${OUTDIR}/tensorflow.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} media.ffmpeg > "${OUTDIR}/ffmpeg.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} http.httpd --path=latin100k.txt --connections=1 --connections=5 --connections=10 --connections=25 > "${OUTDIR}/httpd100k.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} http.httpd --path=latin10240k.txt --connections=1 --connections=5 --connections=10 --connections=25 > "${OUTDIR}/httpd10240k.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} redis > "${OUTDIR}/redis.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'http.(ruby|node)' > "${OUTDIR}/applications.csv" + +echo "${OUTPUT}" && exit 0 diff --git a/scripts/build.sh b/scripts/build.sh index 0b3d1b316..8b2094cb0 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -17,63 +17,71 @@ source $(dirname $0)/common.sh # Install required packages for make_repository.sh et al. -sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils +sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils xz-utils # Build runsc. runsc=$(build -c opt //runsc) # Build packages. -pkg=$(build -c opt //runsc:runsc-debian) +pkgs=$(build -c opt //runsc:runsc-debian) + +# Stop here if we have no artifacts directory. +[[ -v KOKORO_ARTIFACTS_DIR ]] || exit 0 + +# install_raw installs raw artifacts. +install_raw() { + mkdir -p "$1" + cp -f "${runsc}" "$1"/runsc + sha512sum "$1"/runsc | awk '{print $1 " runsc"}' > "$1"/runsc.sha512 +} # Build a repository, if the key is available. +# +# Note that make_repository.sh script will install packages into the provided +# root, but will output to stdout a directory that can be copied arbitrarily +# into "${KOKORO_ARTIFACTS_DIR}"/dists/XXX. We do things this way because we +# will copy the same repository structure into multiple locations, below. if [[ -v KOKORO_REPO_KEY ]]; then - repo=$(tools/make_repository.sh "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" gvisor-bot@google.com main ${pkg}) + repo=$(tools/make_repository.sh \ + "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" \ + gvisor-bot@google.com \ + main \ + "${KOKORO_ARTIFACTS_DIR}" \ + ${pkgs}) fi -# Install installs artifacts. -install() { - local -r binaries_dir="$1" - local -r repo_dir="$2" - mkdir -p "${binaries_dir}" - cp -f "${runsc}" "${binaries_dir}"/runsc - sha512sum "${binaries_dir}"/runsc | awk '{print $1 " runsc"}' > "${binaries_dir}"/runsc.sha512 +# install_repo installs a repository. +# +# Note that packages are already installed, as noted above. +install_repo() { if [[ -v repo ]]; then - rm -rf "${repo_dir}" && mkdir -p "$(dirname "${repo_dir}")" - cp -a "${repo}" "${repo_dir}" + rm -rf "$1" && mkdir -p "$(dirname "$1")" && cp -a "${repo}" "$1" fi } -# Move the runsc binary into "latest" directory, and also a directory with the -# current date. If the current commit happens to correpond to a tag, then we -# will also move everything into a directory named after the given tag. -if [[ -v KOKORO_ARTIFACTS_DIR ]]; then - if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then - # The "latest" directory and current date. - stamp="$(date -Idate)" - install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" \ - "${KOKORO_ARTIFACTS_DIR}/dists/nightly/latest" - install "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" \ - "${KOKORO_ARTIFACTS_DIR}/dists/nightly/${stamp}" - else - # Is it a tagged release? Build that instead. In that case, we also try to - # update the base release directory, in case this is an update. Finally, we - # update the "release" directory, which has the last released version. - tags="$(git tag --points-at HEAD)" - if ! [[ -z "${tags}" ]]; then - # Note that a given commit can match any number of tags. We have to - # iterate through all possible tags and produce associated artifacts. - for tag in ${tags}; do - name=$(echo "${tag}" | cut -d'-' -f2) - base=$(echo "${name}" | cut -d'.' -f1) - install "${KOKORO_ARTIFACTS_DIR}/release/${name}" \ - "${KOKORO_ARTIFACTS_DIR}/dists/${name}" - if [[ "${base}" != "${tag}" ]]; then - install "${KOKORO_ARTIFACTS_DIR}/release/${base}" \ - "${KOKORO_ARTIFACTS_DIR}/dists/${base}" - fi - install "${KOKORO_ARTIFACTS_DIR}/release/latest" \ - "${KOKORO_ARTIFACTS_DIR}/dists/latest" - done - fi +# If nightly, install only nightly artifacts. +if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then + # The "latest" directory and current date. + stamp="$(date -Idate)" + install_raw "${KOKORO_ARTIFACTS_DIR}/nightly/latest" + install_raw "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/nightly" +else + # We keep only the latest master raw release. + install_raw "${KOKORO_ARTIFACTS_DIR}/master/latest" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/master" + + # Is it a tagged release? Build that too. + tags="$(git tag --points-at HEAD)" + if ! [[ -z "${tags}" ]]; then + # Note that a given commit can match any number of tags. We have to iterate + # through all possible tags and produce associated artifacts. + for tag in ${tags}; do + name=$(echo "${tag}" | cut -d'-' -f2) + base=$(echo "${name}" | cut -d'.' -f1) + install_raw "${KOKORO_ARTIFACTS_DIR}/release/${name}" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/release" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/${base}" + done fi fi diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh index f8ec967b1..bbc1a038e 100755 --- a/scripts/common_bazel.sh +++ b/scripts/common_bazel.sh @@ -71,6 +71,13 @@ function run_as_root() { function collect_logs() { # Zip out everything into a convenient form. if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then + # Merge results files of all shards for each test suite. + for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do + junitparser merge `find $d -name test.xml` $d/test.xml + cat $d/shard_*_of_*/test.log > $d/test.log + ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip + done + find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf # Move test logs to Kokoro directory. tar is used to conveniently perform # renames while moving files. find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | diff --git a/scripts/dev.sh b/scripts/dev.sh index c67003018..6238b4d0b 100755 --- a/scripts/dev.sh +++ b/scripts/dev.sh @@ -54,9 +54,10 @@ declare OUTPUT="$(build //runsc)" if [[ ${REFRESH} -eq 0 ]]; then install_runsc "${RUNTIME}" --net-raw install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets + install_runsc "${RUNTIME}-p" --net-raw --profile echo - echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup." + echo "Runtimes ${RUNTIME}, ${RUNTIME}-d (debug enabled), and ${RUNTIME}-p installed." echo "Use --runtime="${RUNTIME}" with your Docker command." echo " docker run --rm --runtime="${RUNTIME}" hello-world" echo diff --git a/scripts/go.sh b/scripts/go.sh index 0dbfb7747..626ed8fa4 100755 --- a/scripts/go.sh +++ b/scripts/go.sh @@ -25,6 +25,8 @@ tools/go_branch.sh # Checkout the new branch. git checkout go && git clean -f +go version + # Build everything. go build ./... diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh new file mode 100755 index 000000000..c47cbd675 --- /dev/null +++ b/scripts/iptables_tests.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +install_runsc_for_test iptables + +# Build the docker image for the test. +run //test/iptables/runner --norun + +# TODO(gvisor.dev/issue/170): Also test this on runsc once iptables are better +# supported +test //test/iptables:iptables_test "--test_arg=--runtime=runc" \ + "--test_arg=--image=bazel/test/iptables/runner:runner" diff --git a/scripts/release.sh b/scripts/release.sh index b936bcc77..091abf87f 100755 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -34,5 +34,16 @@ declare -r EMAIL=${EMAIL:-${KOKORO_RELEASE_AUTHOR}@google.com} git config --get user.name || git config user.name "gVisor-bot" git config --get user.email || git config user.email "${EMAIL}" +# Provide a credential if available. +if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then + git config --global credential.helper cache + git credential approve <<EOF +protocol=https +host=github.com +username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}") +password=x-oauth-basic +EOF +fi + # Run the release tool, which pushes to the origin repository. tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}" diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh new file mode 100755 index 000000000..9ee991e42 --- /dev/null +++ b/scripts/runtime_tests.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +if [ ! -v RUNTIME_TEST_NAME ]; then + echo 'Must set $RUNTIME_TEST_NAME' >&2 + exit 1 +fi + +install_runsc_for_test runtimes +test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test" diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh index 585216aae..3a15050c2 100755 --- a/scripts/simple_tests.sh +++ b/scripts/simple_tests.sh @@ -17,4 +17,4 @@ source $(dirname $0)/common.sh # Run all simple tests (locally). -test //pkg/... //runsc/... //tools/... +test //pkg/... //runsc/... //tools/... //benchmarks/... //benchmarks/runner:runner_test diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh new file mode 100755 index 000000000..0de2df1d2 --- /dev/null +++ b/scripts/swgso_tests.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +# Install the runtime and perform basic tests. +install_runsc_for_test swgso --software-gso=true --gso=false +test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/syscall_kvm_tests.sh b/scripts/syscall_kvm_tests.sh new file mode 100755 index 000000000..de85daa5a --- /dev/null +++ b/scripts/syscall_kvm_tests.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +# TODO(b/112165693): "test --test_tag_filters=runsc_kvm" can be used +# when the "manual" tag will be removed for kvm tests. +test `bazel query "attr(tags, runsc_kvm, tests(//test/syscalls/...))"` diff --git a/test/README.md b/test/README.md index 09c36b461..97fe7ea04 100644 --- a/test/README.md +++ b/test/README.md @@ -10,9 +10,31 @@ they may need extra setup in the test machine and extra configuration to run. functionality. - **image:** basic end to end test for popular images. These require the same setup as integration tests. -- **root:** tests that require to be run as root. +- **root:** tests that require to be run as root. These require the same setup + as integration tests. - **util:** utilities library to support the tests. For the above noted cases, the relevant runtime must be installed via `runsc -install` before running. This is handled automatically by the test scripts in -the `kokoro` directory. +install` before running. Just note that they require specific configuration to +work. This is handled automatically by the test scripts in the `scripts` +directory and they can be used to run tests locally on your machine. They are +also used to run these tests in `kokoro`. + +**Example:** + +To run image and integration tests, run: + +`./scripts/docker_test.sh` + +To run root tests, run: + +`./scripts/root_test.sh` + +There are a few other interesting variations for image and integration tests: + +* overlay: sets writable overlay inside the sentry +* hostnet: configures host network pass-thru, instead of netstack +* kvm: runsc the test using the KVM platform, instead of ptrace + +The test will build runsc, configure it with your local docker, restart +`dockerd`, and run tests. The location for runsc logs is printed to the output. diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go index c962a3159..4074d2285 100644 --- a/test/e2e/exec_test.go +++ b/test/e2e/exec_test.go @@ -109,7 +109,7 @@ func TestExecPrivileged(t *testing.T) { t.Logf("Exec CapEff: %v", got) want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW))) if got != want { - t.Errorf("wrong capabilities, got: %q, want: %q", got, want) + t.Errorf("Wrong capabilities, got: %q, want: %q. Make sure runsc is not using '--net-raw'", got, want) } } diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 7cc0de129..28064e557 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -175,6 +175,9 @@ func TestCheckpointRestore(t *testing.T) { t.Fatal(err) } + // TODO(b/143498576): Remove after github.com/moby/moby/issues/38963 is fixed. + time.Sleep(1 * time.Second) + if err := d.Restore("test"); err != nil { t.Fatal("docker restore failed:", err) } diff --git a/test/iptables/BUILD b/test/iptables/BUILD new file mode 100644 index 000000000..fa833c3b2 --- /dev/null +++ b/test/iptables/BUILD @@ -0,0 +1,31 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "iptables", + srcs = [ + "filter_input.go", + "iptables.go", + "iptables_util.go", + ], + importpath = "gvisor.dev/gvisor/test/iptables", + visibility = ["//test/iptables:__subpackages__"], +) + +go_test( + name = "iptables_test", + srcs = [ + "iptables_test.go", + ], + embed = [":iptables"], + tags = [ + "local", + "manual", + ], + deps = [ + "//pkg/log", + "//runsc/dockerutil", + "//runsc/testutil", + ], +) diff --git a/test/iptables/README.md b/test/iptables/README.md new file mode 100644 index 000000000..b37cb2a96 --- /dev/null +++ b/test/iptables/README.md @@ -0,0 +1,44 @@ +# iptables Tests + +iptables tests are run via `scripts/iptables\_test.sh`. + +## Test Structure + +Each test implements `TestCase`, providing (1) a function to run inside the +container and (2) a function to run locally. Those processes are given each +others' IP addresses. The test succeeds when both functions succeed. + +The function inside the container (`ContainerAction`) typically sets some +iptables rules and then tries to send or receive packets. The local function +(`LocalAction`) will typically just send or receive packets. + +### Adding Tests + +1) Add your test to the `iptables` package. + +2) Register the test in an `init` function via `RegisterTestCase` (see +`filter_input.go` as an example). + +3) Add it to `iptables_test.go` (see the other tests in that file). + +Your test is now runnable with bazel! + +## Run individual tests + +Build the testing Docker container: + +```bash +$ bazel run //test/iptables/runner -- --norun +``` + +Run an individual test via: + +```bash +$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> +``` + +To run an individual test with `runc`: + +```bash +$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> --test_arg=--runtime=runc +``` diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go new file mode 100644 index 000000000..923f44e68 --- /dev/null +++ b/test/iptables/filter_input.go @@ -0,0 +1,124 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "net" + "time" +) + +const ( + dropPort = 2401 + acceptPort = 2402 + sendloopDuration = 2 * time.Second + network = "udp4" +) + +func init() { + RegisterTestCase(FilterInputDropUDP{}) + RegisterTestCase(FilterInputDropUDPPort{}) + RegisterTestCase(FilterInputDropDifferentUDPPort{}) +} + +// FilterInputDropUDP tests that we can drop UDP traffic. +type FilterInputDropUDP struct{} + +// Name implements TestCase.Name. +func (FilterInputDropUDP) Name() string { + return "FilterInputDropUDP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropUDP) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropUDP) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// FilterInputDropUDPPort tests that we can drop UDP traffic by port. +type FilterInputDropUDPPort struct{} + +// Name implements TestCase.Name. +func (FilterInputDropUDPPort) Name() string { + return "FilterInputDropUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropUDPPort) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port +// doesn't drop packets on other ports. +type FilterInputDropDifferentUDPPort struct{} + +// Name implements TestCase.Name. +func (FilterInputDropDifferentUDPPort) Name() string { + return "FilterInputDropDifferentUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on another port. + if err := listenUDP(acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go new file mode 100644 index 000000000..2e565d988 --- /dev/null +++ b/test/iptables/iptables.go @@ -0,0 +1,53 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package iptables contains a set of iptables tests implemented as TestCases +package iptables + +import ( + "fmt" + "net" +) + +// IPExchangePort is the port the container listens on to receive the IP +// address of the local process. +const IPExchangePort = 2349 + +// A TestCase contains one action to run in the container and one to run +// locally. The actions run concurrently and each must succeed for the test +// pass. +type TestCase interface { + // Name returns the name of the test. + Name() string + + // ContainerAction runs inside the container. It receives the IP of the + // local process. + ContainerAction(ip net.IP) error + + // LocalAction runs locally. It receives the IP of the container. + LocalAction(ip net.IP) error +} + +// Tests maps test names to TestCase. +// +// New TestCases are added by calling RegisterTestCase in an init function. +var Tests = map[string]TestCase{} + +// RegisterTestCase registers tc so it can be run. +func RegisterTestCase(tc TestCase) { + if _, ok := Tests[tc.Name()]; ok { + panic(fmt.Sprintf("TestCase %s already registered.", tc.Name())) + } + Tests[tc.Name()] = tc +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go new file mode 100644 index 000000000..bfbf1bb87 --- /dev/null +++ b/test/iptables/iptables_test.go @@ -0,0 +1,179 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "net" + "os" + "path" + "testing" + "time" + + "flag" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/runsc/testutil" +) + +const timeout time.Duration = 10 * time.Second + +var image = flag.String("image", "bazel/test/iptables/runner:runner", "image to run tests in") + +type result struct { + output string + err error +} + +// singleTest runs a TestCase. Each test follows a pattern: +// - Create a container. +// - Get the container's IP. +// - Send the container our IP. +// - Start a new goroutine running the local action of the test. +// - Wait for both the container and local actions to finish. +// +// Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or +// to stderr. +func singleTest(test TestCase) error { + if _, ok := Tests[test.Name()]; !ok { + return fmt.Errorf("no test found with name %q. Has it been registered?", test.Name()) + } + + // Create and start the container. + cont := dockerutil.MakeDocker("gvisor-iptables") + defer cont.CleanUp() + resultChan := make(chan *result) + go func() { + output, err := cont.RunFg("--cap-add=NET_ADMIN", *image, "-name", test.Name()) + logContainer(output, err) + resultChan <- &result{output, err} + }() + + // Get the container IP. + ip, err := getIP(cont) + if err != nil { + return fmt.Errorf("failed to get container IP: %v", err) + } + + // Give the container our IP. + if err := sendIP(ip); err != nil { + return fmt.Errorf("failed to send IP to container: %v", err) + } + + // Run our side of the test. + errChan := make(chan error) + go func() { + errChan <- test.LocalAction(ip) + }() + + // Wait for both the container and local tests to finish. + var res *result + to := time.After(timeout) + for localDone := false; res == nil || !localDone; { + select { + case res = <-resultChan: + log.Infof("Container finished.") + case err, localDone = <-errChan: + log.Infof("Local finished.") + if err != nil { + return fmt.Errorf("local test failed: %v", err) + } + case <-to: + return fmt.Errorf("timed out after %f seconds", timeout.Seconds()) + } + } + + return res.err +} + +func getIP(cont dockerutil.Docker) (net.IP, error) { + // The container might not have started yet, so retry a few times. + var ipStr string + to := time.After(timeout) + for ipStr == "" { + ipStr, _ = cont.FindIP() + select { + case <-to: + return net.IP{}, fmt.Errorf("timed out getting IP after %f seconds", timeout.Seconds()) + default: + time.Sleep(250 * time.Millisecond) + } + } + ip := net.ParseIP(ipStr) + if ip == nil { + return net.IP{}, fmt.Errorf("invalid IP: %q", ipStr) + } + log.Infof("Container has IP of %s", ipStr) + return ip, nil +} + +func sendIP(ip net.IP) error { + contAddr := net.TCPAddr{ + IP: ip, + Port: IPExchangePort, + } + var conn *net.TCPConn + // The container may not be listening when we first connect, so retry + // upon error. + cb := func() error { + c, err := net.DialTCP("tcp4", nil, &contAddr) + conn = c + return err + } + if err := testutil.Poll(cb, timeout); err != nil { + return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err) + } + if _, err := conn.Write([]byte{0}); err != nil { + return fmt.Errorf("error writing to container: %v", err) + } + return nil +} + +func logContainer(output string, err error) { + msg := fmt.Sprintf("Container error: %v\nContainer output:\n%v", err, output) + if artifactsDir := os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); artifactsDir != "" { + fpath := path.Join(artifactsDir, "container.log") + if file, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE, 0644); err != nil { + log.Warningf("Failed to open log file %q: %v", fpath, err) + } else { + defer file.Close() + if _, err := file.Write([]byte(msg)); err == nil { + return + } + log.Warningf("Failed to write to log file %s: %v", fpath, err) + } + } + + // We couldn't write to the output directory -- just log to stderr. + log.Infof(msg) +} + +func TestFilterInputDropUDP(t *testing.T) { + if err := singleTest(FilterInputDropUDP{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterInputDropUDPPort(t *testing.T) { + if err := singleTest(FilterInputDropUDPPort{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterInputDropDifferentUDPPort(t *testing.T) { + if err := singleTest(FilterInputDropDifferentUDPPort{}); err != nil { + t.Fatal(err) + } +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go new file mode 100644 index 000000000..3a4d11f1a --- /dev/null +++ b/test/iptables/iptables_util.go @@ -0,0 +1,82 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "net" + "os/exec" + "time" +) + +const iptablesBinary = "iptables" + +// filterTable calls `iptables -t filter` with the given args. +func filterTable(args ...string) error { + args = append([]string{"-t", "filter"}, args...) + cmd := exec.Command(iptablesBinary, args...) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) + } + return nil +} + +// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for +// the first read on that port. +func listenUDP(port int, timeout time.Duration) error { + localAddr := net.UDPAddr{ + Port: port, + } + conn, err := net.ListenUDP(network, &localAddr) + if err != nil { + return err + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(timeout)) + _, err = conn.Read([]byte{0}) + return err +} + +// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified +// over a duration. +func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { + // Send packets for a few seconds. + remote := net.UDPAddr{ + IP: ip, + Port: port, + } + conn, err := net.DialUDP(network, nil, &remote) + if err != nil { + return err + } + defer conn.Close() + + to := time.After(duration) + for timedOut := false; !timedOut; { + // This may return an error (connection refused) if the remote + // hasn't started listening yet or they're dropping our + // packets. So we ignore Write errors and depend on the remote + // to report a failure if it doesn't get a packet it needs. + conn.Write([]byte{0}) + select { + case <-to: + timedOut = true + default: + time.Sleep(200 * time.Millisecond) + } + } + + return nil +} diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD new file mode 100644 index 000000000..c6c42d870 --- /dev/null +++ b/test/iptables/runner/BUILD @@ -0,0 +1,16 @@ +load("@io_bazel_rules_docker//go:image.bzl", "go_image") +load("@io_bazel_rules_docker//container:container.bzl", "container_image") + +package(licenses = ["notice"]) + +container_image( + name = "iptables-base", + base = "@iptables-test//image", +) + +go_image( + name = "runner", + srcs = ["main.go"], + base = ":iptables-base", + deps = ["//test/iptables"], +) diff --git a/test/iptables/runner/Dockerfile b/test/iptables/runner/Dockerfile new file mode 100644 index 000000000..b77db44a1 --- /dev/null +++ b/test/iptables/runner/Dockerfile @@ -0,0 +1,4 @@ +# This Dockerfile builds the image hosted at +# gcr.io/gvisor-presubmit/iptables-test. +FROM ubuntu +RUN apt update && apt install -y iptables diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go new file mode 100644 index 000000000..3c794114e --- /dev/null +++ b/test/iptables/runner/main.go @@ -0,0 +1,70 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main runs iptables tests from within a docker container. +package main + +import ( + "flag" + "fmt" + "log" + "net" + + "gvisor.dev/gvisor/test/iptables" +) + +var name = flag.String("name", "", "name of the test to run") + +func main() { + flag.Parse() + + // Find out which test we're running. + test, ok := iptables.Tests[*name] + if !ok { + log.Fatalf("No test found named %q", *name) + } + log.Printf("Running test %q", *name) + + // Get the IP of the local process. + ip, err := getIP() + if err != nil { + log.Fatal(err) + } + + // Run the test. + if err := test.ContainerAction(ip); err != nil { + log.Fatalf("Failed running test %q: %v", *name, err) + } +} + +// getIP listens for a connection from the local process and returns the source +// IP of that connection. +func getIP() (net.IP, error) { + localAddr := net.TCPAddr{ + Port: iptables.IPExchangePort, + } + listener, err := net.ListenTCP("tcp4", &localAddr) + if err != nil { + return net.IP{}, fmt.Errorf("failed listening for IP: %v", err) + } + defer listener.Close() + conn, err := listener.AcceptTCP() + if err != nil { + return net.IP{}, fmt.Errorf("failed accepting IP: %v", err) + } + defer conn.Close() + log.Printf("Connected to %v", conn.RemoteAddr()) + + return conn.RemoteAddr().(*net.TCPAddr).IP, nil +} diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index 76f1e4f2a..4038661cb 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -24,6 +24,7 @@ import ( "strconv" "strings" "testing" + "time" "gvisor.dev/gvisor/runsc/cgroup" "gvisor.dev/gvisor/runsc/dockerutil" @@ -56,6 +57,59 @@ func verifyPid(pid int, path string) error { } // TestCgroup sets cgroup options and checks that cgroup was properly configured. +func TestMemCGroup(t *testing.T) { + allocMemSize := 128 << 20 + if err := dockerutil.Pull("python"); err != nil { + t.Fatal("docker pull failed:", err) + } + d := dockerutil.MakeDocker("memusage-test") + + // Start a new container and allocate the specified about of memory. + args := []string{ + "--memory=256MB", + "python", + "python", + "-c", + fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize), + } + if err := d.Run(args...); err != nil { + t.Fatal("docker create failed:", err) + } + defer d.CleanUp() + + gid, err := d.ID() + if err != nil { + t.Fatalf("Docker.ID() failed: %v", err) + } + t.Logf("cgroup ID: %s", gid) + + path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.usage_in_bytes") + memUsage := 0 + + // Wait when the container will allocate memory. + start := time.Now() + for time.Now().Sub(start) < 30*time.Second { + outRaw, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("failed to read %q: %v", path, err) + } + out := strings.TrimSpace(string(outRaw)) + memUsage, err = strconv.Atoi(out) + if err != nil { + t.Fatalf("Atoi(%v): %v", out, err) + } + + if memUsage > allocMemSize { + return + } + + time.Sleep(100 * time.Millisecond) + } + + t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20) +} + +// TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestCgroup(t *testing.T) { if err := dockerutil.Pull("alpine"); err != nil { t.Fatal("docker pull failed:", err) diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 6cd378a1b..126f0975a 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -40,6 +40,15 @@ var ( // TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a // single container sandbox. func TestOOMScoreAdjSingle(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -84,7 +93,6 @@ func TestOOMScoreAdjSingle(t *testing.T) { s := testutil.NewSpecWithArgs("sleep", "1000") s.Process.OOMScoreAdj = testCase.OOMScoreAdj - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id}) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -123,6 +131,15 @@ func TestOOMScoreAdjSingle(t *testing.T) { // TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a // multi-container sandbox. func TestOOMScoreAdjMulti(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -240,7 +257,6 @@ func TestOOMScoreAdjMulti(t *testing.T) { } } - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -327,13 +343,8 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { } func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { - // Setup root dir if one hasn't been provided. if len(conf.RootDir) == 0 { - rootDir, err := testutil.SetupRootDir() - if err != nil { - return nil, nil, fmt.Errorf("error creating root dir: %v", err) - } - conf.RootDir = rootDir + panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } var containers []*container.Container @@ -345,7 +356,6 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*c for _, b := range bundles { os.RemoveAll(b) } - os.RemoveAll(conf.RootDir) } for i, spec := range specs { bundleDir, err := testutil.SetupBundleDir(spec) diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 2e125525b..367295206 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -16,32 +16,32 @@ go_binary( ) runtime_test( + name = "go1.12", blacklist_file = "blacklist_go1.12.csv", - image = "gcr.io/gvisor-presubmit/go1.12", lang = "go", ) runtime_test( + name = "java11", blacklist_file = "blacklist_java11.csv", - image = "gcr.io/gvisor-presubmit/java11", lang = "java", ) runtime_test( + name = "nodejs12.4.0", blacklist_file = "blacklist_nodejs12.4.0.csv", - image = "gcr.io/gvisor-presubmit/nodejs12.4.0", lang = "nodejs", ) runtime_test( + name = "php7.3.6", blacklist_file = "blacklist_php7.3.6.csv", - image = "gcr.io/gvisor-presubmit/php7.3.6", lang = "php", ) runtime_test( + name = "python3.7.3", blacklist_file = "blacklist_python3.7.3.csv", - image = "gcr.io/gvisor-presubmit/python3.7.3", lang = "python", ) diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl index 7c11624b4..6f84ca852 100644 --- a/test/runtimes/build_defs.bzl +++ b/test/runtimes/build_defs.bzl @@ -2,32 +2,48 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") -# runtime_test is a macro that will create targets to run the given test target -# with different runtime options. def runtime_test( + name, lang, - image, + image_repo = "gcr.io/gvisor-presubmit", + image_name = None, + blacklist_file = None, shard_count = 50, - size = "enormous", - blacklist_file = ""): + size = "enormous"): + """Generates sh_test and blacklist test targets for a given runtime. + + Args: + name: The name of the runtime being tested. Typically, the lang + version. + This is used in the names of the generated test targets. + lang: The language being tested. + image_repo: The docker repository containing the proctor image to run. + i.e., the prefix to the fully qualified docker image id. + image_name: The name of the image in the image_repo. + Defaults to the test name. + blacklist_file: A test blacklist to pass to the runtime test's runner. + shard_count: See Bazel common test attributes. + size: See Bazel common test attributes. + """ + if image_name == None: + image_name = name args = [ "--lang", lang, "--image", - image, + "/".join([image_repo, image_name]), ] data = [ ":runner", ] - if blacklist_file != "": + if blacklist_file: args += ["--blacklist_file", "test/runtimes/" + blacklist_file] data += [blacklist_file] # Add a test that the blacklist parses correctly. - blacklist_test(lang, blacklist_file) + blacklist_test(name, blacklist_file) sh_test( - name = lang + "_test", + name = name + "_test", srcs = ["runner.sh"], args = args, data = data, @@ -35,15 +51,16 @@ def runtime_test( shard_count = shard_count, tags = [ # Requires docker and runsc to be configured before the test runs. - "manual", "local", + # Don't include test target in wildcard target patterns. + "manual", ], ) -def blacklist_test(lang, blacklist_file): +def blacklist_test(name, blacklist_file): """Test that a blacklist parses correctly.""" go_test( - name = lang + "_blacklist_test", + name = name + "_blacklist_test", embed = [":runner"], srcs = ["blacklist_test.go"], args = ["--blacklist_file", "test/runtimes/" + blacklist_file], diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go index e6178e82b..b54abe434 100644 --- a/test/runtimes/images/proctor/proctor.go +++ b/test/runtimes/images/proctor/proctor.go @@ -39,10 +39,10 @@ type TestRunner interface { } var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + testName = flag.String("test", "", "run a single test from the list of available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") ) func main() { @@ -74,14 +74,23 @@ func main() { return } - // Run a single test. - if *test == "" { - log.Fatalf("test flag must be provided") + var tests []string + if *testName == "" { + // Run every test. + tests, err = tr.ListTests() + if err != nil { + log.Fatalf("failed to get all tests: %v", err) + } + } else { + // Run a single test. + tests = []string{*testName} } - cmd := tr.TestCmd(*test) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - log.Fatalf("FAIL: %v", err) + for _, test := range tests { + cmd := tr.TestCmd(test) + cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("FAIL: %v", err) + } } } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 87ef87e07..a3a85917d 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -9,7 +9,7 @@ syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test") syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", ) @@ -79,6 +79,12 @@ syscall_test(test = "//test/syscalls/linux:clock_nanosleep_test") syscall_test(test = "//test/syscalls/linux:concurrency_test") syscall_test( + add_uds_tree = True, + test = "//test/syscalls/linux:connect_external_test", + use_tmpfs = True, +) + +syscall_test( add_overlay = True, test = "//test/syscalls/linux:creat_test", ) @@ -370,6 +376,8 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:rlimits_test") +syscall_test(test = "//test/syscalls/linux:rseq_test") + syscall_test(test = "//test/syscalls/linux:rtsignal_test") syscall_test(test = "//test/syscalls/linux:sched_test") @@ -428,7 +436,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", ) @@ -439,7 +447,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", ) @@ -452,19 +460,19 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", ) @@ -475,13 +483,13 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", ) syscall_test( size = "medium", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", ) @@ -492,7 +500,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", ) @@ -505,8 +513,12 @@ syscall_test(test = "//test/syscalls/linux:socket_ip_unbound_test") syscall_test(test = "//test/syscalls/linux:socket_netdevice_test") +syscall_test(test = "//test/syscalls/linux:socket_netlink_test") + syscall_test(test = "//test/syscalls/linux:socket_netlink_route_test") +syscall_test(test = "//test/syscalls/linux:socket_netlink_uevent_test") + syscall_test(test = "//test/syscalls/linux:socket_blocking_local_test") syscall_test(test = "//test/syscalls/linux:socket_blocking_ip_test") @@ -550,7 +562,7 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", ) @@ -589,7 +601,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", ) @@ -659,6 +671,7 @@ syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( size = "medium", + add_hostinet = True, shard_count = 10, test = "//test/syscalls/linux:udp_socket_test", ) @@ -704,6 +717,11 @@ syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test") syscall_test(test = "//test/syscalls/linux:proc_net_udp_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:xattr_test", +) + go_binary( name = "syscall_test_runner", testonly = 1, @@ -716,6 +734,7 @@ go_binary( "//runsc/specutils", "//runsc/testutil", "//test/syscalls/gtest", + "//test/uds", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index e94ef5602..aaf77c65b 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -8,6 +8,8 @@ def syscall_test( size = "small", use_tmpfs = False, add_overlay = False, + add_uds_tree = False, + add_hostinet = False, tags = None): _syscall_test( test = test, @@ -15,6 +17,7 @@ def syscall_test( size = size, platform = "native", use_tmpfs = False, + add_uds_tree = add_uds_tree, tags = tags, ) @@ -24,6 +27,7 @@ def syscall_test( size = size, platform = "kvm", use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, tags = tags, ) @@ -33,6 +37,7 @@ def syscall_test( size = size, platform = "ptrace", use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, tags = tags, ) @@ -43,6 +48,7 @@ def syscall_test( size = size, platform = "ptrace", use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. + add_uds_tree = add_uds_tree, tags = tags, overlay = True, ) @@ -55,10 +61,23 @@ def syscall_test( size = size, platform = "ptrace", use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, tags = tags, file_access = "shared", ) + if add_hostinet: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "ptrace", + use_tmpfs = use_tmpfs, + network = "host", + add_uds_tree = add_uds_tree, + tags = tags, + ) + def _syscall_test( test, shard_count, @@ -66,8 +85,10 @@ def _syscall_test( platform, use_tmpfs, tags, + network = "none", file_access = "exclusive", - overlay = False): + overlay = False, + add_uds_tree = False): test_name = test.split(":")[1] # Prepend "runsc" to non-native platform names. @@ -78,6 +99,8 @@ def _syscall_test( name += "_shared" if overlay: name += "_overlay" + if network != "none": + name += "_" + network + "net" if tags == None: tags = [] @@ -100,9 +123,11 @@ def _syscall_test( # Arguments are passed directly to syscall_test_runner binary. "--test-name=" + test_name, "--platform=" + platform, + "--network=" + network, "--use-tmpfs=" + str(use_tmpfs), "--file-access=" + file_access, "--overlay=" + str(overlay), + "--add-uds-tree=" + str(add_uds_tree), ] sh_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 84a8eb76c..064ce8429 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -6,6 +6,19 @@ package( licenses = ["notice"], ) +exports_files( + [ + "socket.cc", + "socket_inet_loopback.cc", + "socket_ip_loopback_blocking.cc", + "socket_ip_tcp_loopback.cc", + "socket_ipv4_udp_unbound_loopback.cc", + "tcp_socket.cc", + "udp_socket.cc", + ], + visibility = ["//:sandbox"], +) + cc_binary( name = "sigaltstack_check", testonly = 1, @@ -480,6 +493,21 @@ cc_binary( ) cc_binary( + name = "connect_external_test", + testonly = 1, + srcs = ["connect_external.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "creat_test", testonly = 1, srcs = ["creat.cc"], @@ -655,6 +683,7 @@ cc_binary( "//test/util:thread_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest", ], ) @@ -727,6 +756,7 @@ cc_binary( "//test/util:eventfd_util", "//test/util:multiprocess_util", "//test/util:posix_error", + "//test/util:save_util", "//test/util:temp_path", "//test/util:test_util", "//test/util:timer_util", @@ -1552,11 +1582,14 @@ cc_binary( srcs = ["proc_net.cc"], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) @@ -1776,7 +1809,6 @@ cc_binary( name = "readv_socket_test", testonly = 1, srcs = [ - "file_base.h", "readv_common.cc", "readv_common.h", "readv_socket.cc", @@ -1824,6 +1856,22 @@ cc_binary( ) cc_binary( + name = "rseq_test", + testonly = 1, + srcs = ["rseq.cc"], + data = ["//test/syscalls/linux/rseq"], + linkstatic = 1, + deps = [ + "//test/syscalls/linux/rseq:lib", + "//test/util:logging", + "//test/util:multiprocess_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "rtsignal_test", testonly = 1, srcs = ["rtsignal.cc"], @@ -2122,6 +2170,8 @@ cc_library( deps = [ ":socket_test_util", "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], alwayslink = 1, @@ -2214,6 +2264,7 @@ cc_library( ":ip_socket_test_util", ":socket_test_util", "//test/util:test_util", + "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], alwayslink = 1, @@ -2657,6 +2708,20 @@ cc_binary( ) cc_binary( + name = "socket_netlink_test", + testonly = 1, + srcs = ["socket_netlink.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "socket_netlink_route_test", testonly = 1, srcs = ["socket_netlink_route.cc"], @@ -2673,6 +2738,21 @@ cc_binary( ], ) +cc_binary( + name = "socket_netlink_uevent_test", + testonly = 1, + srcs = ["socket_netlink_uevent.cc"], + linkstatic = 1, + deps = [ + ":socket_netlink_util", + ":socket_test_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + # These socket tests are in a library because the test cases are shared # across several test build targets. cc_library( @@ -3195,8 +3275,6 @@ cc_binary( testonly = 1, srcs = ["tcp_socket.cc"], linkstatic = 1, - # FIXME(b/135470853) - tags = ["flaky"], deps = [ ":socket_test_util", "//test/util:file_descriptor", @@ -3301,11 +3379,15 @@ cc_binary( ], ) -cc_binary( - name = "udp_socket_test", +cc_library( + name = "udp_socket_test_cases", testonly = 1, - srcs = ["udp_socket.cc"], - linkstatic = 1, + srcs = [ + "udp_socket_test_cases.cc", + ] + select_for_linux([ + "udp_socket_errqueue_test_case.cc", + ]), + hdrs = ["udp_socket_test_cases.h"], deps = [ ":socket_test_util", ":unix_domain_socket_test_util", @@ -3316,6 +3398,17 @@ cc_binary( "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], + alwayslink = 1, +) + +cc_binary( + name = "udp_socket_test", + testonly = 1, + srcs = ["udp_socket.cc"], + linkstatic = 1, + deps = [ + ":udp_socket_test_cases", + ], ) cc_binary( @@ -3632,3 +3725,24 @@ cc_binary( "@com_google_googletest//:gtest", ], ) + +cc_binary( + name = "xattr_test", + testonly = 1, + srcs = [ + "file_base.h", + "xattr.cc", + ], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc index 1122ea240..e08c578f0 100644 --- a/test/syscalls/linux/accept_bind.cc +++ b/test/syscalls/linux/accept_bind.cc @@ -14,9 +14,10 @@ #include <stdio.h> #include <sys/un.h> + #include <algorithm> #include <vector> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" @@ -140,6 +141,18 @@ TEST_P(AllSocketPairTest, Connect) { SyscallSucceeds()); } +TEST_P(AllSocketPairTest, ConnectNonListening) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallFailsWithErrno(ECONNREFUSED)); +} + TEST_P(AllSocketPairTest, ConnectToFilePath) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/accept_bind_stream.cc b/test/syscalls/linux/accept_bind_stream.cc index b6cdb3f4f..4857f160b 100644 --- a/test/syscalls/linux/accept_bind_stream.cc +++ b/test/syscalls/linux/accept_bind_stream.cc @@ -14,9 +14,10 @@ #include <stdio.h> #include <sys/un.h> + #include <algorithm> #include <vector> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc index b27d4e10a..a33daff17 100644 --- a/test/syscalls/linux/aio.cc +++ b/test/syscalls/linux/aio.cc @@ -129,7 +129,7 @@ TEST_F(AIOTest, BasicWrite) { // aio implementation uses aio_ring. gVisor doesn't and returns all zeroes. // Linux implements aio_ring, so skip the zeroes check. // - // TODO(b/65486370): Remove when gVisor implements aio_ring. + // TODO(gvisor.dev/issue/204): Remove when gVisor implements aio_ring. auto ring = reinterpret_cast<struct aio_ring*>(ctx_); auto magic = IsRunningOnGvisor() ? 0 : AIO_RING_MAGIC; EXPECT_EQ(ring->magic, magic); diff --git a/test/syscalls/linux/bind.cc b/test/syscalls/linux/bind.cc index de8cca53b..9547c4ab2 100644 --- a/test/syscalls/linux/bind.cc +++ b/test/syscalls/linux/bind.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc index 7e918b9b2..a06b5cfd6 100644 --- a/test/syscalls/linux/chmod.cc +++ b/test/syscalls/linux/chmod.cc @@ -16,6 +16,7 @@ #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> + #include <string> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index 498c45f16..04bc2d7b9 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -19,12 +19,12 @@ #include <sys/stat.h> #include <syscall.h> #include <unistd.h> + #include <string> #include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" diff --git a/test/syscalls/linux/clock_gettime.cc b/test/syscalls/linux/clock_gettime.cc index c9e3ed6b2..7f6015049 100644 --- a/test/syscalls/linux/clock_gettime.cc +++ b/test/syscalls/linux/clock_gettime.cc @@ -14,6 +14,7 @@ #include <pthread.h> #include <sys/time.h> + #include <cerrno> #include <cstdint> #include <ctime> @@ -55,11 +56,6 @@ void spin_ns(int64_t ns) { // Test that CLOCK_PROCESS_CPUTIME_ID is a superset of CLOCK_THREAD_CPUTIME_ID. TEST(ClockGettime, CputimeId) { - // TODO(b/128871825,golang.org/issue/10958): Test times out when there is a - // small number of core because one goroutine starves the others. - printf("CPUS: %d\n", std::thread::hardware_concurrency()); - SKIP_IF(std::thread::hardware_concurrency() <= 2); - constexpr int kNumThreads = 13; // arbitrary absl::Duration spin_time = absl::Seconds(1); diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc index 4e0a13f8b..00b96b34a 100644 --- a/test/syscalls/linux/concurrency.cc +++ b/test/syscalls/linux/concurrency.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <signal.h> + #include <atomic> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc new file mode 100644 index 000000000..bfe1da82e --- /dev/null +++ b/test/syscalls/linux/connect_external.cc @@ -0,0 +1,163 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <stdlib.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <string> +#include <tuple> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/test_util.h" + +// This file contains tests specific to connecting to host UDS managed outside +// the sandbox / test. +// +// A set of ultity sockets will be created externally in $TEST_UDS_TREE and +// $TEST_UDS_ATTACH_TREE for these tests to interact with. + +namespace gvisor { +namespace testing { + +namespace { + +struct ProtocolSocket { + int protocol; + std::string name; +}; + +// Parameter is (socket root dir, ProtocolSocket). +using GoferStreamSeqpacketTest = + ::testing::TestWithParam<std::tuple<std::string, ProtocolSocket>>; + +// Connect to a socket and verify that write/read work. +// +// An "echo" socket doesn't work for dgram sockets because our socket is +// unnamed. The server thus has no way to reply to us. +TEST_P(GoferStreamSeqpacketTest, Echo) { + std::string env; + ProtocolSocket proto; + std::tie(env, proto) = GetParam(); + + char *val = getenv(env.c_str()); + ASSERT_NE(val, nullptr); + std::string root(val); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0)); + + std::string socket_path = JoinPath(root, proto.name, "echo"); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); + + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + constexpr int kBufferSize = 64; + char send_buffer[kBufferSize]; + memset(send_buffer, 'a', sizeof(send_buffer)); + + ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)), + SyscallSucceedsWithValue(sizeof(send_buffer))); + + char recv_buffer[kBufferSize]; + ASSERT_THAT(ReadFd(sock.get(), recv_buffer, sizeof(recv_buffer)), + SyscallSucceedsWithValue(sizeof(recv_buffer))); + ASSERT_EQ(0, memcmp(send_buffer, recv_buffer, sizeof(send_buffer))); +} + +// It is not possible to connect to a bound but non-listening socket. +TEST_P(GoferStreamSeqpacketTest, NonListening) { + std::string env; + ProtocolSocket proto; + std::tie(env, proto) = GetParam(); + + char *val = getenv(env.c_str()); + ASSERT_NE(val, nullptr); + std::string root(val); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0)); + + std::string socket_path = JoinPath(root, proto.name, "nonlistening"); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); + + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + sizeof(addr)), + SyscallFailsWithErrno(ECONNREFUSED)); +} + +INSTANTIATE_TEST_SUITE_P( + StreamSeqpacket, GoferStreamSeqpacketTest, + ::testing::Combine( + // Test access via standard path and attach point. + ::testing::Values("TEST_UDS_TREE", "TEST_UDS_ATTACH_TREE"), + ::testing::Values(ProtocolSocket{SOCK_STREAM, "stream"}, + ProtocolSocket{SOCK_SEQPACKET, "seqpacket"}))); + +// Parameter is socket root dir. +using GoferDgramTest = ::testing::TestWithParam<std::string>; + +// Connect to a socket and verify that write works. +// +// An "echo" socket doesn't work for dgram sockets because our socket is +// unnamed. The server thus has no way to reply to us. +TEST_P(GoferDgramTest, Null) { + std::string env = GetParam(); + char *val = getenv(env.c_str()); + ASSERT_NE(val, nullptr); + std::string root(val); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_DGRAM, 0)); + + std::string socket_path = JoinPath(root, "dgram/null"); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); + + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + constexpr int kBufferSize = 64; + char send_buffer[kBufferSize]; + memset(send_buffer, 'a', sizeof(send_buffer)); + + ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)), + SyscallSucceedsWithValue(sizeof(send_buffer))); +} + +INSTANTIATE_TEST_SUITE_P(Dgram, GoferDgramTest, + // Test access via standard path and attach point. + ::testing::Values("TEST_UDS_TREE", + "TEST_UDS_ATTACH_TREE")); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/exceptions.cc b/test/syscalls/linux/exceptions.cc index 370e85166..3d564e720 100644 --- a/test/syscalls/linux/exceptions.cc +++ b/test/syscalls/linux/exceptions.cc @@ -22,6 +22,23 @@ namespace gvisor { namespace testing { +// Default value for the x87 FPU control word. See Intel SDM Vol 1, Ch 8.1.5 +// "x87 FPU Control Word". +constexpr uint16_t kX87ControlWordDefault = 0x37f; + +// Mask for the divide-by-zero exception. +constexpr uint16_t kX87ControlWordDiv0Mask = 1 << 2; + +// Default value for the SSE control register (MXCSR). See Intel SDM Vol 1, Ch +// 11.6.4 "Initialization of SSE/SSE3 Extensions". +constexpr uint32_t kMXCSRDefault = 0x1f80; + +// Mask for the divide-by-zero exception. +constexpr uint32_t kMXCSRDiv0Mask = 1 << 9; + +// Flag for a pending divide-by-zero exception. +constexpr uint32_t kMXCSRDiv0Flag = 1 << 2; + void inline Halt() { asm("hlt\r\n"); } void inline SetAlignmentCheck() { @@ -107,6 +124,170 @@ TEST(ExceptionTest, DivideByZero) { ::testing::KilledBySignal(SIGFPE), ""); } +// By default, x87 exceptions are masked and simply return a default value. +TEST(ExceptionTest, X87DivideByZeroMasked) { + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm("fildl %[value]\r\n" + "fidivl %[divisor]\r\n" + "fistpl %[quotient]\r\n" + : [ quotient ] "=m"(quotient) + : [ value ] "m"(value), [ divisor ] "m"(divisor)); + + EXPECT_EQ(quotient, INT32_MIN); +} + +// When unmasked, division by zero raises SIGFPE. +TEST(ExceptionTest, X87DivideByZeroUnmasked) { + // See above. + struct sigaction sa = {}; + sa.sa_handler = SIG_DFL; + auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); + + EXPECT_EXIT( + { + // Clear the divide by zero exception mask. + constexpr uint16_t kControlWord = + kX87ControlWordDefault & ~kX87ControlWordDiv0Mask; + + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm volatile( + "fldcw %[cw]\r\n" + "fildl %[value]\r\n" + "fidivl %[divisor]\r\n" + "fistpl %[quotient]\r\n" + : [ quotient ] "=m"(quotient) + : [ cw ] "m"(kControlWord), [ value ] "m"(value), + [ divisor ] "m"(divisor)); + }, + ::testing::KilledBySignal(SIGFPE), ""); +} + +// Pending exceptions in the x87 status register are not clobbered by syscalls. +TEST(ExceptionTest, X87StatusClobber) { + // See above. + struct sigaction sa = {}; + sa.sa_handler = SIG_DFL; + auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); + + EXPECT_EXIT( + { + // Clear the divide by zero exception mask. + constexpr uint16_t kControlWord = + kX87ControlWordDefault & ~kX87ControlWordDiv0Mask; + + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm volatile( + "fildl %[value]\r\n" + "fidivl %[divisor]\r\n" + // Exception is masked, so it does not occur here. + "fistpl %[quotient]\r\n" + + // SYS_getpid placed in rax by constraint. + "syscall\r\n" + + // Unmask exception. The syscall didn't clobber the pending + // exception, so now it can be raised. + // + // N.B. "a floating-point exception will be generated upon execution + // of the *next* floating-point instruction". + "fldcw %[cw]\r\n" + "fwait\r\n" + : [ quotient ] "=m"(quotient) + : [ value ] "m"(value), [ divisor ] "m"(divisor), "a"(SYS_getpid), + [ cw ] "m"(kControlWord) + : "rcx", "r11"); + }, + ::testing::KilledBySignal(SIGFPE), ""); +} + +// By default, SSE exceptions are masked and simply return a default value. +TEST(ExceptionTest, SSEDivideByZeroMasked) { + uint32_t status; + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm("cvtsi2ssl %[value], %%xmm0\r\n" + "cvtsi2ssl %[divisor], %%xmm1\r\n" + "divss %%xmm1, %%xmm0\r\n" + "cvtss2sil %%xmm0, %[quotient]\r\n" + : [ quotient ] "=r"(quotient), [ status ] "=r"(status) + : [ value ] "r"(value), [ divisor ] "r"(divisor) + : "xmm0", "xmm1"); + + EXPECT_EQ(quotient, INT32_MIN); +} + +// When unmasked, division by zero raises SIGFPE. +TEST(ExceptionTest, SSEDivideByZeroUnmasked) { + // See above. + struct sigaction sa = {}; + sa.sa_handler = SIG_DFL; + auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); + + EXPECT_EXIT( + { + // Clear the divide by zero exception mask. + constexpr uint32_t kMXCSR = kMXCSRDefault & ~kMXCSRDiv0Mask; + + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm volatile( + "ldmxcsr %[mxcsr]\r\n" + "cvtsi2ssl %[value], %%xmm0\r\n" + "cvtsi2ssl %[divisor], %%xmm1\r\n" + "divss %%xmm1, %%xmm0\r\n" + "cvtss2sil %%xmm0, %[quotient]\r\n" + : [ quotient ] "=r"(quotient) + : [ mxcsr ] "m"(kMXCSR), [ value ] "r"(value), + [ divisor ] "r"(divisor) + : "xmm0", "xmm1"); + }, + ::testing::KilledBySignal(SIGFPE), ""); +} + +// Pending exceptions in the SSE status register are not clobbered by syscalls. +TEST(ExceptionTest, SSEStatusClobber) { + uint32_t mxcsr; + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm("cvtsi2ssl %[value], %%xmm0\r\n" + "cvtsi2ssl %[divisor], %%xmm1\r\n" + "divss %%xmm1, %%xmm0\r\n" + // Exception is masked, so it does not occur here. + "cvtss2sil %%xmm0, %[quotient]\r\n" + + // SYS_getpid placed in rax by constraint. + "syscall\r\n" + + // Intel SDM Vol 1, Ch 10.2.3.1 "SIMD Floating-Point Mask and Flag Bits": + // "If LDMXCSR or FXRSTOR clears a mask bit and sets the corresponding + // exception flag bit, a SIMD floating-point exception will not be + // generated as a result of this change. The unmasked exception will be + // generated only upon the execution of the next SSE/SSE2/SSE3 instruction + // that detects the unmasked exception condition." + // + // Though ambiguous, empirical evidence indicates that this means that + // exception flags set in the status register will never cause an + // exception to be raised; only a new exception condition will do so. + // + // Thus here we just check for the flag itself rather than trying to raise + // the exception. + "stmxcsr %[mxcsr]\r\n" + : [ quotient ] "=r"(quotient), [ mxcsr ] "+m"(mxcsr) + : [ value ] "r"(value), [ divisor ] "r"(divisor), "a"(SYS_getpid) + : "xmm0", "xmm1", "rcx", "r11"); + + EXPECT_TRUE(mxcsr & kMXCSRDiv0Flag); +} + TEST(ExceptionTest, IOAccessFault) { // See above. struct sigaction sa = {}; diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index 4947271ba..b5e0a512b 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -33,6 +33,7 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/multiprocess_util.h" @@ -46,33 +47,25 @@ namespace testing { namespace { -constexpr char kBasicWorkload[] = "exec_basic_workload"; -constexpr char kExitScript[] = "exit_script"; -constexpr char kStateWorkload[] = "exec_state_workload"; -constexpr char kProcExeWorkload[] = "exec_proc_exe_workload"; -constexpr char kAssertClosedWorkload[] = "exec_assert_closed_workload"; -constexpr char kPriorityWorkload[] = "priority_execve"; - -std::string WorkloadPath(absl::string_view binary) { - std::string full_path; - char* test_src = getenv("TEST_SRCDIR"); - if (test_src) { - full_path = JoinPath(test_src, "__main__/test/syscalls/linux", binary); - } - - TEST_CHECK(full_path.empty() == false); - return full_path; -} +constexpr char kBasicWorkload[] = "test/syscalls/linux/exec_basic_workload"; +constexpr char kExitScript[] = "test/syscalls/linux/exit_script"; +constexpr char kStateWorkload[] = "test/syscalls/linux/exec_state_workload"; +constexpr char kProcExeWorkload[] = + "test/syscalls/linux/exec_proc_exe_workload"; +constexpr char kAssertClosedWorkload[] = + "test/syscalls/linux/exec_assert_closed_workload"; +constexpr char kPriorityWorkload[] = "test/syscalls/linux/priority_execve"; constexpr char kExit42[] = "--exec_exit_42"; constexpr char kExecWithThread[] = "--exec_exec_with_thread"; constexpr char kExecFromThread[] = "--exec_exec_from_thread"; -// Runs filename with argv and checks that the exit status is expect_status and -// that stderr contains expect_stderr. -void CheckOutput(const std::string& filename, const ExecveArray& argv, - const ExecveArray& envv, int expect_status, - const std::string& expect_stderr) { +// Runs file specified by dirfd and pathname with argv and checks that the exit +// status is expect_status and that stderr contains expect_stderr. +void CheckExecHelper(const absl::optional<int32_t> dirfd, + const std::string& pathname, const ExecveArray& argv, + const ExecveArray& envv, const int flags, + int expect_status, const std::string& expect_stderr) { int pipe_fds[2]; ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds()); @@ -110,8 +103,15 @@ void CheckOutput(const std::string& filename, const ExecveArray& argv, // CloexecEventfd depend on that not happening. }; - auto kill = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(filename, argv, envv, remap_stderr, &child, &execve_errno)); + Cleanup kill; + if (dirfd.has_value()) { + kill = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(*dirfd, pathname, argv, + envv, flags, remap_stderr, + &child, &execve_errno)); + } else { + kill = ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(pathname, argv, envv, remap_stderr, &child, &execve_errno)); + } ASSERT_EQ(0, execve_errno); @@ -140,6 +140,21 @@ void CheckOutput(const std::string& filename, const ExecveArray& argv, EXPECT_TRUE(absl::StrContains(output, expect_stderr)) << output; } +void CheckExec(const std::string& filename, const ExecveArray& argv, + const ExecveArray& envv, int expect_status, + const std::string& expect_stderr) { + CheckExecHelper(/*dirfd=*/absl::optional<int32_t>(), filename, argv, envv, + /*flags=*/0, expect_status, expect_stderr); +} + +void CheckExecveat(const int32_t dirfd, const std::string& pathname, + const ExecveArray& argv, const ExecveArray& envv, + const int flags, int expect_status, + const std::string& expect_stderr) { + CheckExecHelper(absl::optional<int32_t>(dirfd), pathname, argv, envv, flags, + expect_status, expect_stderr); +} + TEST(ExecTest, EmptyPath) { int execve_errno; ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec("", {}, {}, nullptr, &execve_errno)); @@ -147,73 +162,72 @@ TEST(ExecTest, EmptyPath) { } TEST(ExecTest, Basic) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {}, - ArgEnvExitStatus(0, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n")); + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {}, + ArgEnvExitStatus(0, 0), + absl::StrCat(RunfilePath(kBasicWorkload), "\n")); } TEST(ExecTest, OneArg) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"}, - {}, ArgEnvExitStatus(1, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "1"}, {}, + ArgEnvExitStatus(1, 0), + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); } TEST(ExecTest, FiveArg) { - CheckOutput(WorkloadPath(kBasicWorkload), - {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, - ArgEnvExitStatus(5, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); + CheckExec(RunfilePath(kBasicWorkload), + {RunfilePath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, + ArgEnvExitStatus(5, 0), + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } TEST(ExecTest, OneEnv) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, - {"1"}, ArgEnvExitStatus(0, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1"}, + ArgEnvExitStatus(0, 1), + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); } TEST(ExecTest, FiveEnv) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, - {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, + {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5), + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } TEST(ExecTest, OneArgOneEnv) { - CheckOutput(WorkloadPath(kBasicWorkload), - {WorkloadPath(kBasicWorkload), "arg"}, {"env"}, - ArgEnvExitStatus(1, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n")); + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "arg"}, + {"env"}, ArgEnvExitStatus(1, 1), + absl::StrCat(RunfilePath(kBasicWorkload), "\narg\nenv\n")); } TEST(ExecTest, InterpreterScript) { - CheckOutput(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {}, - ArgEnvExitStatus(25, 0), ""); + CheckExec(RunfilePath(kExitScript), {RunfilePath(kExitScript), "25"}, {}, + ArgEnvExitStatus(25, 0), ""); } // Everything after the path in the interpreter script is a single argument. TEST(ExecTest, InterpreterScriptArgSplit) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), + absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n")); } // Original argv[0] is replaced with the script path. TEST(ExecTest, InterpreterScriptArgvZero) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); - CheckOutput(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); + CheckExec(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script.path(), "\n")); } // Original argv[0] is replaced with the script path, exactly as passed to @@ -221,7 +235,7 @@ TEST(ExecTest, InterpreterScriptArgvZero) { TEST(ExecTest, InterpreterScriptArgvZeroRelative) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); @@ -230,62 +244,62 @@ TEST(ExecTest, InterpreterScriptArgvZeroRelative) { auto script_relative = ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path())); - CheckOutput(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script_relative, "\n")); + CheckExec(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script_relative, "\n")); } // argv[0] is added as the script path, even if there was none. TEST(ExecTest, InterpreterScriptArgvZeroAdded) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); - CheckOutput(script.path(), {}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); + CheckExec(script.path(), {}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script.path(), "\n")); } // A NUL byte in the script line ends parsing. TEST(ExecTest, InterpreterScriptArgNUL) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo", std::string(1, '\0'), "bar"), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), + absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); } // Trailing whitespace following interpreter path is ignored. TEST(ExecTest, InterpreterScriptTrailingWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script.path(), "\n")); } // Multiple whitespace characters between interpreter and arg allowed. TEST(ExecTest, InterpreterScriptArgWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), + absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); } TEST(ExecTest, InterpreterScriptNoPath) { @@ -302,7 +316,7 @@ TEST(ExecTest, InterpreterScriptNoPath) { TEST(ExecTest, ExecFn) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " PrintExecFn"), @@ -314,21 +328,21 @@ TEST(ExecTest, ExecFn) { auto script_relative = ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path())); - CheckOutput(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(script_relative, "\n")); + CheckExec(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0), + absl::StrCat(script_relative, "\n")); } TEST(ExecTest, ExecName) { - std::string path = WorkloadPath(kStateWorkload); + std::string path = RunfilePath(kStateWorkload); - CheckOutput(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(Basename(path).substr(0, 15), "\n")); + CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0), + absl::StrCat(Basename(path).substr(0, 15), "\n")); } TEST(ExecTest, ExecNameScript) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), @@ -336,21 +350,21 @@ TEST(ExecTest, ExecNameScript) { std::string script_path = script.path(); - CheckOutput(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(Basename(script_path).substr(0, 15), "\n")); + CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), + absl::StrCat(Basename(script_path).substr(0, 15), "\n")); } // execve may be called by a multithreaded process. TEST(ExecTest, WithSiblingThread) { - CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {}, - W_EXITCODE(42, 0), ""); + CheckExec("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {}, + W_EXITCODE(42, 0), ""); } // execve may be called from a thread other than the leader of a multithreaded // process. TEST(ExecTest, FromSiblingThread) { - CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {}, - W_EXITCODE(42, 0), ""); + CheckExec("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {}, + W_EXITCODE(42, 0), ""); } TEST(ExecTest, NotFound) { @@ -382,13 +396,13 @@ TEST(ExecStateTest, HandlerReset) { ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigHandler", absl::StrCat(SIGUSR1), absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))), }; - CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Ignored signal dispositions are not reset. @@ -398,13 +412,13 @@ TEST(ExecStateTest, IgnorePreserved) { ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigHandler", absl::StrCat(SIGUSR1), absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))), }; - CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Signal masks are not reset on exec @@ -415,12 +429,12 @@ TEST(ExecStateTest, SignalMask) { ASSERT_THAT(sigprocmask(SIG_BLOCK, &s, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigBlocked", absl::StrCat(SIGUSR1), }; - CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // itimers persist across execve. @@ -448,7 +462,7 @@ TEST(ExecStateTest, ItimerPreserved) { } }; - std::string filename = WorkloadPath(kStateWorkload); + std::string filename = RunfilePath(kStateWorkload); ExecveArray argv = { filename, "CheckItimerEnabled", @@ -472,10 +486,10 @@ TEST(ExecStateTest, ItimerPreserved) { TEST(ProcSelfExe, ChangesAcrossExecve) { // See exec_proc_exe_workload for more details. We simply // assert that the /proc/self/exe link changes across execve. - CheckOutput(WorkloadPath(kProcExeWorkload), - {WorkloadPath(kProcExeWorkload), - ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))}, - {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kProcExeWorkload), + {RunfilePath(kProcExeWorkload), + ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))}, + {}, W_EXITCODE(0, 0), ""); } TEST(ExecTest, CloexecNormalFile) { @@ -484,20 +498,20 @@ TEST(ExecTest, CloexecNormalFile) { const FileDescriptor fd_closed_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); - CheckOutput(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), - absl::StrCat(fd_closed_on_exec.get())}, - {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), + absl::StrCat(fd_closed_on_exec.get())}, + {}, W_EXITCODE(0, 0), ""); // The assert closed workload exits with code 2 if the file still exists. We // can use this to do a negative test. const FileDescriptor fd_open_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY)); - CheckOutput(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), - absl::StrCat(fd_open_on_exec.get())}, - {}, W_EXITCODE(2, 0), ""); + CheckExec( + RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_open_on_exec.get())}, + {}, W_EXITCODE(2, 0), ""); } TEST(ExecTest, CloexecEventfd) { @@ -505,9 +519,239 @@ TEST(ExecTest, CloexecEventfd) { ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds()); FileDescriptor fd(efd); - CheckOutput(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, - W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, + W_EXITCODE(0, 0), ""); +} + +constexpr int kLinuxMaxSymlinks = 40; + +TEST(ExecTest, SymlinkLimitExceeded) { + std::string path = RunfilePath(kBasicWorkload); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> symlinks; + for (int i = 0; i < kLinuxMaxSymlinks + 1; i++) { + symlinks.push_back( + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", path))); + path = symlinks[i].path(); + } + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(path, {path}, {}, /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { + std::string tmp_dir = "/tmp"; + std::string interpreter_path = "/bin/echo"; + TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + tmp_dir, absl::StrCat("#!", interpreter_path), 0755)); + std::string script_path = script.path(); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> interpreter_symlinks; + std::vector<TempPath> script_symlinks; + for (int i = 0; i < kLinuxMaxSymlinks; i++) { + interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, interpreter_path))); + interpreter_path = interpreter_symlinks[i].path(); + script_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, script_path))); + script_path = script_symlinks[i].path(); + } + + CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, BasicWithFDCWD) { + std::string path = RunfilePath(kBasicWorkload); + CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0), + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, Basic) { + std::string absolute_path = RunfilePath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + CheckExecveat(dirfd.get(), base, {absolute_path}, {}, /*flags=*/0, + ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); +} + +TEST(ExecveatTest, FDNotADirectory) { + std::string absolute_path = RunfilePath(kBasicWorkload); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), base, {absolute_path}, {}, + /*flags=*/0, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOTDIR); +} + +TEST(ExecveatTest, AbsolutePathWithFDCWD) { + std::string path = RunfilePath(kBasicWorkload); + CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, AbsolutePath) { + std::string path = RunfilePath(kBasicWorkload); + // File descriptor should be ignored when an absolute path is given. + const int32_t badFD = -1; + CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, EmptyPathBasic) { + std::string path = RunfilePath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0), + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, EmptyPathWithDirFD) { + std::string path = RunfilePath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), "", {path}, {}, + AT_EMPTY_PATH, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, EACCES); +} + +TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { + std::string path = RunfilePath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( + fd.get(), "", {path}, {}, /*flags=*/0, /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { + std::string path = RunfilePath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { + std::string absolute_path = RunfilePath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + CheckExecveat(dirfd.get(), base, {absolute_path}, {}, AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); +} + +TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + std::string base = std::string(Basename(link.path())); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, + AT_SYMLINK_NOFOLLOW, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); + std::string path = link.path(); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(AT_FDCWD, path, {path}, {}, + AT_SYMLINK_NOFOLLOW, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) { + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); + std::string path = link.path(); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH | AT_SYMLINK_NOFOLLOW, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, SymlinkNoFollowIgnoreSymlinkAncestor) { + TempPath parent_link = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", "/bin")); + std::string path_with_symlink = JoinPath(parent_link.path(), "echo"); + + CheckExecveat(AT_FDCWD, path_with_symlink, {path_with_symlink}, {}, + AT_SYMLINK_NOFOLLOW, ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) { + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/bin", O_DIRECTORY)); + + CheckExecveat(dirfd.get(), "echo", {"echo"}, {}, AT_SYMLINK_NOFOLLOW, + ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, BasicWithCloexecFD) { + std::string path = RunfilePath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecFD) { + std::string path = RunfilePath(kExitScript); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), "", {path}, {}, + AT_EMPTY_PATH, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) { + std::string absolute_path = RunfilePath(kExitScript); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_CLOEXEC | O_DIRECTORY)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, + /*flags=*/0, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, InvalidFlags) { + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( + /*dirfd=*/-1, "", {}, {}, /*flags=*/0xFFFF, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, EINVAL); } // Priority consistent across calls to execve() @@ -522,9 +766,8 @@ TEST(GetpriorityTest, ExecveMaintainsPriority) { // Program run (priority_execve) will exit(X) where // X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio. - CheckOutput(WorkloadPath(kPriorityWorkload), - {WorkloadPath(kPriorityWorkload)}, {}, - W_EXITCODE(expected_exit_code, 0), ""); + CheckExec(RunfilePath(kPriorityWorkload), {RunfilePath(kPriorityWorkload)}, + {}, W_EXITCODE(expected_exit_code, 0), ""); } void ExecWithThread() { diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 0a3931e5a..736452b0c 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -20,6 +20,7 @@ #include <sys/types.h> #include <sys/user.h> #include <unistd.h> + #include <algorithm> #include <functional> #include <iterator> diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 8a45be12a..4f3aa81d6 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -14,6 +14,7 @@ #include <fcntl.h> #include <signal.h> +#include <sys/types.h> #include <syscall.h> #include <unistd.h> @@ -32,6 +33,7 @@ #include "test/util/eventfd_util.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/timer_util.h" @@ -910,8 +912,166 @@ TEST(FcntlTest, GetOwn) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), 0); + MaybeSave(); +} + +TEST(FcntlTest, GetOwnEx) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); +} + +TEST(FcntlTest, SetOwnExInvalidType) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = __pid_type(-1); + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(FcntlTest, SetOwnExInvalidTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_TID; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExInvalidPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExInvalidPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PGRP; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_TID; + EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceeds()); + + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceeds()); + + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PGRP; + EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceeds()); + + // NOTE(igudger): I don't understand why, but this is flaky on Linux. + // GetOwnExPgrp (below) does not have this issue. + SKIP_IF(!IsRunningOnGvisor()); + + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), -owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, GetOwnExTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_TID; + EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceeds()); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); +} + +TEST(FcntlTest, GetOwnExPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PID; + EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceeds()); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); +} + +TEST(FcntlTest, GetOwnExPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PGRP; + EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceeds()); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); } } // namespace diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 36efabcae..6f80bc97c 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -27,12 +27,12 @@ #include <sys/types.h> #include <sys/uio.h> #include <unistd.h> + #include <cstring> #include <string> #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -111,95 +111,6 @@ class FileTest : public ::testing::Test { int test_pipe_[2]; }; -class SocketTest : public ::testing::Test { - public: - void SetUp() override { - test_unix_stream_socket_[0] = -1; - test_unix_stream_socket_[1] = -1; - test_unix_dgram_socket_[0] = -1; - test_unix_dgram_socket_[1] = -1; - test_unix_seqpacket_socket_[0] = -1; - test_unix_seqpacket_socket_[1] = -1; - test_tcp_socket_[0] = -1; - test_tcp_socket_[1] = -1; - - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT( - socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - } - - void TearDown() override { - close(test_unix_stream_socket_[0]); - close(test_unix_stream_socket_[1]); - - close(test_unix_dgram_socket_[0]); - close(test_unix_dgram_socket_[1]); - - close(test_unix_seqpacket_socket_[0]); - close(test_unix_seqpacket_socket_[1]); - - close(test_tcp_socket_[0]); - close(test_tcp_socket_[1]); - } - - int test_unix_stream_socket_[2]; - int test_unix_dgram_socket_[2]; - int test_unix_seqpacket_socket_[2]; - int test_tcp_socket_[2]; -}; - -// MatchesStringLength checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string length strlen. -MATCHER_P(MatchesStringLength, strlen, "") { - struct iovec* iovs = arg.first; - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - offset += iovs[i].iov_len; - } - if (offset != static_cast<int>(strlen)) { - *result_listener << offset; - return false; - } - return true; -} - -// MatchesStringValue checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string value str. -MATCHER_P(MatchesStringValue, str, "") { - struct iovec* iovs = arg.first; - int len = strlen(str); - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - struct iovec iov = iovs[i]; - if (len < offset) { - *result_listener << "strlen " << len << " < offset " << offset; - return false; - } - if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) { - absl::string_view iovec_string(static_cast<char*>(iov.iov_base), - iov.iov_len); - *result_listener << iovec_string << " @offset " << offset; - return false; - } - offset += iov.iov_len; - } - return true; -} - } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index b4a91455d..3ecb8db8e 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <sys/file.h> + #include <string> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc index dd6e1a422..371890110 100644 --- a/test/syscalls/linux/fork.cc +++ b/test/syscalls/linux/fork.cc @@ -20,6 +20,7 @@ #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> + #include <atomic> #include <cstdlib> diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index d3e3f998c..40c80a6e1 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -239,6 +239,27 @@ TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) { EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1)); } +TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) { + constexpr int kInitialValue = 1; + std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); + + // Prevent save/restore from interrupting futex_wait, which will cause it to + // return EAGAIN instead of the expected result if futex_wait is restarted + // after we change the value of a below. + DisableSave ds; + ScopedThread thread([&] { + EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), + SyscallSucceedsWithValue(0)); + }); + absl::SleepFor(kWaiterStartupDelay); + + // Change a so that if futex_wake happens before futex_wait, the latter + // returns EAGAIN instead of hanging the test. + a.fetch_add(1); + // The Linux kernel wakes one waiter even if val is 0 or negative. + EXPECT_THAT(futex_wake(IsPrivate(), &a, 0), SyscallSucceedsWithValue(1)); +} + TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index fe9cfafe8..ad2dbacb8 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -23,6 +23,7 @@ #include <sys/types.h> #include <syscall.h> #include <unistd.h> + #include <map> #include <string> #include <unordered_map> diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc index 4948a76f0..b0a07a064 100644 --- a/test/syscalls/linux/ioctl.cc +++ b/test/syscalls/linux/ioctl.cc @@ -25,7 +25,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" @@ -216,7 +215,8 @@ TEST_F(IoctlTest, FIOASYNCSelfTarget2) { auto mask_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - pid_t pid = getpid(); + pid_t pid = -1; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); int set = 1; diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index 57e99596f..8398fc95f 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/syscalls/linux/ip_socket_test_util.h" + #include <net/if.h> #include <netinet/in.h> #include <sys/ioctl.h> #include <sys/socket.h> -#include <cstring> -#include "test/syscalls/linux/ip_socket_test_util.h" +#include <cstring> namespace gvisor { namespace testing { diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 072230d85..9cb4566db 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -26,25 +26,6 @@ namespace gvisor { namespace testing { -// Possible values of the "st" field in a /proc/net/{tcp,udp} entry. Source: -// Linux kernel, include/net/tcp_states.h. -enum { - TCP_ESTABLISHED = 1, - TCP_SYN_SENT, - TCP_SYN_RECV, - TCP_FIN_WAIT1, - TCP_FIN_WAIT2, - TCP_TIME_WAIT, - TCP_CLOSE, - TCP_CLOSE_WAIT, - TCP_LAST_ACK, - TCP_LISTEN, - TCP_CLOSING, - TCP_NEW_SYN_RECV, - - TCP_MAX_STATES -}; - // Extracts the IP address from an inet sockaddr in network byte order. uint32_t IPFromInetSockaddr(const struct sockaddr* addr); diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index 930d2b940..b77e4cbd1 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -267,6 +267,9 @@ int TestSIGPROFFairness(absl::Duration sleep) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { + // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + pid_t child; int execve_errno; auto kill = ASSERT_NO_ERRNO_AND_VALUE( @@ -288,6 +291,9 @@ TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) { + // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + pid_t child; int execve_errno; auto kill = ASSERT_NO_ERRNO_AND_VALUE( diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc index 08ff4052c..7fd0ea20c 100644 --- a/test/syscalls/linux/madvise.cc +++ b/test/syscalls/linux/madvise.cc @@ -25,7 +25,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/file_descriptor.h" #include "test/util/logging.h" #include "test/util/memory_util.h" diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc index a6e20f9c3..94aea4077 100644 --- a/test/syscalls/linux/memory_accounting.cc +++ b/test/syscalls/linux/memory_accounting.cc @@ -13,10 +13,10 @@ // limitations under the License. #include <sys/mman.h> + #include <map> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc index 283c21ed3..620b4f8b4 100644 --- a/test/syscalls/linux/mlock.cc +++ b/test/syscalls/linux/mlock.cc @@ -16,6 +16,7 @@ #include <sys/resource.h> #include <sys/syscall.h> #include <unistd.h> + #include <cerrno> #include <cstring> diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index a112316e9..1c4d9f1c7 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -28,6 +28,7 @@ #include <sys/types.h> #include <sys/wait.h> #include <unistd.h> + #include <vector> #include "gmock/gmock.h" @@ -813,23 +814,27 @@ class MMapFileTest : public MMapTest { } }; +class MMapFileParamTest + : public MMapFileTest, + public ::testing::WithParamInterface<std::tuple<int, int>> { + protected: + int prot() const { return std::get<0>(GetParam()); } + + int flags() const { return std::get<1>(GetParam()); } +}; + // MAP_POPULATE allowed. // There isn't a good way to verify it actually did anything. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, MapPopulate) { - ASSERT_THAT( - Map(0, kPageSize, PROT_READ, MAP_PRIVATE | MAP_POPULATE, fd_.get(), 0), - SyscallSucceeds()); +TEST_P(MMapFileParamTest, MapPopulate) { + ASSERT_THAT(Map(0, kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0), + SyscallSucceeds()); } // MAP_POPULATE on a short file. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, MapPopulateShort) { - ASSERT_THAT(Map(0, 2 * kPageSize, PROT_READ, MAP_PRIVATE | MAP_POPULATE, - fd_.get(), 0), - SyscallSucceeds()); +TEST_P(MMapFileParamTest, MapPopulateShort) { + ASSERT_THAT( + Map(0, 2 * kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0), + SyscallSucceeds()); } // Read contents from mapped file. @@ -900,16 +905,6 @@ TEST_F(MMapFileTest, WritePrivateOnReadOnlyFd) { reinterpret_cast<volatile char*>(addr)); } -// MAP_PRIVATE PROT_READ is not allowed on write-only FDs. -TEST_F(MMapFileTest, ReadPrivateOnWriteOnlyFd) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY)); - - uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd.get(), 0), - SyscallFailsWithErrno(EACCES)); -} - // MAP_SHARED PROT_WRITE not allowed on read-only FDs. TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) { const FileDescriptor fd = @@ -921,28 +916,13 @@ TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) { SyscallFailsWithErrno(EACCES)); } -// MAP_SHARED PROT_READ not allowed on write-only FDs. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, ReadSharedOnWriteOnlyFd) { +// The FD must be readable. +TEST_P(MMapFileParamTest, WriteOnlyFd) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY)); uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd.get(), 0), - SyscallFailsWithErrno(EACCES)); -} - -// MAP_SHARED PROT_WRITE not allowed on write-only FDs. -// The FD must always be readable. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, WriteSharedOnWriteOnlyFd) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY)); - - uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, PROT_WRITE, MAP_SHARED, fd.get(), 0), + EXPECT_THAT(addr = Map(0, kPageSize, prot(), flags(), fd.get(), 0), SyscallFailsWithErrno(EACCES)); } @@ -1181,7 +1161,7 @@ TEST_F(MMapFileTest, ReadSharedTruncateDownThenUp) { ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), SyscallSucceeds()); - // Check that the memory contains he file data. + // Check that the memory contains the file data. EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), buf.c_str(), kPageSize)); // Truncate down, then up. @@ -1370,125 +1350,68 @@ TEST_F(MMapFileTest, WritePrivate) { EqualsMemory(std::string(len, '\0'))); } -// SIGBUS raised when writing past end of file to a private mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, SigBusDeathWritePrivate) { +// SIGBUS raised when reading or writing past end of a mapped file. +TEST_P(MMapFileParamTest, SigBusDeath) { SetupGvisorDeathTest(); uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), + ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), SyscallSucceeds()); - // MMapFileTest makes a file kPageSize/2 long. The entire first page will be - // accessible. Write just beyond that. - size_t len = strlen(kFileContents); - EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + kPageSize)), - ::testing::KilledBySignal(SIGBUS), ""); -} - -// SIGBUS raised when reading past end of file on a shared mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, SigBusDeathReadShared) { - SetupGvisorDeathTest(); - - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // MMapFileTest makes a file kPageSize/2 long. The entire first page will be - // accessible. Read just beyond that. - std::vector<char> in(kPageSize); - EXPECT_EXIT( - std::copy(reinterpret_cast<volatile char*>(addr + kPageSize), - reinterpret_cast<volatile char*>(addr + kPageSize) + kPageSize, - in.data()), - ::testing::KilledBySignal(SIGBUS), ""); + auto* start = reinterpret_cast<volatile char*>(addr + kPageSize); + + // MMapFileTest makes a file kPageSize/2 long. The entire first page should be + // accessible, but anything beyond it should not. + if (prot() & PROT_WRITE) { + // Write beyond first page. + size_t len = strlen(kFileContents); + EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, start), + ::testing::KilledBySignal(SIGBUS), ""); + } else { + // Read beyond first page. + std::vector<char> in(kPageSize); + EXPECT_EXIT(std::copy(start, start + kPageSize, in.data()), + ::testing::KilledBySignal(SIGBUS), ""); + } } -// SIGBUS raised when reading past end of file on a shared mapping. +// Tests that SIGBUS is not raised when reading or writing to a file-mapped +// page before EOF, even if part of the mapping extends beyond EOF. // -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, SigBusDeathWriteShared) { - SetupGvisorDeathTest(); - +// See b/27877699. +TEST_P(MMapFileParamTest, NoSigBusOnPagesBeforeEOF) { uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - // MMapFileTest makes a file kPageSize/2 long. The entire first page will be - // accessible. Write just beyond that. - size_t len = strlen(kFileContents); - EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + kPageSize)), - ::testing::KilledBySignal(SIGBUS), ""); -} - -// Tests that SIGBUS is not raised when writing to a file-mapped page before -// EOF, even if part of the mapping extends beyond EOF. -TEST_F(MMapFileTest, NoSigBusOnPagesBeforeEOF) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), + ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), SyscallSucceeds()); // The test passes if this survives. - size_t len = strlen(kFileContents); - std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr)); -} - -// Tests that SIGBUS is not raised when writing to a file-mapped page containing -// EOF, *after* the EOF for a private mapping. -TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFWritePrivate) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), - SyscallSucceeds()); - - // The test passes if this survives. (Technically addr+kPageSize/2 is already - // beyond EOF, but +1 to check for fencepost errors.) - size_t len = strlen(kFileContents); - std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1)); -} - -// Tests that SIGBUS is not raised when reading from a file-mapped page -// containing EOF, *after* the EOF for a shared mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFReadShared) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // The test passes if this survives. (Technically addr+kPageSize/2 is already - // beyond EOF, but +1 to check for fencepost errors.) auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1); size_t len = strlen(kFileContents); - std::vector<char> in(len); - std::copy(start, start + len, in.data()); + if (prot() & PROT_WRITE) { + std::copy(kFileContents, kFileContents + len, start); + } else { + std::vector<char> in(len); + std::copy(start, start + len, in.data()); + } } -// Tests that SIGBUS is not raised when writing to a file-mapped page containing -// EOF, *after* the EOF for a shared mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFWriteShared) { +// Tests that SIGBUS is not raised when reading or writing from a file-mapped +// page containing EOF, *after* the EOF. +TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) { uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), + ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), SyscallSucceeds()); // The test passes if this survives. (Technically addr+kPageSize/2 is already // beyond EOF, but +1 to check for fencepost errors.) + auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1); size_t len = strlen(kFileContents); - std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1)); + if (prot() & PROT_WRITE) { + std::copy(kFileContents, kFileContents + len, start); + } else { + std::vector<char> in(len); + std::copy(start, start + len, in.data()); + } } // Tests that reading from writable shared file-mapped pages succeeds. @@ -1732,6 +1655,15 @@ TEST(MMapNoFixtureTest, Map32Bit) { #endif // defined(__x86_64__) +INSTANTIATE_TEST_SUITE_P( + ReadWriteSharedPrivate, MMapFileParamTest, + ::testing::Combine(::testing::ValuesIn({ + PROT_READ, + PROT_WRITE, + PROT_READ | PROT_WRITE, + }), + ::testing::ValuesIn({MAP_SHARED, MAP_PRIVATE}))); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index e35be3cab..a3e9745cf 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -18,6 +18,7 @@ #include <sys/mount.h> #include <sys/stat.h> #include <unistd.h> + #include <functional> #include <memory> #include <string> diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 2b1df52ce..267ae19f6 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -73,6 +73,28 @@ class OpenTest : public FileTest { const std::string test_data_ = "hello world\n"; }; +TEST_F(OpenTest, OTrunc) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, OTruncAndReadOnlyDir) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, OTruncAndReadOnlyFile) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); + const FileDescriptor existing = + ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666)); + const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE( + Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); +} + TEST_F(OpenTest, ReadOnly) { char buf; const FileDescriptor ro_file = diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index e5a85ef9d..431733dbe 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -88,6 +88,30 @@ TEST(CreateTest, CreateExclusively) { SyscallFailsWithErrno(EEXIST)); } +TEST(CreateTeast, CreatWithOTrunc) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); + int dirfd; + ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + SyscallSucceeds()); + ASSERT_THAT(close(dirfd), SyscallSucceeds()); +} + TEST(CreateTest, CreateFailsOnUnpermittedDir) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // always override directory permissions. diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 37b4e6575..92ae55eec 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -61,6 +61,9 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Eq; + constexpr char kMessage[] = "soweoneul malhaebwa"; constexpr in_port_t kPort = 0x409c; // htons(40000) @@ -83,17 +86,14 @@ void SendUDPMessage(int sock) { // Send an IP packet and make sure ETH_P_<something else> doesn't pick it up. TEST(BasicCookedPacketTest, WrongType) { - // (b/129292371): Remove once we support packet sockets. - SKIP_IF(IsRunningOnGvisor()); - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP), SyscallFailsWithErrno(EPERM)); GTEST_SKIP(); } - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP)); + FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_PUP))); // Let's use a simple IP payload: a UDP datagram. FileDescriptor udp_sock = @@ -124,23 +124,31 @@ class CookedPacketTest : public ::testing::TestWithParam<int> { }; void CookedPacketTest::SetUp() { - // (b/129292371): Remove once we support packet sockets. - SKIP_IF(IsRunningOnGvisor()); - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), SyscallFailsWithErrno(EPERM)); GTEST_SKIP(); } + if (!IsRunningOnGvisor()) { + FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + char enabled; + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), SyscallSucceeds()); } void CookedPacketTest::TearDown() { - // (b/129292371): Remove once we support packet sockets. - SKIP_IF(IsRunningOnGvisor()); - // TearDown will be run even if we skip the test. if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { EXPECT_THAT(close(socket_), SyscallSucceeds()); @@ -177,13 +185,16 @@ TEST_P(CookedPacketTest, Receive) { ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); - ASSERT_EQ(src_len, sizeof(src)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); - EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -213,6 +224,9 @@ TEST_P(CookedPacketTest, Receive) { // Send via a packet socket. TEST_P(CookedPacketTest, Send) { + // TODO(b/129292371): Remove once we support packet socket writing. + SKIP_IF(IsRunningOnGvisor()); + // Let's send a UDP packet and receive it using a regular UDP socket. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 6491453b6..d258d353c 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -26,6 +26,7 @@ #include <sys/types.h> #include <unistd.h> +#include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/base/internal/endian.h" #include "test/syscalls/linux/socket_test_util.h" @@ -61,6 +62,9 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Eq; + constexpr char kMessage[] = "soweoneul malhaebwa"; constexpr in_port_t kPort = 0x409c; // htons(40000) @@ -97,9 +101,6 @@ class RawPacketTest : public ::testing::TestWithParam<int> { }; void RawPacketTest::SetUp() { - // (b/129292371): Remove once we support packet sockets. - SKIP_IF(IsRunningOnGvisor()); - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())), SyscallFailsWithErrno(EPERM)); @@ -125,9 +126,6 @@ void RawPacketTest::SetUp() { } void RawPacketTest::TearDown() { - // (b/129292371): Remove once we support packet sockets. - SKIP_IF(IsRunningOnGvisor()); - // TearDown will be run even if we skip the test. if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { EXPECT_THAT(close(socket_), SyscallSucceeds()); @@ -164,16 +162,16 @@ TEST_P(RawPacketTest, Receive) { ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); - // sizeof(src) is the size of a struct sockaddr_ll. sockaddr_ll ends with an 8 - // byte physical address field, but ethernet (MAC) addresses only use 6 bytes. - // Thus src_len should get modified to be 2 less than the size of sockaddr_ll. - ASSERT_EQ(src_len, sizeof(src) - 2); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); - EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -214,6 +212,9 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { + // TODO(b/129292371): Remove once we support packet socket writing. + SKIP_IF(IsRunningOnGvisor()); + // Let's send a UDP packet and receive it using a regular UDP socket. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); @@ -309,7 +310,7 @@ TEST_P(RawPacketTest, Send) { } INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, - ::testing::Values(ETH_P_IP /*, ETH_P_ALL*/)); + ::testing::Values(ETH_P_IP, ETH_P_ALL)); } // namespace diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 10e2a6dfc..ac9b21b24 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -20,7 +20,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/notification.h" #include "absl/time/clock.h" @@ -213,6 +212,20 @@ TEST(Pipe2Test, BadOptions) { EXPECT_THAT(pipe2(fds, 0xDEAD), SyscallFailsWithErrno(EINVAL)); } +// Tests that opening named pipes with O_TRUNC shouldn't cause an error, but +// calls to (f)truncate should. +TEST(NamedPipeTest, Truncate) { + const std::string tmp_path = NewTempAbsPath(); + SKIP_IF(mkfifo(tmp_path.c_str(), 0644) != 0); + + ASSERT_THAT(open(tmp_path.c_str(), O_NONBLOCK | O_RDONLY), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmp_path.c_str(), O_RDWR | O_NONBLOCK | O_TRUNC)); + + ASSERT_THAT(truncate(tmp_path.c_str(), 0), SyscallFailsWithErrno(EINVAL)); + ASSERT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); +} + TEST_P(PipeTest, Seek) { SKIP_IF(!CreateBlocking()); diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index 5e3eb1735..2cecf2e5f 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -20,7 +20,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index eebd129f2..f7ea44054 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -22,7 +22,6 @@ #include <string> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc index aac960130..c9246367d 100644 --- a/test/syscalls/linux/preadv2.cc +++ b/test/syscalls/linux/preadv2.cc @@ -21,7 +21,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/memory/memory.h" #include "test/syscalls/linux/file_base.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index e4c030bbb..8cf08991b 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -37,6 +37,7 @@ #include <map> #include <memory> #include <ostream> +#include <regex> #include <string> #include <unordered_set> #include <utility> @@ -51,6 +52,7 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/capability_util.h" @@ -183,7 +185,8 @@ PosixError WithSubprocess(SubprocessCallback const& running, siginfo_t info; // Wait until the child process has exited (WEXITED flag) but don't // reap the child (WNOWAIT flag). - waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED); + EXPECT_THAT(waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED), + SyscallSucceeds()); if (zombied) { // Arg of "Z" refers to a Zombied Process. @@ -1987,6 +1990,44 @@ TEST(Proc, GetdentsEnoent) { SyscallFailsWithErrno(ENOENT)); } +void CheckSyscwFromIOFile(const std::string& path, const std::string& regex) { + std::string output; + ASSERT_NO_ERRNO(GetContents(path, &output)); + ASSERT_THAT(output, ContainsRegex(absl::StrCat("syscw:\\s+", regex, "\n"))); +} + +// Checks that there is variable accounting of IO between threads/tasks. +TEST(Proc, PidTidIOAccounting) { + absl::Notification notification; + + // Run a thread with a bunch of writes. Check that io account records exactly + // the number of write calls. File open/close is there to prevent buffering. + ScopedThread writer([¬ification] { + const int num_writes = 100; + for (int i = 0; i < num_writes; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + ASSERT_NO_ERRNO(SetContents(path.path(), "a")); + } + notification.Notify(); + const std::string& writer_dir = + absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); + + CheckSyscwFromIOFile(writer_dir, std::to_string(num_writes)); + }); + + // Run a thread and do no writes. Check that no writes are recorded. + ScopedThread noop([¬ification] { + notification.WaitForNotification(); + const std::string& noop_dir = + absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); + + CheckSyscwFromIOFile(noop_dir, "0"); + }); + + writer.Join(); + noop.Join(); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index efdaf202b..65bad06d4 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -12,8 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <arpa/inet.h> +#include <errno.h> +#include <netinet/in.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/syscall.h> +#include <sys/types.h> + #include "gtest/gtest.h" -#include "gtest/gtest.h" +#include "absl/strings/str_split.h" +#include "absl/time/clock.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" @@ -57,6 +67,265 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp, + const std::string &type, + const std::string &item) { + std::vector<std::string> snmp_vec = absl::StrSplit(snmp, '\n'); + + // /proc/net/snmp prints a line of headers followed by a line of metrics. + // Only search the headers. + for (unsigned i = 0; i < snmp_vec.size(); i = i + 2) { + if (!absl::StartsWith(snmp_vec[i], type)) continue; + + std::vector<std::string> fields = + absl::StrSplit(snmp_vec[i], ' ', absl::SkipWhitespace()); + + EXPECT_TRUE((i + 1) < snmp_vec.size()); + std::vector<std::string> values = + absl::StrSplit(snmp_vec[i + 1], ' ', absl::SkipWhitespace()); + + EXPECT_TRUE(!fields.empty() && fields.size() == values.size()); + + // Metrics start at the first index. + for (unsigned j = 1; j < fields.size(); j++) { + if (fields[j] == item) { + uint64_t val; + if (!absl::SimpleAtoi(values[j], &val)) { + return PosixError(EINVAL, + absl::StrCat("field is not a number: ", values[j])); + } + + return val; + } + } + } + // We should never get here. + return PosixError( + EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp)); +} + +TEST(ProcNetSnmp, TcpReset_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldAttemptFails; + uint64_t oldActiveOpens; + uint64_t oldOutRsts; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + oldOutRsts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); + oldAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(1234), + }; + + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(connect(s.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallFailsWithErrno(ECONNREFUSED)); + + uint64_t newAttemptFails; + uint64_t newActiveOpens; + uint64_t newOutRsts; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + newOutRsts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); + newAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); + + EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); + EXPECT_EQ(oldOutRsts, newOutRsts - 1); + EXPECT_EQ(oldAttemptFails, newAttemptFails - 1); +} + +TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldEstabResets; + uint64_t oldActiveOpens; + uint64_t oldPassiveOpens; + uint64_t oldCurrEstab; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + oldPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); + oldCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + oldEstabResets = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); + + FileDescriptor s_listen = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = 0, + }; + + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(bind(s_listen.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = sizeof(sin); + ASSERT_THAT( + getsockname(s_listen.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + SyscallSucceeds()); + + FileDescriptor s_connect = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_THAT(connect(s_connect.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + + auto s_accept = + ASSERT_NO_ERRNO_AND_VALUE(Accept(s_listen.get(), nullptr, nullptr)); + + uint64_t newEstabResets; + uint64_t newActiveOpens; + uint64_t newPassiveOpens; + uint64_t newCurrEstab; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + newPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); + newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + + EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); + EXPECT_EQ(oldPassiveOpens, newPassiveOpens - 1); + EXPECT_EQ(oldCurrEstab, newCurrEstab - 2); + + // Send 1 byte from client to server. + ASSERT_THAT(send(s_connect.get(), "a", 1, 0), SyscallSucceedsWithValue(1)); + + constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. + + // Wait until server-side fd sees the data on its side but don't read it. + struct pollfd poll_fd = {s_accept.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close server-side fd without reading the data which leads to a RST + // packet sent to client side. + s_accept.reset(-1); + + // Wait until client-side fd sees RST packet. + struct pollfd poll_fd1 = {s_connect.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd1, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close client-side fd. + s_connect.reset(-1); + + // Wait until the process of the netstack. + absl::SleepFor(absl::Seconds(1)); + + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + newEstabResets = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); + + EXPECT_EQ(oldCurrEstab, newCurrEstab); + EXPECT_EQ(oldEstabResets, newEstabResets - 2); +} + +TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldOutDatagrams; + uint64_t oldNoPorts; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + oldNoPorts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(4444), + }; + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceedsWithValue(1)); + + uint64_t newOutDatagrams; + uint64_t newNoPorts; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + newNoPorts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); + + EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); + EXPECT_EQ(oldNoPorts, newNoPorts - 1); +} + +TEST(ProcNetSnmp, UdpIn) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + const DisableSave ds; + + uint64_t oldOutDatagrams; + uint64_t oldInDatagrams; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + oldInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); + + std::cerr << "snmp: " << std::endl << snmp << std::endl; + FileDescriptor server = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(0), + }; + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(bind(server.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + // Get the port bound by the server socket. + socklen_t addrlen = sizeof(sin); + ASSERT_THAT( + getsockname(server.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + SyscallSucceeds()); + + FileDescriptor client = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + ASSERT_THAT( + sendto(client.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceedsWithValue(1)); + + char buf[128]; + ASSERT_THAT(recvfrom(server.get(), buf, sizeof(buf), 0, NULL, NULL), + SyscallSucceedsWithValue(1)); + + uint64_t newOutDatagrams; + uint64_t newInDatagrams; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + std::cerr << "new snmp: " << std::endl << snmp << std::endl; + newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + newInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); + + EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); + EXPECT_EQ(oldInDatagrams, newInDatagrams - 1); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index f61795592..5b6e3e3cd 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <netinet/tcp.h> #include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc index 369df8e0e..786b4b4af 100644 --- a/test/syscalls/linux/proc_net_udp.cc +++ b/test/syscalls/linux/proc_net_udp.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <netinet/tcp.h> #include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 83dbd1364..66db0acaa 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 99a0df235..dafe64d20 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -70,6 +70,8 @@ constexpr absl::Duration kTimeout = absl::Seconds(20); // The maximum line size in bytes returned per read from a pty file. constexpr int kMaxLineSize = 4096; +constexpr char kMasterPath[] = "/dev/ptmx"; + // glibc defines its own, different, version of struct termios. We care about // what the kernel does, not glibc. #define KERNEL_NCCS 19 @@ -376,9 +378,25 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, return PosixError(ETIMEDOUT, "Poll timed out"); } +TEST(PtyTrunc, Truncate) { + // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to + // (f)truncate should. + FileDescriptor master = + ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); + int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); + std::string spath = absl::StrCat("/dev/pts/", n); + FileDescriptor slave = + ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); + + EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); +} + TEST(BasicPtyTest, StatUnopenedMaster) { struct stat s; - ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds()); + ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds()); EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor)); EXPECT_EQ(s.st_size, 0); diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index e1603fc2d..b48fe540d 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index f6a0fc96c..1dbc0d6df 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -21,7 +21,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 8bcaba6f1..3de898df7 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -129,7 +129,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveBadChecksum) { EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } -// + // Send and receive an ICMP packet. TEST_F(RawSocketICMPTest, SendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 4430fa3c2..2633ba31b 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -14,6 +14,7 @@ #include <fcntl.h> #include <unistd.h> + #include <vector> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index f327ec3a9..4069cbc7e 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc index 35d2dd9e3..491d5f40f 100644 --- a/test/syscalls/linux/readv_common.cc +++ b/test/syscalls/linux/readv_common.cc @@ -19,13 +19,53 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +// MatchesStringLength checks that a tuple argument of (struct iovec *, int) +// corresponding to an iovec array and its length, contains data that matches +// the string length strlen. +MATCHER_P(MatchesStringLength, strlen, "") { + struct iovec* iovs = arg.first; + int niov = arg.second; + int offset = 0; + for (int i = 0; i < niov; i++) { + offset += iovs[i].iov_len; + } + if (offset != static_cast<int>(strlen)) { + *result_listener << offset; + return false; + } + return true; +} + +// MatchesStringValue checks that a tuple argument of (struct iovec *, int) +// corresponding to an iovec array and its length, contains data that matches +// the string value str. +MATCHER_P(MatchesStringValue, str, "") { + struct iovec* iovs = arg.first; + int len = strlen(str); + int niov = arg.second; + int offset = 0; + for (int i = 0; i < niov; i++) { + struct iovec iov = iovs[i]; + if (len < offset) { + *result_listener << "strlen " << len << " < offset " << offset; + return false; + } + if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) { + absl::string_view iovec_string(static_cast<char*>(iov.iov_base), + iov.iov_len); + *result_listener << iovec_string << " @offset " << offset; + return false; + } + offset += iov.iov_len; + } + return true; +} + extern const char kReadvTestData[] = "127.0.0.1 localhost" "" diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc index 3c315cc02..dd6fb7008 100644 --- a/test/syscalls/linux/readv_socket.cc +++ b/test/syscalls/linux/readv_socket.cc @@ -19,8 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/test_util.h" @@ -29,9 +27,30 @@ namespace testing { namespace { -class ReadvSocketTest : public SocketTest { +class ReadvSocketTest : public ::testing::Test { + public: void SetUp() override { - SocketTest::SetUp(); + test_unix_stream_socket_[0] = -1; + test_unix_stream_socket_[1] = -1; + test_unix_dgram_socket_[0] = -1; + test_unix_dgram_socket_[1] = -1; + test_unix_seqpacket_socket_[0] = -1; + test_unix_seqpacket_socket_[1] = -1; + + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT( + socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT( write(test_unix_stream_socket_[1], kReadvTestData, kReadvTestDataSize), SyscallSucceedsWithValue(kReadvTestDataSize)); @@ -41,11 +60,22 @@ class ReadvSocketTest : public SocketTest { ASSERT_THAT(write(test_unix_seqpacket_socket_[1], kReadvTestData, kReadvTestDataSize), SyscallSucceedsWithValue(kReadvTestDataSize)); - // FIXME(b/69821513): Enable when possible. - // ASSERT_THAT(write(test_tcp_socket_[1], kReadvTestData, - // kReadvTestDataSize), - // SyscallSucceedsWithValue(kReadvTestDataSize)); } + + void TearDown() override { + close(test_unix_stream_socket_[0]); + close(test_unix_stream_socket_[1]); + + close(test_unix_dgram_socket_[0]); + close(test_unix_dgram_socket_[1]); + + close(test_unix_seqpacket_socket_[0]); + close(test_unix_seqpacket_socket_[1]); + } + + int test_unix_stream_socket_[2]; + int test_unix_dgram_socket_[2]; + int test_unix_seqpacket_socket_[2]; }; TEST_F(ReadvSocketTest, ReadOneBufferPerByte_StreamSocket) { diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index c9d76c2e2..833c0dc4f 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -14,10 +14,10 @@ #include <fcntl.h> #include <stdio.h> + #include <string> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/util/capability_util.h" #include "test/util/cleanup.h" diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc new file mode 100644 index 000000000..106c045e3 --- /dev/null +++ b/test/syscalls/linux/rseq.cc @@ -0,0 +1,198 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <signal.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <sys/wait.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/rseq/test.h" +#include "test/syscalls/linux/rseq/uapi.h" +#include "test/util/logging.h" +#include "test/util/multiprocess_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Syscall test for rseq (restartable sequences). +// +// We must be very careful about how these tests are written. Each thread may +// only have one struct rseq registration, which may be done automatically at +// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does +// not do so). +// +// Testing of rseq is thus done primarily in a child process with no +// registration. This means exec'ing a nostdlib binary, as rseq registration can +// only be cleared by execve (or knowing the old rseq address), and glibc (based +// on the current unmerged patches) register rseq before calling main()). + +int RSeq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { + return syscall(kRseqSyscall, rseq, rseq_len, flags, sig); +} + +// Returns true if this kernel supports the rseq syscall. +PosixErrorOr<bool> RSeqSupported() { + // We have to be careful here, there are three possible cases: + // + // 1. rseq is not supported -> ENOSYS + // 2. rseq is supported and not registered -> success, but we should + // unregister. + // 3. rseq is supported and registered -> EINVAL (most likely). + + // The only validation done on new registrations is that rseq is aligned and + // writable. + rseq rseq = {}; + int ret = RSeq(&rseq, sizeof(rseq), 0, 0); + if (ret == 0) { + // Successfully registered, rseq is supported. Unregister. + ret = RSeq(&rseq, sizeof(rseq), kRseqFlagUnregister, 0); + if (ret != 0) { + return PosixError(errno); + } + return true; + } + + switch (errno) { + case ENOSYS: + // Not supported. + return false; + case EINVAL: + // Supported, but already registered. EINVAL returned because we provided + // a different address. + return true; + default: + // Unknown error. + return PosixError(errno); + } +} + +constexpr char kRseqBinary[] = "test/syscalls/linux/rseq/rseq"; + +void RunChildTest(std::string test_case, int want_status) { + std::string path = RunfilePath(kRseqBinary); + + pid_t child_pid = -1; + int execve_errno = 0; + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(path, {path, test_case}, {}, &child_pid, &execve_errno)); + + ASSERT_GT(child_pid, 0); + ASSERT_EQ(execve_errno, 0); + + int status = 0; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); + ASSERT_EQ(status, want_status); +} + +// Test that rseq must be aligned. +TEST(RseqTest, Unaligned) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnaligned, 0); +} + +// Sanity test that registration works. +TEST(RseqTest, Register) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestRegister, 0); +} + +// Registration can't be done twice. +TEST(RseqTest, DoubleRegister) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestDoubleRegister, 0); +} + +// Registration can be done again after unregister. +TEST(RseqTest, RegisterUnregister) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestRegisterUnregister, 0); +} + +// The pointer to rseq must match on register/unregister. +TEST(RseqTest, UnregisterDifferentPtr) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnregisterDifferentPtr, 0); +} + +// The signature must match on register/unregister. +TEST(RseqTest, UnregisterDifferentSignature) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnregisterDifferentSignature, 0); +} + +// The CPU ID is initialized. +TEST(RseqTest, CPU) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestCPU, 0); +} + +// Critical section is eventually aborted. +TEST(RseqTest, Abort) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbort, 0); +} + +// Abort may be before the critical section. +TEST(RseqTest, AbortBefore) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortBefore, 0); +} + +// Signature must match. +TEST(RseqTest, AbortSignature) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortSignature, SIGSEGV); +} + +// Abort must not be in the critical section. +TEST(RseqTest, AbortPreCommit) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortPreCommit, SIGSEGV); +} + +// rseq.rseq_cs is cleared on abort. +TEST(RseqTest, AbortClearsCS) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortClearsCS, 0); +} + +// rseq.rseq_cs is cleared on abort outside of critical section. +TEST(RseqTest, InvalidAbortClearsCS) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestInvalidAbortClearsCS, 0); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD new file mode 100644 index 000000000..5cfe4e56f --- /dev/null +++ b/test/syscalls/linux/rseq/BUILD @@ -0,0 +1,59 @@ +# This package contains a standalone rseq test binary. This binary must not +# depend on libc, which might use rseq itself. + +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") +load("@rules_cc//cc:defs.bzl", "cc_library") + +package(licenses = ["notice"]) + +genrule( + name = "rseq_binary", + srcs = [ + "critical.h", + "critical.S", + "rseq.cc", + "syscalls.h", + "start.S", + "test.h", + "types.h", + "uapi.h", + ], + outs = ["rseq"], + cmd = " ".join([ + "$(CC)", + "$(CC_FLAGS) ", + "-I.", + "-Wall", + "-Werror", + "-O2", + "-std=c++17", + "-static", + "-nostdlib", + "-ffreestanding", + "-o", + "$(location rseq)", + "$(location critical.S)", + "$(location rseq.cc)", + "$(location start.S)", + ]), + toolchains = [ + ":no_pie_cc_flags", + "@bazel_tools//tools/cpp:current_cc_toolchain", + ], + visibility = ["//:sandbox"], +) + +cc_flags_supplier( + name = "no_pie_cc_flags", + features = ["-pie"], +) + +cc_library( + name = "lib", + testonly = 1, + hdrs = [ + "test.h", + "uapi.h", + ], + visibility = ["//:sandbox"], +) diff --git a/test/syscalls/linux/rseq/critical.S b/test/syscalls/linux/rseq/critical.S new file mode 100644 index 000000000..8c0687e6d --- /dev/null +++ b/test/syscalls/linux/rseq/critical.S @@ -0,0 +1,66 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Restartable sequences critical sections. + +// Loops continuously until aborted. +// +// void rseq_loop(struct rseq* r, struct rseq_cs* cs) + + .text + .globl rseq_loop + .type rseq_loop, @function + +rseq_loop: + jmp begin + + // Abort block before the critical section. + // Abort signature is 4 nops for simplicity. + .byte 0x90, 0x90, 0x90, 0x90 + .globl rseq_loop_early_abort +rseq_loop_early_abort: + ret + +begin: + // r->rseq_cs = cs + movq %rsi, 8(%rdi) + + // N.B. rseq_cs will be cleared by any preempt, even outside the critical + // section. Thus it must be set in or immediately before the critical section + // to ensure it is not cleared before the section begins. + .globl rseq_loop_start +rseq_loop_start: + jmp rseq_loop_start + + // "Pre-commit": extra instructions inside the critical section. These are + // used as the abort point in TestAbortPreCommit, which is not valid. + .globl rseq_loop_pre_commit +rseq_loop_pre_commit: + // Extra abort signature + nop for TestAbortPostCommit. + .byte 0x90, 0x90, 0x90, 0x90 + nop + + // "Post-commit": never reached in this case. + .globl rseq_loop_post_commit +rseq_loop_post_commit: + + // Abort signature is 4 nops for simplicity. + .byte 0x90, 0x90, 0x90, 0x90 + + .globl rseq_loop_abort +rseq_loop_abort: + ret + + .size rseq_loop,.-rseq_loop + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/critical.h b/test/syscalls/linux/rseq/critical.h new file mode 100644 index 000000000..ac987a25e --- /dev/null +++ b/test/syscalls/linux/rseq/critical.h @@ -0,0 +1,39 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ + +#include "test/syscalls/linux/rseq/types.h" +#include "test/syscalls/linux/rseq/uapi.h" + +constexpr uint32_t kRseqSignature = 0x90909090; + +extern "C" { + +extern void rseq_loop(struct rseq* r, struct rseq_cs* cs); +extern void* rseq_loop_early_abort; +extern void* rseq_loop_start; +extern void* rseq_loop_pre_commit; +extern void* rseq_loop_post_commit; +extern void* rseq_loop_abort; + +extern int rseq_getpid(struct rseq* r, struct rseq_cs* cs); +extern void* rseq_getpid_start; +extern void* rseq_getpid_post_commit; +extern void* rseq_getpid_abort; + +} // extern "C" + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc new file mode 100644 index 000000000..f036db26d --- /dev/null +++ b/test/syscalls/linux/rseq/rseq.cc @@ -0,0 +1,366 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/rseq/critical.h" +#include "test/syscalls/linux/rseq/syscalls.h" +#include "test/syscalls/linux/rseq/test.h" +#include "test/syscalls/linux/rseq/types.h" +#include "test/syscalls/linux/rseq/uapi.h" + +namespace gvisor { +namespace testing { + +extern "C" int main(int argc, char** argv, char** envp); + +// Standalone initialization before calling main(). +extern "C" void __init(uintptr_t* sp) { + int argc = sp[0]; + char** argv = reinterpret_cast<char**>(&sp[1]); + char** envp = &argv[argc + 1]; + + // Call main() and exit. + sys_exit_group(main(argc, argv, envp)); + + // sys_exit_group does not return +} + +int strcmp(const char* s1, const char* s2) { + const unsigned char* p1 = reinterpret_cast<const unsigned char*>(s1); + const unsigned char* p2 = reinterpret_cast<const unsigned char*>(s2); + + while (*p1 == *p2) { + if (!*p1) { + return 0; + } + ++p1; + ++p2; + } + return static_cast<int>(*p1) - static_cast<int>(*p2); +} + +int sys_rseq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { + return raw_syscall(kRseqSyscall, rseq, rseq_len, flags, sig); +} + +// Test that rseq must be aligned. +int TestUnaligned() { + constexpr uintptr_t kRequiredAlignment = alignof(rseq); + + char buf[2 * kRequiredAlignment] = {}; + uintptr_t ptr = reinterpret_cast<uintptr_t>(&buf[0]); + if ((ptr & (kRequiredAlignment - 1)) == 0) { + // buf is already aligned. Misalign it. + ptr++; + } + + int ret = sys_rseq(reinterpret_cast<rseq*>(ptr), sizeof(rseq), 0, 0); + if (sys_errno(ret) != EINVAL) { + return 1; + } + return 0; +} + +// Sanity test that registration works. +int TestRegister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + return 0; +}; + +// Registration can't be done twice. +int TestDoubleRegister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) { + return 1; + } + + return 0; +}; + +// Registration can be done again after unregister. +int TestRegisterUnregister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); + sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + return 0; +}; + +// The pointer to rseq must match on register/unregister. +int TestUnregisterDifferentPtr() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + struct rseq r2 = {}; + if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); + sys_errno(ret) != EINVAL) { + return 1; + } + + return 0; +}; + +// The signature must match on register/unregister. +int TestUnregisterDifferentSignature() { + constexpr int kSignature = 0; + + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); + sys_errno(ret) != EPERM) { + return 1; + } + + return 0; +}; + +// The CPU ID is initialized. +int TestCPU() { + struct rseq r = {}; + r.cpu_id = kRseqCPUIDUninitialized; + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (__atomic_load_n(&r.cpu_id, __ATOMIC_RELAXED) < 0) { + return 1; + } + if (__atomic_load_n(&r.cpu_id_start, __ATOMIC_RELAXED) < 0) { + return 1; + } + + return 0; +}; + +// Critical section is eventually aborted. +int TestAbort() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + return 0; +}; + +// Abort may be before the critical section. +int TestAbortBefore() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_early_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + return 0; +}; + +// Signature must match. +int TestAbortSignature() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. This should SIGSEGV on abort. + rseq_loop(&r, &cs); + + return 1; +}; + +// Abort must not be in the critical section. +int TestAbortPreCommit() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_pre_commit); + + // Loops until abort. This should SIGSEGV on abort. + rseq_loop(&r, &cs); + + return 1; +}; + +// rseq.rseq_cs is cleared on abort. +int TestAbortClearsCS() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + if (__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { + return 1; + } + + return 0; +}; + +// rseq.rseq_cs is cleared on abort outside of critical section. +int TestInvalidAbortClearsCS() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + __atomic_store_n(&r.rseq_cs, &cs, __ATOMIC_RELAXED); + + // When the next abort condition occurs, the kernel will clear cs once it + // determines we aren't in the critical section. + while (1) { + if (!__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { + break; + } + } + + return 0; +}; + +// Exit codes: +// 0 - Pass +// 1 - Fail +// 2 - Missing argument +// 3 - Unknown test case +extern "C" int main(int argc, char** argv, char** envp) { + if (argc != 2) { + // Usage: rseq <test case> + return 2; + } + + if (strcmp(argv[1], kRseqTestUnaligned) == 0) { + return TestUnaligned(); + } + if (strcmp(argv[1], kRseqTestRegister) == 0) { + return TestRegister(); + } + if (strcmp(argv[1], kRseqTestDoubleRegister) == 0) { + return TestDoubleRegister(); + } + if (strcmp(argv[1], kRseqTestRegisterUnregister) == 0) { + return TestRegisterUnregister(); + } + if (strcmp(argv[1], kRseqTestUnregisterDifferentPtr) == 0) { + return TestUnregisterDifferentPtr(); + } + if (strcmp(argv[1], kRseqTestUnregisterDifferentSignature) == 0) { + return TestUnregisterDifferentSignature(); + } + if (strcmp(argv[1], kRseqTestCPU) == 0) { + return TestCPU(); + } + if (strcmp(argv[1], kRseqTestAbort) == 0) { + return TestAbort(); + } + if (strcmp(argv[1], kRseqTestAbortBefore) == 0) { + return TestAbortBefore(); + } + if (strcmp(argv[1], kRseqTestAbortSignature) == 0) { + return TestAbortSignature(); + } + if (strcmp(argv[1], kRseqTestAbortPreCommit) == 0) { + return TestAbortPreCommit(); + } + if (strcmp(argv[1], kRseqTestAbortClearsCS) == 0) { + return TestAbortClearsCS(); + } + if (strcmp(argv[1], kRseqTestInvalidAbortClearsCS) == 0) { + return TestInvalidAbortClearsCS(); + } + + return 3; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/rseq/start.S b/test/syscalls/linux/rseq/start.S new file mode 100644 index 000000000..b9611b276 --- /dev/null +++ b/test/syscalls/linux/rseq/start.S @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + + .text + .align 4 + .type _start,@function + .globl _start + +_start: + movq %rsp,%rdi + call __init + hlt + + .size _start,.-_start + .section .note.GNU-stack,"",@progbits + + .text + .globl raw_syscall + .type raw_syscall, @function + +raw_syscall: + mov %rdi,%rax // syscall # + mov %rsi,%rdi // arg0 + mov %rdx,%rsi // arg1 + mov %rcx,%rdx // arg2 + mov %r8,%r10 // arg3 (goes in r10 instead of rcx for system calls) + mov %r9,%r8 // arg4 + mov 0x8(%rsp),%r9 // arg5 + syscall + ret + + .size raw_syscall,.-raw_syscall + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/syscalls.h b/test/syscalls/linux/rseq/syscalls.h new file mode 100644 index 000000000..e5299c188 --- /dev/null +++ b/test/syscalls/linux/rseq/syscalls.h @@ -0,0 +1,66 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ + +#include "test/syscalls/linux/rseq/types.h" + +#ifdef __x86_64__ +// Syscall numbers. +constexpr int kGetpid = 39; +constexpr int kExitGroup = 231; +#else +#error "Unknown architecture" +#endif + +namespace gvisor { +namespace testing { + +// Standalone system call interfaces. +// Note that these are all "raw" system call interfaces which encode +// errors by setting the return value to a small negative number. +// Use sys_errno() to check system call return values for errors. + +// Maximum Linux error number. +constexpr int kMaxErrno = 4095; + +// Errno values. +#define EPERM 1 +#define EFAULT 14 +#define EBUSY 16 +#define EINVAL 22 + +// Get the error number from a raw system call return value. +// Returns a positive error number or 0 if there was no error. +static inline int sys_errno(uintptr_t rval) { + if (rval >= static_cast<uintptr_t>(-kMaxErrno)) { + return -static_cast<int>(rval); + } + return 0; +} + +extern "C" uintptr_t raw_syscall(int number, ...); + +static inline void sys_exit_group(int status) { + raw_syscall(kExitGroup, status); +} +static inline int sys_getpid() { + return static_cast<int>(raw_syscall(kGetpid)); +} + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h new file mode 100644 index 000000000..3b7bb74b1 --- /dev/null +++ b/test/syscalls/linux/rseq/test.h @@ -0,0 +1,43 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ + +namespace gvisor { +namespace testing { + +// Test cases supported by rseq binary. + +inline constexpr char kRseqTestUnaligned[] = "unaligned"; +inline constexpr char kRseqTestRegister[] = "register"; +inline constexpr char kRseqTestDoubleRegister[] = "double-register"; +inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; +inline constexpr char kRseqTestUnregisterDifferentPtr[] = + "unregister-different-ptr"; +inline constexpr char kRseqTestUnregisterDifferentSignature[] = + "unregister-different-signature"; +inline constexpr char kRseqTestCPU[] = "cpu"; +inline constexpr char kRseqTestAbort[] = "abort"; +inline constexpr char kRseqTestAbortBefore[] = "abort-before"; +inline constexpr char kRseqTestAbortSignature[] = "abort-signature"; +inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; +inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; +inline constexpr char kRseqTestInvalidAbortClearsCS[] = + "invalid-abort-clears-cs"; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ diff --git a/test/syscalls/linux/rseq/types.h b/test/syscalls/linux/rseq/types.h new file mode 100644 index 000000000..b6afe9817 --- /dev/null +++ b/test/syscalls/linux/rseq/types.h @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ + +using size_t = __SIZE_TYPE__; +using uintptr_t = __UINTPTR_TYPE__; + +using uint8_t = __UINT8_TYPE__; +using uint16_t = __UINT16_TYPE__; +using uint32_t = __UINT32_TYPE__; +using uint64_t = __UINT64_TYPE__; + +using int8_t = __INT8_TYPE__; +using int16_t = __INT16_TYPE__; +using int32_t = __INT32_TYPE__; +using int64_t = __INT64_TYPE__; + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h new file mode 100644 index 000000000..e3ff0579a --- /dev/null +++ b/test/syscalls/linux/rseq/uapi.h @@ -0,0 +1,54 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ + +// User-kernel ABI for restartable sequences. + +// Standard types. +// +// N.B. This header will be included in targets that do have the standard +// library, so we can't shadow the standard type names. +using __u32 = __UINT32_TYPE__; +using __u64 = __UINT64_TYPE__; + +#ifdef __x86_64__ +// Syscall numbers. +constexpr int kRseqSyscall = 334; +#else +#error "Unknown architecture" +#endif // __x86_64__ + +struct rseq_cs { + __u32 version; + __u32 flags; + __u64 start_ip; + __u64 post_commit_offset; + __u64 abort_ip; +} __attribute__((aligned(4 * sizeof(__u64)))); + +// N.B. alignment is enforced by the kernel. +struct rseq { + __u32 cpu_id_start; + __u32 cpu_id; + struct rseq_cs* rseq_cs; + __u32 flags; +} __attribute__((aligned(4 * sizeof(__u64)))); + +constexpr int kRseqFlagUnregister = 1 << 0; + +constexpr int kRseqCPUIDUninitialized = -1; + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc index e77586852..7e41fe7d8 100644 --- a/test/syscalls/linux/seccomp.cc +++ b/test/syscalls/linux/seccomp.cc @@ -25,6 +25,7 @@ #include <time.h> #include <ucontext.h> #include <unistd.h> + #include <atomic> #include "gmock/gmock.h" diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc index 88c010aec..424e2a67f 100644 --- a/test/syscalls/linux/select.cc +++ b/test/syscalls/linux/select.cc @@ -16,12 +16,12 @@ #include <sys/resource.h> #include <sys/select.h> #include <sys/time.h> + #include <climits> #include <csignal> #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/time.h" #include "test/syscalls/linux/base_poll_test.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index 40c57f543..e9b131ca9 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -447,9 +447,8 @@ TEST(SemaphoreTest, SemCtlGetPidFork) { const pid_t child_pid = fork(); if (child_pid == 0) { - ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); - ASSERT_THAT(semctl(sem.get(), 0, GETPID), - SyscallSucceedsWithValue(getpid())); + TEST_PCHECK(semctl(sem.get(), 0, SETVAL, 1) == 0); + TEST_PCHECK(semctl(sem.get(), 0, GETPID) == getpid()); _exit(0); } diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc index 1c56540bc..3331288b7 100644 --- a/test/syscalls/linux/sendfile_socket.cc +++ b/test/syscalls/linux/sendfile_socket.cc @@ -185,7 +185,7 @@ TEST_P(SendFileTest, Shutdown) { // Create a socket. std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, released below. + FileDescriptor server(std::get<1>(fds)); // non-const, reset below. // If this is a TCP socket, then turn off linger. if (GetParam() == AF_INET) { @@ -210,14 +210,14 @@ TEST_P(SendFileTest, Shutdown) { // checking the contents (other tests do that), so we just re-use the same // buffer as above. ScopedThread t([&]() { - int done = 0; + size_t done = 0; while (done < data.size()) { - int n = read(server.get(), data.data(), data.size()); + int n = RetryEINTR(read)(server.get(), data.data(), data.size()); ASSERT_THAT(n, SyscallSucceeds()); done += n; } // Close the server side socket. - ASSERT_THAT(close(server.release()), SyscallSucceeds()); + server.reset(); }); // Continuously stream from the file to the socket. Note we do not assert diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index eb7a3966f..7ba752599 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -13,7 +13,6 @@ // limitations under the License. #include <stdio.h> - #include <sys/ipc.h> #include <sys/mman.h> #include <sys/shm.h> diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc index 69b6e4f90..a778fa639 100644 --- a/test/syscalls/linux/sigaltstack.cc +++ b/test/syscalls/linux/sigaltstack.cc @@ -22,7 +22,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/cleanup.h" #include "test/util/fs_util.h" #include "test/util/multiprocess_util.h" @@ -96,13 +95,7 @@ TEST(SigaltstackTest, ResetByExecve) { auto const cleanup_sigstack = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack)); - std::string full_path; - char* test_src = getenv("TEST_SRCDIR"); - if (test_src) { - full_path = JoinPath(test_src, "../../linux/sigaltstack_check"); - } - - ASSERT_FALSE(full_path.empty()); + std::string full_path = RunfilePath("test/syscalls/linux/sigaltstack_check"); pid_t child_pid = -1; int execve_errno = 0; diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc index 9379d5878..09ecad34a 100644 --- a/test/syscalls/linux/signalfd.cc +++ b/test/syscalls/linux/signalfd.cc @@ -24,7 +24,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/synchronization/mutex.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc index d20821cac..6b27f6eab 100644 --- a/test/syscalls/linux/socket_bind_to_device.cc +++ b/test/syscalls/linux/socket_bind_to_device.cc @@ -32,7 +32,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc index 4d2400328..5767181a1 100644 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -33,7 +33,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index a7365d139..34b1058a9 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -33,7 +33,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" @@ -98,12 +97,22 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { sockets_to_close_.erase(socket_id); } - // Bind a socket with the reuse option and bind_to_device options. Checks + // SetDevice changes the bind_to_device option. It does not bind or re-bind. + void SetDevice(int socket_id, int device_id) { + auto socket_fd = sockets_to_close_[socket_id]->get(); + string device_name; + ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, + device_name.c_str(), device_name.size() + 1), + SyscallSucceedsWithValue(0)); + } + + // Bind a socket with the reuse options and bind_to_device options. Checks // that all steps succeed and that the bind command's error matches want. // Sets the socket_id to uniquely identify the socket bound if it is not // nullptr. - void BindSocket(bool reuse, int device_id = 0, int want = 0, - int *socket_id = nullptr) { + void BindSocket(bool reuse_port, bool reuse_addr, int device_id = 0, + int want = 0, int *socket_id = nullptr) { next_socket_id_++; sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket_fd = sockets_to_close_[next_socket_id_]->get(); @@ -111,13 +120,20 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { *socket_id = next_socket_id_; } - // If reuse is indicated, do that. - if (reuse) { + // If reuse_port is indicated, do that. + if (reuse_port) { EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); } + // If reuse_addr is indicated, do that. + if (reuse_addr) { + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + } + // If the device is non-zero, bind to that device. if (device_id != 0) { string device_name; @@ -183,129 +199,308 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { }; TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 3)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindToDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 1)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 2)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 1)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 2)); } TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithReuse) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0)); + BindSocket(/* reusePort */ true, /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); } TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse_port */ true, /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 999, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 999, 0)); } TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindAndRelease) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); int to_release; - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 345, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); // Release the bind to device 0 and try again. ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 345)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 345)); } TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + BindSocket(/* reusePort */ false, /* reuse_addr */ true)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); +} + +TEST_P(BindToDeviceSequenceTest, + CannotBindTo0AfterMixingReuseAddrAndNotReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReusePortThenReuseAddrReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +// This behavior seems like a bug? +TEST_P(BindToDeviceSequenceTest, + BindReuseAddrThenReuseAddrReusePortThenReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); +} + +// Repro test for gvisor.dev/issue/1217. Not replicated in ports_test.go as this +// test is different from the others and wouldn't fit well there. +TEST_P(BindToDeviceSequenceTest, BindAndReleaseDifferentDevice) { + int to_release; + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, EADDRINUSE)); + // Change the device. Since the socket was already bound, this should have no + // effect. + SetDevice(to_release, 2); + // Release the bind to device 3 and try again. + ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); } INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, diff --git a/test/syscalls/linux/socket_blocking.cc b/test/syscalls/linux/socket_blocking.cc index 00c50d1bf..7e88aa2d9 100644 --- a/test/syscalls/linux/socket_blocking.cc +++ b/test/syscalls/linux/socket_blocking.cc @@ -17,10 +17,10 @@ #include <sys/socket.h> #include <sys/types.h> #include <sys/un.h> + #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index 51d614639..e8f24a59e 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 322ee07ad..619d41901 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -14,9 +14,9 @@ #include <arpa/inet.h> #include <netinet/in.h> +#include <netinet/tcp.h> #include <poll.h> #include <string.h> -#include <sys/socket.h> #include <atomic> #include <iostream> @@ -30,6 +30,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" @@ -43,6 +44,8 @@ namespace testing { namespace { +using ::testing::Gt; + PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { switch (family) { case AF_INET: @@ -99,19 +102,17 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { SyscallFailsWithErrno(EAFNOSUPPORT)); } -TEST_P(SocketInetLoopbackTest, TCP) { - auto const& param = GetParam(); - - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - +void tcpSimpleConnectTest(TestAddress const& listener, + TestAddress const& connector, bool unbound) { // Create the listening socket. const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + if (!unbound) { + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + } ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. @@ -145,6 +146,23 @@ TEST_P(SocketInetLoopbackTest, TCP) { ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); } +TEST_P(SocketInetLoopbackTest, TCP) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + tcpSimpleConnectTest(listener, connector, true); +} + +TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + tcpSimpleConnectTest(listener, connector, false); +} + TEST_P(SocketInetLoopbackTest, TCPListenClose) { auto const& param = GetParam(); @@ -203,7 +221,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked // before function end. - // ds.reset() + // ds.reset(); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -266,6 +284,394 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { } } +// TCPFinWait2Test creates a pair of connected sockets then closes one end to +// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local +// IP/port on a new socket and tries to connect. The connect should fail w/ +// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the +// connect again with a new socket and this time it should succeed. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Lower FIN_WAIT2 state to 5 seconds for test. + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + // Now bind and connect a new socket. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Disable cooperative saves after this point. As a save between the first + // bind/connect and the second one can cause the linger timeout timer to + // be restarted causing the final bind/connect to fail. + DisableSave ds; + + // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple + // reservations which causes the bind() to succeed on gVisor but connect + // correctly fails. + if (IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } else { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } + + // Sleep for a little over the linger timeout to reduce flakiness in + // save/restore tests. + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + ds.reset(); + + if (!IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + } + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPLinger2TimeoutAfterClose creates a pair of connected sockets +// then closes one end to trigger FIN_WAIT2 state for the closed endpont. +// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ +// connecting the same address succeeds. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPResetAfterClose creates a pair of connected sockets then closes +// one end to trigger FIN_WAIT2 state for the closed endpoint verifies +// that we generate RSTs for any new data after the socket is fully +// closed. +TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + int data = 1234; + + // Now send data which should trigger a RST as the other end should + // have timed out and closed the socket. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceeds()); + // Sleep for a shortwhile to get a RST back. + absl::SleepFor(absl::Seconds(1)); + + // Try writing again and we should get an EPIPE back. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallFailsWithErrno(EPIPE)); + + // Trying to read should return zero as the other end did send + // us a FIN. We do it twice to verify that the RST does not cause an + // ECONNRESET on the read after EOF has been read by applicaiton. + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); +} + +// This test is disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the accept FD to trigger TIME_WAIT on the accepted socket which + // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of + // TIME_WAIT. + accepted.reset(); + absl::SleepFor(absl::Seconds(1)); + conn_fd.reset(); + absl::SleepFor(absl::Seconds(1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the userTimeout on the listening socket. + constexpr int kUserTimeout = 10; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kUserTimeout, sizeof(kUserTimeout)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + // Verify that the accepted socket inherited the user timeout set on + // listening socket. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kUserTimeout); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( @@ -298,7 +704,9 @@ INSTANTIATE_TEST_SUITE_P( using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>; -TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { +// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is +// saved/restored. +TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -306,6 +714,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { sockaddr_storage listen_addr = listener.addr; sockaddr_storage conn_addr = connector.addr; constexpr int kThreadCount = 3; + constexpr int kConnectAttempts = 4096; // Create the listening socket. FileDescriptor listener_fds[kThreadCount]; @@ -320,7 +729,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { ASSERT_THAT( bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), SyscallSucceeds()); - ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + ASSERT_THAT(listen(fd, kConnectAttempts / 3), SyscallSucceeds()); // On the first bind we need to determine which port was bound. if (i != 0) { @@ -339,7 +748,6 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); } - constexpr int kConnectAttempts = 10000; std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); std::unique_ptr<ScopedThread> listen_thread[kThreadCount]; int accept_counts[kThreadCount] = {}; @@ -516,6 +924,115 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // TODO(b/141211329): endpointsByNic.seed has to be saved/restored. + const DisableSave ds141211329; + + // Create listening sockets. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10; + FileDescriptor client_fds[kConnectAttempts]; + + // Do the first run without save/restore. + DisableSave ds; + for (int i = 0; i < kConnectAttempts; i++) { + client_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + ds.reset(); + + // Check that a mapping of client and server sockets has + // not been change after save/restore. + for (int i = 0; i < kConnectAttempts; i++) { + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + + struct pollfd pollfds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + pollfds[i].fd = listener_fds[i].get(); + pollfds[i].events = POLLIN; + } + + std::map<uint16_t, int> portToFD; + + int received = 0; + while (received < kConnectAttempts * 2) { + ASSERT_THAT(poll(pollfds, kThreadCount, -1), + SyscallSucceedsWithValue(Gt(0))); + + for (int i = 0; i < kThreadCount; i++) { + if ((pollfds[i].revents & POLLIN) == 0) { + continue; + } + + received++; + + const int fd = pollfds[i].fd; + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + EXPECT_THAT(RetryEINTR(recvfrom)( + fd, &data, sizeof(data), 0, + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceedsWithValue(sizeof(data))); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr)); + auto prev_port = portToFD.find(port); + // Check that all packets from one client have been delivered to the + // same server socket. + if (prev_port == portToFD.end()) { + portToFD[port] = fd; + } else { + EXPECT_EQ(portToFD[port], fd); + } + } + } +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetReusePortTest, ::testing::Values( @@ -713,10 +1230,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { sockaddr_storage addr_dual = test_addr_dual.addr; const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_dual.family(), param.type, 0)); - int one = 1; - EXPECT_THAT( - setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY, &one, sizeof(one)), - SyscallSucceeds()); + EXPECT_THAT(setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), test_addr_dual.addr_len), SyscallSucceeds()); @@ -764,7 +1280,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/114268588) + // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make + // it disabled by default. SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); for (int i = 0; true; i++) { @@ -862,10 +1379,76 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { } } +TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { + auto const& param = GetParam(); + + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); + + // Bind the v6 loopback on a dual stack socket. + TestAddress const& test_addr = V6Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(connect(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), + reinterpret_cast<sockaddr*>(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), + connected_addr_len), + SyscallSucceeds()); +} + TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/114268588) + // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make + // it disabled by default. SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); for (int i = 0; true; i++) { @@ -965,9 +1548,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { // v6-only socket. const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6_any.family(), param.type, 0)); - int one = 1; EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, - &one, sizeof(one)), + &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); ret = bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), @@ -986,10 +1568,78 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { } } +TEST_P(SocketMultiProtocolInetLoopbackTest, + V4MappedEphemeralPortReservedResueAddr) { + auto const& param = GetParam(); + + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); + + // Bind the v4 loopback on a dual stack socket. + TestAddress const& test_addr = V4MappedLoopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(connect(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), + reinterpret_cast<sockaddr*>(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), + connected_addr_len), + SyscallSucceeds()); +} + TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/114268588) + // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make + // it disabled by default. SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); for (int i = 0; true; i++) { @@ -1090,9 +1740,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { // v6-only socket. const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6_any.family(), param.type, 0)); - int one = 1; EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, - &one, sizeof(one)), + &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); ret = bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), @@ -1111,6 +1760,75 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { } } +TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { + auto const& param = GetParam(); + + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); + + // Bind the v4 loopback on a v4 socket. + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(connect(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), + reinterpret_cast<sockaddr*>(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), + connected_addr_len), + SyscallSucceeds()); +} + TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { auto const& param = GetParam(); TestAddress const& test_addr = V4Loopback(); @@ -1197,7 +1915,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) { } INSTANTIATE_TEST_SUITE_P( - AllFamlies, SocketMultiProtocolInetLoopbackTest, + AllFamilies, SocketMultiProtocolInetLoopbackTest, ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, ProtocolTestParam{"UDP", SOCK_DGRAM}), DescribeProtocolTestParam); diff --git a/test/syscalls/linux/socket_ip_loopback_blocking.cc b/test/syscalls/linux/socket_ip_loopback_blocking.cc index d7fc9715b..fda252dd7 100644 --- a/test/syscalls/linux/socket_ip_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_loopback_blocking.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" @@ -22,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>( @@ -42,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P( BlockingIPSockets, BlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index bfa7943b1..57ce8e169 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -24,14 +24,16 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { -TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, TcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -40,7 +42,7 @@ TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -49,7 +51,7 @@ TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -244,6 +246,31 @@ TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { SyscallSucceedsWithValue(0)); } +// This test verifies that a shutdown(wr) by the server after sending +// data allows the client to still read() the queued data and a client +// close after sending response allows server to read the incoming +// response. +TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char buf[10] = {}; + ScopedThread t([&]() { + ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(close(sockets->release_first_fd()), + SyscallSucceedsWithValue(0)); + }); + ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR), + SyscallSucceedsWithValue(0)); + t.Join(); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); +} + TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -697,5 +724,156 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); } +// Linux and Netstack both default to a 60s TCP_LINGER2 timeout. +constexpr int kDefaultTCPLingerTimeout = 60; + +TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + + constexpr int kNegative = -1234; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kNegative, sizeof(kNegative)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kAboveDefault, sizeof(kAboveDefault)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, TestTCPCloseWithData) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ScopedThread t([&]() { + // Close one end to trigger sending of a FIN. + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds()); + char buf[3]; + ASSERT_THAT(read(sockets->second_fd(), buf, 3), + SyscallSucceedsWithValue(3)); + absl::SleepFor(absl::Milliseconds(50)); + ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); + }); + + absl::SleepFor(absl::Milliseconds(50)); + // Send some data then close. + constexpr char kStr[] = "abc"; + ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3)); + t.Join(); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); +} + +TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kZero, sizeof(kZero)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kNeg = -10; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kNeg, sizeof(kNeg)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAbove = 10; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kAbove, sizeof(kAbove)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kAbove); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc index 0dc274e2d..d11f7cc23 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" diff --git a/test/syscalls/linux/socket_ip_tcp_loopback.cc b/test/syscalls/linux/socket_ip_tcp_loopback.cc index 831de53b8..9db3037bc 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return { @@ -34,5 +35,6 @@ INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, AllSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc index cd3ad97d0..fcd20102f 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc index 1acdecc17..63a05b799 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" diff --git a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc index de63f79d9..f178f1af9 100644 --- a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc @@ -22,7 +22,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 044394ba7..66eb68857 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -24,7 +24,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" @@ -36,7 +35,7 @@ TEST_P(UDPSocketPairTest, MulticastTTLDefault) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -53,7 +52,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMin) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -70,7 +69,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMax) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -92,7 +91,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLNegativeOne) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -127,7 +126,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLChar) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -148,7 +147,7 @@ TEST_P(UDPSocketPairTest, MulticastLoopDefault) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -164,7 +163,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoop) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -174,7 +173,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoop) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -193,7 +192,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -203,12 +202,120 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { &kSockOptOnChar, sizeof(kSockOptOnChar)), SyscallSucceeds()); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); EXPECT_EQ(get, kSockOptOn); } +TEST_P(UDPSocketPairTest, ReuseAddrDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, SetReuseAddr) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, ReusePortDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, SetReusePort) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index fa9a9df6f..b6754111f 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -23,7 +23,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" @@ -355,6 +354,38 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) { EXPECT_EQ(get, expect); } +TEST_P(IPUnboundSocketTest, NullTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + TOSOption t = GetTOSOption(GetParam().domain); + int set_sz = sizeof(int); + if (GetParam().domain == AF_INET) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallFailsWithErrno(EFAULT)); + } else { // AF_INET6 + // The AF_INET6 behavior is not yet compatible. gVisor will try to read + // optval from user memory at syscall handler, it needs substantial + // refactoring to implement this behavior just for IPv6. + if (IsRunningOnGvisor()) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallFailsWithErrno(EFAULT)); + } else { + // Linux's IPv6 stack treats nullptr optval as input of 0, so the call + // succeeds. (net/ipv6/ipv6_sockglue.c, do_ipv6_setsockopt()) + // + // Linux's implementation would need fixing as passing a nullptr as optval + // and non-zero optlen may not be valid. + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallSucceedsWithValue(0)); + } + } + socklen_t get_sz = sizeof(int); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, nullptr, &get_sz), + SyscallFailsWithErrno(EFAULT)); + int get = -1; + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, nullptr), + SyscallFailsWithErrno(EFAULT)); +} + INSTANTIATE_TEST_SUITE_P( IPUnboundSockets, IPUnboundSocketTest, ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>( diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc index 3a068aacf..80f12b0a9 100644 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc @@ -18,12 +18,12 @@ #include <sys/socket.h> #include <sys/types.h> #include <sys/un.h> + #include <cstdio> #include <cstring> #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc index 92f03e045..3ac790873 100644 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc +++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h" + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 67d29af0a..aa6fb4e3f 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -18,10 +18,11 @@ #include <sys/ioctl.h> #include <sys/socket.h> #include <sys/un.h> + #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" @@ -52,26 +53,27 @@ TestAddress V4Broadcast() { // Check that packets are not received without a group membership. Default send // interface configured by bind. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address. If multicast worked like unicast, // this would ensure that we get the packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -83,33 +85,33 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } // Check that not setting a default send interface prevents multicast packets // from being sent. Group membership interface configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the second FD to the v4 any address to ensure that we can receive any // unicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -119,8 +121,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -129,27 +131,27 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallFailsWithErrno(ENETUNREACH)); } // Check that not setting a default send interface prevents multicast packets // from being sent. Group membership interface configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the second FD to the v4 any address to ensure that we can receive any // unicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -159,8 +161,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -169,35 +171,35 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallFailsWithErrno(ENETUNREACH)); } // Check that multicast works when the default send interface is configured by // bind and the group membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -207,8 +209,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -217,43 +219,42 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } // Check that multicast works when the default send interface is configured by // bind and the group membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -263,8 +264,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -273,17 +274,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -291,25 +290,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -319,8 +319,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -329,17 +329,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -347,25 +345,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -375,8 +374,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -385,17 +384,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -403,25 +400,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -431,8 +429,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -440,22 +438,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -463,25 +459,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -491,8 +488,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -500,22 +497,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -523,25 +518,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -551,8 +547,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -561,17 +557,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -579,25 +573,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -607,8 +602,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -617,17 +612,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -635,25 +628,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -663,8 +657,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -672,20 +666,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; EXPECT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } @@ -693,25 +686,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -721,8 +715,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -730,20 +724,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } @@ -751,29 +744,30 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP, &kSockOptOff, sizeof(kSockOptOff)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -783,8 +777,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -793,17 +787,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -811,29 +803,30 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP, &kSockOptOff, sizeof(kSockOptOff)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -843,8 +836,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -853,57 +846,57 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } // Check that dropping a group membership that does not exist fails. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastInvalidDrop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastInvalidDrop) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Unregister from a membership that we didn't have. ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } // Check that dropping a group membership prevents multicast packets from being // delivered. Default send address configured by bind and group membership // interface configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -913,11 +906,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -926,15 +919,14 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } @@ -942,26 +934,27 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { // Check that dropping a group membership prevents multicast packets from being // delivered. Default send address configured by bind and group membership // interface configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -971,11 +964,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -984,50 +977,53 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfZero) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfInvalidNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; iface.imr_ifindex = -1; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfInvalidAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreq iface = {}; iface.imr_interface.s_addr = inet_addr("255.255.255"); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetShort) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetShort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Create a valid full-sized request. ip_mreqn iface = {}; @@ -1035,29 +1031,31 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetShort) { // Send an optlen of 1 to check that optlen is enforced. EXPECT_THAT( - setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1), + setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1), SyscallFailsWithErrno(EINVAL)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefault) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefaultReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefaultReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the @@ -1072,19 +1070,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefaultReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddrGetReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddrGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); in_addr set = {}; set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the @@ -1096,19 +1095,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddrGetReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddrGetReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreq set = {}; set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the @@ -1120,19 +1120,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddrGetReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNicGetReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNicGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn set = {}; set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(in_addr)); EXPECT_EQ(get.imr_multiaddr.s_addr, 0); @@ -1140,87 +1141,93 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNicGetReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); in_addr set = {}; set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, set.s_addr); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreq set = {}; set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, set.imr_interface.s_addr); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn set = {}; set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupNoIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(ENODEV)); } -TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupInvalidIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupInvalidIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn group = {}; group.imr_address.s_addr = inet_addr("255.255.255"); group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(ENODEV)); } // Check that multiple memberships are not allowed on the same socket. -TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - auto fd = sockets->first_fd(); +TEST_P(IPv4UDPUnboundSocketTest, TestMultipleJoinsOnSingleSocket) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto fd = socket1->get(); ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); @@ -1235,41 +1242,44 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) { } // Check that two sockets can join the same multicast group at the same time. -TEST_P(IPv4UDPUnboundSocketPairTest, TestTwoSocketsJoinSameMulticastGroup) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Drop the membership twice on each socket, the second call for each socket // should fail. - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(EADDRNOTAVAIL)); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } // Check that two sockets can join the same multicast group at the same time, // and both will receive data on it. -TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) { +TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { std::unique_ptr<SocketPair> socket_pairs[2] = { - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())}; + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())), + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))}; ip_mreq iface = {}, group = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); @@ -1339,11 +1349,12 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) { // Check that on two sockets that joined a group and listen on ANY, dropping // memberships one by one will continue to deliver packets to both sockets until // both memberships have been dropped. -TEST_P(IPv4UDPUnboundSocketPairTest, - TestMcastReceptionWhenDroppingMemberships) { +TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { std::unique_ptr<SocketPair> socket_pairs[2] = { - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())}; + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())), + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))}; ip_mreq iface = {}, group = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); @@ -1438,18 +1449,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, // Check that a receiving socket can bind to the multicast address before // joining the group and receive data once the group has been joined. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the multicast address. auto receiver_addr = V4Multicast(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Update receiver_addr with the correct port number. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1459,30 +1471,29 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet on the first socket out the loopback interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); auto sendto_addr = V4Multicast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -1490,18 +1501,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { // Check that a receiving socket can bind to the multicast address and won't // receive multicast data if it hasn't joined the group. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the multicast address. auto receiver_addr = V4Multicast(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Update receiver_addr with the correct port number. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1510,40 +1522,40 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { // Send a multicast packet on the first socket out the loopback interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); auto sendto_addr = V4Multicast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we don't receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } // Check that a socket can bind to a multicast address and still send out // packets. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the ANY address. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1552,11 +1564,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { // Bind the first socket (sender) to the multicast address. auto sender_addr = V4Multicast(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), &sender_addr_len), SyscallSucceeds()); @@ -1568,15 +1580,14 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -1584,46 +1595,46 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { // Check that a receiving socket can bind to the broadcast address and receive // broadcast packets. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the broadcast address. auto receiver_addr = V4Broadcast(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); // Send a broadcast packet on the first socket out the loopback interface. - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, - &kSockOptOn, sizeof(kSockOptOn)), + EXPECT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); // Note: Binding to the loopback interface makes the broadcast go out of it. auto sender_bind_addr = V4Loopback(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), - sender_bind_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); auto sendto_addr = V4Broadcast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -1631,17 +1642,18 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { // Check that a socket can bind to the broadcast address and still send out // packets. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the ANY address. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1650,11 +1662,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { // Bind the first socket (sender) to the broadcast address. auto sender_addr = V4Broadcast(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), &sender_addr_len), SyscallSucceeds()); @@ -1666,19 +1678,476 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } +// Check that SO_REUSEADDR always delivers to the most recently bound socket. +TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + std::vector<std::unique_ptr<FileDescriptor>> sockets; + sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); + + ASSERT_THAT(setsockopt(sockets[0]->get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(sockets[0]->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(sockets[0]->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + constexpr int kMessageSize = 200; + + for (int i = 0; i < 10; i++) { + // Add a new receiver. + sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); + auto& last = sockets.back(); + ASSERT_THAT(setsockopt(last->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(last->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Send a new message to the SO_REUSEADDR group. We use a new socket each + // time so that a new ephemeral port will be used each time. This ensures + // that we aren't doing REUSEPORT-like hash load blancing. + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + char send_buf[kMessageSize]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Verify that the most recent socket got the message. We don't expect any + // of the other sockets to have received it, but we will check that later. + char recv_buf[sizeof(send_buf)] = {}; + EXPECT_THAT( + RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } + + // Verify that no other messages were received. + for (auto& socket : sockets) { + char recv_buf[kMessageSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. + socket2->reset(); + + // Bind socket3 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. + socket2->reset(); + + // Bind socket3 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, also with REUSEADDR and + // REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, also with REUSEADDR and + // REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +// Check that REUSEPORT takes precedence over REUSEADDR. +TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(receiver1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind receiver2 to the same address as socket1, also with REUSEADDR and + // REUSEPORT. + ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(receiver2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + constexpr int kMessageSize = 10; + + for (int i = 0; i < 100; ++i) { + // Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket + // each time so that a new ephemerial port will be used each time. This + // ensures that we cycle through hashes. + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + char send_buf[kMessageSize] = {}; + EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + } + + // Check that both receivers got messages. This checks that we are using load + // balancing (REUSEPORT) instead of the most recently bound socket + // (REUSEADDR). + char recv_buf[kMessageSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(kMessageSize)); + EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(kMessageSize)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.h b/test/syscalls/linux/socket_ipv4_udp_unbound.h index 8e07bfbbf..f64c57645 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.h @@ -20,8 +20,8 @@ namespace gvisor { namespace testing { -// Test fixture for tests that apply to pairs of IPv4 UDP sockets. -using IPv4UDPUnboundSocketPairTest = SocketPairTest; +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketTest = SimpleSocketTest; } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 8b8993d3d..98ae414f3 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -27,7 +27,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc index 9d4e1ab97..8f47952b0 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h" + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc index cb0105471..f121c044d 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc @@ -22,14 +22,11 @@ namespace gvisor { namespace testing { -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - IPv4UDPUnboundSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P(IPv4UDPSockets, IPv4UDPUnboundSocketPairTest, - ::testing::ValuesIn(GetSocketPairs())); +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 765f8e0e4..405dbbd73 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -68,7 +68,8 @@ TEST(NetdeviceTest, Netmask) { // Use a netlink socket to get the netmask, which we'll then compare to the // netmask obtained via ioctl. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { diff --git a/test/syscalls/linux/socket_netlink.cc b/test/syscalls/linux/socket_netlink.cc new file mode 100644 index 000000000..4ec0fd4fa --- /dev/null +++ b/test/syscalls/linux/socket_netlink.cc @@ -0,0 +1,153 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/netlink.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Tests for all netlink socket protocols. + +namespace gvisor { +namespace testing { + +namespace { + +// NetlinkTest parameter is the protocol to test. +using NetlinkTest = ::testing::TestWithParam<int>; + +// Netlink sockets must be SOCK_DGRAM or SOCK_RAW. +TEST_P(NetlinkTest, Types) { + const int protocol = GetParam(); + + EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + + int fd; + EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, protocol), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); + + EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, protocol), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + +TEST_P(NetlinkTest, AutomaticPort) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + + EXPECT_THAT( + bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceeds()); + + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + // This is the only netlink socket in the process, so it should get the PID as + // the port id. + // + // N.B. Another process could theoretically have explicitly reserved our pid + // as a port ID, but that is very unlikely. + EXPECT_EQ(addr.nl_pid, getpid()); +} + +// Calling connect automatically binds to an automatic port. +TEST_P(NetlinkTest, ConnectBinds) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + + EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + + // Each test is running in a pid namespace, so another process can explicitly + // reserve our pid as a port ID. In this case, a negative portid value will be + // set. + if (static_cast<pid_t>(addr.nl_pid) > 0) { + EXPECT_EQ(addr.nl_pid, getpid()); + } + + memset(&addr, 0, sizeof(addr)); + addr.nl_family = AF_NETLINK; + + // Connecting again is allowed, but keeps the same port. + EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + EXPECT_EQ(addr.nl_pid, getpid()); +} + +TEST_P(NetlinkTest, GetPeerName) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + socklen_t addrlen = sizeof(addr); + + EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, sizeof(addr)); + EXPECT_EQ(addr.nl_family, AF_NETLINK); + // Peer is the kernel if we didn't connect elsewhere. + EXPECT_EQ(addr.nl_pid, 0); +} + +INSTANTIATE_TEST_SUITE_P(ProtocolTest, NetlinkTest, + ::testing::Values(NETLINK_ROUTE, + NETLINK_KOBJECT_UEVENT)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index dd4a11655..ef567f512 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -41,112 +41,7 @@ namespace { using ::testing::AnyOf; using ::testing::Eq; -// Netlink sockets must be SOCK_DGRAM or SOCK_RAW. -TEST(NetlinkRouteTest, Types) { - EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - - int fd; - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(NetlinkRouteTest, AutomaticPort) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT( - bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - // This is the only netlink socket in the process, so it should get the PID as - // the port id. - // - // N.B. Another process could theoretically have explicitly reserved our pid - // as a port ID, but that is very unlikely. - EXPECT_EQ(addr.nl_pid, getpid()); -} - -// Calling connect automatically binds to an automatic port. -TEST(NetlinkRouteTest, ConnectBinds) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - - // Each test is running in a pid namespace, so another process can explicitly - // reserve our pid as a port ID. In this case, a negative portid value will be - // set. - if (static_cast<pid_t>(addr.nl_pid) > 0) { - EXPECT_EQ(addr.nl_pid, getpid()); - } - - memset(&addr, 0, sizeof(addr)); - addr.nl_family = AF_NETLINK; - - // Connecting again is allowed, but keeps the same port. - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_pid, getpid()); -} - -TEST(NetlinkRouteTest, GetPeerName) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - socklen_t addrlen = sizeof(addr); - - EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_family, AF_NETLINK); - // Peer is the kernel if we didn't connect elsewhere. - EXPECT_EQ(addr.nl_pid, 0); -} - -// Parameters for GetSockOpt test. They are: +// Parameters for SockOptTest. They are: // 0: Socket option to query. // 1: A predicate to run on the returned sockopt value. Should return true if // the value is considered ok. @@ -195,7 +90,8 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(SO_DOMAIN, IsEqual(AF_NETLINK), absl::StrFormat("AF_NETLINK (%d)", AF_NETLINK)), std::make_tuple(SO_PROTOCOL, IsEqual(NETLINK_ROUTE), - absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)))); + absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)), + std::make_tuple(SO_PASSCRED, IsEqual(0), "0"))); // Validates the reponses to RTM_GETLINK + NLM_F_DUMP. void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { @@ -218,7 +114,8 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { } TEST(NetlinkRouteTest, GetLinkDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -259,7 +156,8 @@ TEST(NetlinkRouteTest, GetLinkDump) { } TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -292,7 +190,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { } TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -331,7 +230,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { } TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -372,7 +272,8 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { } TEST(NetlinkRouteTest, ControlMessageIgnored) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -407,7 +308,8 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { } TEST(NetlinkRouteTest, GetAddrDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -467,7 +369,8 @@ TEST(NetlinkRouteTest, LookupAll) { // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. TEST(NetlinkRouteTest, GetRouteDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -543,7 +446,8 @@ TEST(NetlinkRouteTest, GetRouteDump) { // buffer. MSG_TRUNC with a zero length buffer should consume subsequent // messages off the socket. TEST(NetlinkRouteTest, RecvmsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -619,7 +523,8 @@ TEST(NetlinkRouteTest, RecvmsgTrunc) { // it, so a properly sized buffer can be allocated to store the message. This // test tests that scenario. TEST(NetlinkRouteTest, RecvmsgTruncPeek) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -692,6 +597,115 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) { } while (type != NLMSG_DONE && type != NLMSG_ERROR); } +// No SCM_CREDENTIALS are received without SO_PASSCRED set. +TEST(NetlinkRouteTest, NoPasscredNoCreds) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + iov.iov_base = NULL; + iov.iov_len = 0; + + char control[CMSG_SPACE(sizeof(struct ucred))] = {}; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + // Note: This test assumes at least one message is returned by the + // RTM_GETADDR request. + ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + // No control messages. + EXPECT_EQ(CMSG_FIRSTHDR(&msg), nullptr); +} + +// SCM_CREDENTIALS are received with SO_PASSCRED set. +TEST(NetlinkRouteTest, PasscredCreds) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + iov.iov_base = NULL; + iov.iov_len = 0; + + char control[CMSG_SPACE(sizeof(struct ucred))] = {}; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + // Note: This test assumes at least one message is returned by the + // RTM_GETADDR request. + ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + struct ucred creds; + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); + + memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds)); + + // The peer is the kernel, which is "PID" 0. + EXPECT_EQ(creds.pid, 0); + // The kernel identifies as root. Also allow nobody in case this test is + // running in a userns without root mapped. + EXPECT_THAT(creds.uid, AnyOf(Eq(0), Eq(65534))); + EXPECT_THAT(creds.gid, AnyOf(Eq(0), Eq(65534))); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_uevent.cc b/test/syscalls/linux/socket_netlink_uevent.cc new file mode 100644 index 000000000..da425bed4 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_uevent.cc @@ -0,0 +1,83 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/filter.h> +#include <linux/netlink.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Tests for NETLINK_KOBJECT_UEVENT sockets. +// +// gVisor never sends any messages on these sockets, so we don't test the events +// themselves. + +namespace gvisor { +namespace testing { + +namespace { + +// SO_PASSCRED can be enabled. Since no messages are sent in gVisor, we don't +// actually test receiving credentials. +TEST(NetlinkUeventTest, PassCred) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + EXPECT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); +} + +// SO_DETACH_FILTER fails without a filter already installed. +TEST(NetlinkUeventTest, DetachNoFilter) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + int opt; + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), + SyscallFailsWithErrno(ENOENT)); +} + +// We can attach a BPF filter. +TEST(NetlinkUeventTest, AttachFilter) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + // Minimal BPF program: a single ret. + struct sock_filter filter = {0x6, 0, 0, 0}; + struct sock_fprog prog = {}; + prog.len = 1; + prog.filter = &filter; + + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_ATTACH_FILTER, &prog, sizeof(prog)), + SyscallSucceeds()); + + int opt; + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), + SyscallSucceeds()); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index fcb8f8a88..723f5d728 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -12,24 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <sys/socket.h> +#include "test/syscalls/linux/socket_netlink_util.h" #include <linux/if_arp.h> #include <linux/netlink.h> -#include <linux/rtnetlink.h> +#include <sys/socket.h> #include <vector> #include "absl/strings/str_cat.h" -#include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" namespace gvisor { namespace testing { -PosixErrorOr<FileDescriptor> NetlinkBoundSocket() { +PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) { FileDescriptor fd; - ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); + ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); struct sockaddr_nl addr = {}; addr.nl_family = AF_NETLINK; diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index db8639a2f..76e772c48 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -15,9 +15,10 @@ #ifndef GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_ #define GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_ +#include <sys/socket.h> +// socket.h has to be included before if_arp.h. #include <linux/if_arp.h> #include <linux/netlink.h> -#include <linux/rtnetlink.h> #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -25,8 +26,8 @@ namespace gvisor { namespace testing { -// Returns a bound NETLINK_ROUTE socket. -PosixErrorOr<FileDescriptor> NetlinkBoundSocket(); +// Returns a bound netlink socket. +PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol); // Returns the port ID of the passed socket. PosixErrorOr<uint32_t> NetlinkPortID(int fd); diff --git a/test/syscalls/linux/socket_non_blocking.cc b/test/syscalls/linux/socket_non_blocking.cc index 73e6dc618..c3520cadd 100644 --- a/test/syscalls/linux/socket_non_blocking.cc +++ b/test/syscalls/linux/socket_non_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc index 3c599b6e8..d91c5ed39 100644 --- a/test/syscalls/linux/socket_non_stream.cc +++ b/test/syscalls/linux/socket_non_stream.cc @@ -19,7 +19,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc index 76127d181..62d87c1af 100644 --- a/test/syscalls/linux/socket_non_stream_blocking.cc +++ b/test/syscalls/linux/socket_non_stream_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc index 0417dd347..346443f96 100644 --- a/test/syscalls/linux/socket_stream.cc +++ b/test/syscalls/linux/socket_stream.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc index 8367460d2..e9cc082bf 100644 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ b/test/syscalls/linux/socket_stream_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream_nonblock.cc b/test/syscalls/linux/socket_stream_nonblock.cc index b00748b97..74d608741 100644 --- a/test/syscalls/linux/socket_stream_nonblock.cc +++ b/test/syscalls/linux/socket_stream_nonblock.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 70710195c..2dbb8bed3 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -30,7 +30,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_format.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -115,6 +114,9 @@ class FDSocketPair : public SocketPair { public: FDSocketPair(int first_fd, int second_fd) : first_(first_fd), second_(second_fd) {} + FDSocketPair(std::unique_ptr<FileDescriptor> first_fd, + std::unique_ptr<FileDescriptor> second_fd) + : first_(first_fd->release()), second_(second_fd->release()) {} int first_fd() const override { return first_.get(); } int second_fd() const override { return second_.get(); } diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 875f0391f..4cf1f76f1 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -25,7 +25,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" @@ -66,6 +65,21 @@ TEST_P(UnixSocketPairTest, BindToBadName) { SyscallFailsWithErrno(ENOENT)); } +TEST_P(UnixSocketPairTest, BindToBadFamily) { + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + constexpr char kBadName[] = "/some/path/that/does/not/exist"; + sockaddr_un sockaddr; + sockaddr.sun_family = AF_INET; + memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName)); + + EXPECT_THAT( + bind(pair->first_fd(), reinterpret_cast<struct sockaddr*>(&sockaddr), + sizeof(sockaddr)), + SyscallFailsWithErrno(EINVAL)); +} + TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[10]; diff --git a/test/syscalls/linux/socket_unix_blocking_local.cc b/test/syscalls/linux/socket_unix_blocking_local.cc index 1994139e6..6f84221b2 100644 --- a/test/syscalls/linux/socket_unix_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_blocking_local.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_blocking.h" - #include <vector> +#include "test/syscalls/linux/socket_blocking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc index 1092e29b1..a16899493 100644 --- a/test/syscalls/linux/socket_unix_cmsg.cc +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -25,7 +25,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" @@ -150,6 +149,35 @@ TEST_P(UnixSocketPairCmsgTest, BadFDPass) { SyscallFailsWithErrno(EBADF)); } +TEST_P(UnixSocketPairCmsgTest, ShortCmsg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + int sent_fd = -1; + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(sent_fd))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = 1; + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); + + struct iovec iov; + iov.iov_base = sent_data; + iov.iov_len = sizeof(sent_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EINVAL)); +} + // BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. // The difference is that when calling recvmsg, no space for FDs is provided, // only space for the cmsg header. diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc index 3e0f611d2..af0df4fb4 100644 --- a/test/syscalls/linux/socket_unix_dgram.cc +++ b/test/syscalls/linux/socket_unix_dgram.cc @@ -16,7 +16,7 @@ #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc index 707052af8..2db8b68d3 100644 --- a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc +++ b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc @@ -14,7 +14,7 @@ #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index dafe82494..276a94eb8 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -19,7 +19,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/memory_util.h" @@ -231,11 +230,21 @@ TEST_P(UnixNonStreamSocketPairTest, SendTimeout) { setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); - char buf[100] = {}; + const int buf_size = 5 * kPageSize; + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceeds()); + EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceeds()); + + // The buffer size should be big enough to avoid many iterations in the next + // loop. Otherwise, this will slow down cooperative_save tests. + std::vector<char> buf(kPageSize); for (;;) { int ret; ASSERT_THAT( - ret = RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), + ret = RetryEINTR(send)(sockets->first_fd(), buf.data(), buf.size(), 0), ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN))); if (ret == -1) { break; diff --git a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc index da762cd83..8855d5001 100644 --- a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_non_stream_blocking.h" - #include <vector> +#include "test/syscalls/linux/socket_non_stream_blocking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc index 6f6367dd5..84d3a569e 100644 --- a/test/syscalls/linux/socket_unix_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_seqpacket.cc @@ -16,7 +16,7 @@ #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 8f38ed92f..563467365 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_stream_blocking_local.cc index fa0a9d367..08e579ba7 100644 --- a/test/syscalls/linux/socket_unix_stream_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_stream_blocking_local.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_stream_blocking.h" - #include <vector> +#include "test/syscalls/linux/socket_stream_blocking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc index ec777c59f..1936aa135 100644 --- a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc +++ b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc @@ -11,10 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_stream_nonblock.h" - #include <vector> +#include "test/syscalls/linux/socket_stream_nonblock.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc index 4b5832de8..8b1762000 100644 --- a/test/syscalls/linux/socket_unix_unbound_abstract.cc +++ b/test/syscalls/linux/socket_unix_unbound_abstract.cc @@ -14,7 +14,7 @@ #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_dgram.cc b/test/syscalls/linux/socket_unix_unbound_dgram.cc index 52aef891f..907dca0f1 100644 --- a/test/syscalls/linux/socket_unix_unbound_dgram.cc +++ b/test/syscalls/linux/socket_unix_unbound_dgram.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc index 8cb03c450..cab912152 100644 --- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc +++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc @@ -14,7 +14,7 @@ #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc index 0575f2e1d..cb99030f5 100644 --- a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc @@ -14,7 +14,7 @@ #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_stream.cc b/test/syscalls/linux/socket_unix_unbound_stream.cc index e483d2777..f185dded3 100644 --- a/test/syscalls/linux/socket_unix_unbound_stream.cc +++ b/test/syscalls/linux/socket_unix_unbound_stream.cc @@ -14,7 +14,7 @@ #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 88ab90b5b..30de2f8ff 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -24,7 +24,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc index fe479390d..8aa2525a9 100644 --- a/test/syscalls/linux/sync.cc +++ b/test/syscalls/linux/sync.cc @@ -14,10 +14,9 @@ #include <fcntl.h> #include <stdio.h> -#include <unistd.h> - #include <sys/syscall.h> #include <unistd.h> + #include <string> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index bfa031bce..6b99c021d 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -130,6 +130,19 @@ void TcpSocketTest::TearDown() { } } +TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + ASSERT_THAT( + connect(s_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); + ASSERT_THAT( + connect(t_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); +} + TEST_P(TcpSocketTest, DataCoalesced) { char buf[10]; @@ -394,8 +407,15 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { sizeof(tcp_nodelay_flag)), SyscallSucceeds()); + // Set a 256KB send/receive buffer. + int buf_sz = 1 << 18; + EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + // Create a large buffer that will be used for sending. - std::vector<char> buf(10 * sendbuf_size_); + std::vector<char> buf(1 << 16); // Write until we receive an error. while (RetryEINTR(send)(s_, buf.data(), buf.size(), 0) != -1) { @@ -405,6 +425,11 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { } // The last error should have been EWOULDBLOCK. ASSERT_EQ(errno, EWOULDBLOCK); + + // Now polling on the FD with a timeout should return 0 corresponding to no + // FDs ready. + struct pollfd poll_fd = {s_, POLLOUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), SyscallSucceedsWithValue(0)); } TEST_P(TcpSocketTest, MsgTrunc) { @@ -942,6 +967,78 @@ TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) { EXPECT_THAT(close(s.release()), SyscallSucceeds()); } +// Test that connecting to a non-listening port and thus receiving a RST is +// handled appropriately by the socket - the port that the socket was bound to +// is released and the expected error is returned. +TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) { + // Create a socket that is known to not be listening. As is it bound but not + // listening, when another socket connects to the port, it will refuse.. + FileDescriptor bound_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage bound_addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t bound_addrlen = sizeof(bound_addr); + + ASSERT_THAT( + bind(bound_s.get(), reinterpret_cast<struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallSucceeds()); + + // Get the addresses the socket is bound to because the port is chosen by the + // stack. + ASSERT_THAT(getsockname(bound_s.get(), + reinterpret_cast<struct sockaddr*>(&bound_addr), + &bound_addrlen), + SyscallSucceeds()); + + // Create, initialize, and bind the socket that is used to test connecting to + // the non-listening port. + FileDescriptor client_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // Initialize client address to the loopback one. + sockaddr_storage client_addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t client_addrlen = sizeof(client_addr); + + ASSERT_THAT( + bind(client_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), + client_addrlen), + SyscallSucceeds()); + + ASSERT_THAT(getsockname(client_s.get(), + reinterpret_cast<struct sockaddr*>(&client_addr), + &client_addrlen), + SyscallSucceeds()); + + // Now the test: connect to the bound but not listening socket with the + // client socket. The bound socket should return a RST and cause the client + // socket to return an error and clean itself up immediately. + // The error being ECONNREFUSED diverges with RFC 793, page 37, but does what + // Linux does. + ASSERT_THAT(connect(client_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallFailsWithErrno(ECONNREFUSED)); + + FileDescriptor new_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Test binding to the address from the client socket. This should be okay + // if it was dropped correctly. + ASSERT_THAT( + bind(new_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), + client_addrlen), + SyscallSucceeds()); + + // Attempt #2, with the new socket and reused addr our connect should fail in + // the same way as before, not with an EADDRINUSE. + ASSERT_THAT(connect(client_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallFailsWithErrno(ECONNREFUSED)); +} + // Test that we get an ECONNREFUSED with a nonblocking socket. TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( @@ -1150,6 +1247,31 @@ TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) { } } +TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + { + constexpr int kTCPUserTimeout = -1; + EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallFailsWithErrno(EINVAL)); + } + + // kTCPUserTimeout is in milliseconds. + constexpr int kTCPUserTimeout = 100; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPUserTimeout); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index e5cc5d97c..c988c6380 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -19,6 +19,7 @@ #include <sys/vfs.h> #include <time.h> #include <unistd.h> + #include <iostream> #include <string> diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 111dbacdf..7a8ac30a4 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -12,1332 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <arpa/inet.h> -#include <fcntl.h> -#include <linux/errqueue.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" +#include "test/syscalls/linux/udp_socket_test_cases.h" namespace gvisor { namespace testing { namespace { -// The initial port to be be used on gvisor. -constexpr int TestPort = 40000; - -// Fixture for tests parameterized by the address family to use (AF_INET and -// AF_INET6) when creating sockets. -class UdpSocketTest : public ::testing::TestWithParam<AddressFamily> { - protected: - // Creates two sockets that will be used by test cases. - void SetUp() override; - - // Closes the sockets created by SetUp(). - void TearDown() override { - EXPECT_THAT(close(s_), SyscallSucceeds()); - EXPECT_THAT(close(t_), SyscallSucceeds()); - - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ASSERT_NO_ERRNO(FreeAvailablePort(ports_[i])); - } - } - - // First UDP socket. - int s_; - - // Second UDP socket. - int t_; - - // The length of the socket address. - socklen_t addrlen_; - - // Initialized address pointing to loopback and port TestPort+i. - struct sockaddr* addr_[3]; - - // Initialize "any" address. - struct sockaddr* anyaddr_; - - // Used ports. - int ports_[3]; - - private: - // Storage for the loopback addresses. - struct sockaddr_storage addr_storage_[3]; - - // Storage for the "any" address. - struct sockaddr_storage anyaddr_storage_; -}; - -// Gets a pointer to the port component of the given address. -uint16_t* Port(struct sockaddr_storage* addr) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - return &sin->sin_port; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - return &sin6->sin6_port; - } - } - - return nullptr; -} - -void UdpSocketTest::SetUp() { - int type; - if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; - auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - type = AF_INET6; - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin6); - if (GetParam() == AddressFamily::kIpv6) { - sin6->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - sin6->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - - ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - - memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); - anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_); - anyaddr_->sa_family = type; - - if (gvisor::testing::IsRunningOnGvisor()) { - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ports_[i] = TestPort + i; - } - } else { - // When not under gvisor, use utility function to pick port. Assert that - // all ports are different. - std::string error; - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - // Find an unused port, we specify port 0 to allow the kernel to provide - // the port. - bool unique = true; - do { - ports_[i] = ASSERT_NO_ERRNO_AND_VALUE(PortAvailable( - 0, AddressFamily::kDualStack, SocketType::kUdp, false)); - ASSERT_GT(ports_[i], 0); - for (size_t j = 0; j < i; ++j) { - if (ports_[j] == ports_[i]) { - unique = false; - break; - } - } - } while (!unique); - } - } - - // Initialize the sockaddrs. - for (size_t i = 0; i < ABSL_ARRAYSIZE(addr_); ++i) { - memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); - - addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]); - addr_[i]->sa_family = type; - - switch (type) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]); - sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - sin->sin_port = htons(ports_[i]); - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr_[i]); - sin6->sin6_addr = in6addr_loopback; - sin6->sin6_port = htons(ports_[i]); - break; - } - } - } -} - -TEST_P(UdpSocketTest, Creation) { - int type = AF_INET6; - if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; - } - - int s_; - - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); - - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); - - ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); -} - -TEST_P(UdpSocketTest, Getsockname) { - // Check that we're not bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, anyaddr_, addrlen_), 0); - - // Bind, then check that we get the right address. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, Getpeername) { - // Check that we're not connected. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Connect, then check that we get the right address. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, SendNotConnected) { - // Do send & write, they must fail. - char buf[512]; - EXPECT_THAT(send(s_, buf, sizeof(buf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); - - EXPECT_THAT(write(s_, buf, sizeof(buf)), SyscallFailsWithErrno(EDESTADDRREQ)); - - // Use sendto. - ASSERT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ConnectBinds) { - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ReceiveNotBound) { - char buf[512]; - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, Bind) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Try to bind again. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); - - // Check that we're still bound to the original address. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, BindInUse) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Try to bind again. - EXPECT_THAT(bind(t_, addr_[0], addrlen_), SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(UdpSocketTest, ReceiveAfterConnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - // Send from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - for (int i = 0; i < 2; i++) { - // Send from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Disconnect s_. - struct sockaddr addr = {}; - addr.sa_family = AF_UNSPEC; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds()); - // Connect s_ loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - } -} - -TEST_P(UdpSocketTest, Connect) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[0], addrlen_), 0); - - // Try to bind after connect. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); - - // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Check that peer name changed. - peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); -} - -void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { - struct sockaddr_storage addr = {}; - - // Precondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr any = IN6ADDR_ANY_INIT; - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); - } - - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } - - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } - - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - } - - // Postcondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback; - if (family == AddressFamily::kIpv6) { - loopback = IN6ADDR_LOOPBACK_INIT; - } else { - TestAddress const& v4_mapped_loopback = V4MappedLoopback(); - loopback = reinterpret_cast<const struct sockaddr_in6*>( - &v4_mapped_loopback.addr) - ->sin6_addr; - } - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } - - addrlen = sizeof(addr); - if (port == 0) { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } else { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - } - } -} - -TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } - -TEST_P(UdpSocketTest, ConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - ConnectAny(GetParam(), s_, port); -} - -void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { - struct sockaddr_storage addr = {}; - - socklen_t addrlen = sizeof(addr); - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } - - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - // Now the socket is bound to the loopback address. - - // Disconnect - addrlen = sizeof(addr); - addr.ss_family = AF_UNSPEC; - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - // Check that after disconnect the socket is bound to the ANY address. - EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback = IN6ADDR_ANY_INIT; - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { - DisconnectAfterConnectAny(GetParam(), s_, 0); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - DisconnectAfterConnectAny(GetParam(), s_, port); -} - -TEST_P(UdpSocketTest, DisconnectAfterBind) { - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { - struct sockaddr_storage baddr = {}; - socklen_t addrlen; - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addr_in->sin_family = AF_INET; - addr_in->sin_port = port; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - addr_in->sin6_scope_id = 0; - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } - ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), - SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, Disconnect) { - for (int i = 0; i < 2; i++) { - // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); - - // Try to disconnect. - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&addr), 0); - } -} - -TEST_P(UdpSocketTest, ConnectBadAddress) { - struct sockaddr addr = {}; - addr.sa_family = addr_[0]->sa_family; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send to a different destination than we're connected to. - char buf[512]; - EXPECT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[1], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); - // Receive the packet. - char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Set t_ to non-blocking. - int opts = 0; - ASSERT_THAT(opts = fcntl(t_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(t_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); - // Receive the packet. - char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send some data to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, SendAndReceiveConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveFromNotConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that the data isn't_ received because it was sent from a different - // address than we're connected. - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveBeforeConnect) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Connect to loopback:TestPort+1. - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Receive the data. It works because it was sent before the connect. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Send again. This time it should not be received. - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveFrom) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data and sender address. - char received[sizeof(buf)]; - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); -} - -TEST_P(UdpSocketTest, Listen) { - ASSERT_THAT(listen(s_, SOMAXCONN), SyscallFailsWithErrno(EOPNOTSUPP)); -} - -TEST_P(UdpSocketTest, Accept) { - ASSERT_THAT(accept(s_, nullptr, nullptr), SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// This test validates that a read shutdown with pending data allows the read -// to proceed with the data before returning EAGAIN. -TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { - char received[512]; - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Verify that we get EWOULDBLOCK when there is nothing to read. - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - const char* buf = "abc"; - EXPECT_THAT(write(t_, buf, 3), SyscallSucceedsWithValue(3)); - - int opts = 0; - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(s_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - ASSERT_NE(opts & O_NONBLOCK, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - // We should get the data even though read has been shutdown. - EXPECT_THAT(recv(s_, received, 2, 0), SyscallSucceedsWithValue(2)); - - // Because we read less than the entire packet length, since it's a packet - // based socket any subsequent reads should return EWOULDBLOCK. - EXPECT_THAT(recv(s_, received, 1, 0), SyscallFailsWithErrno(EWOULDBLOCK)); -} - -// This test is validating that even after a socket is shutdown if it's -// reconnected it will reset the shutdown state. -TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReadShutdown) { - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then shutdown from another thread. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - EXPECT_THAT(shutdown(this->s_, SHUT_RD), SyscallSucceeds()); - }); - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); - t.Join(); - - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, WriteShutdown) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SynchronousReceive) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send some data to s_ from another thread. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - // Receive the data prior to actually starting the other thread. - char received[512]; - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Start the thread. - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - ASSERT_THAT( - sendto(this->t_, buf, sizeof(buf), 0, this->addr_[0], this->addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - }); - - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(512)); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); - } - - // Receive the data as 3 separate packets. - char received[6 * psize]; - for (int i = 0; i < 3; ++i) { - EXPECT_THAT(recv(s_, received + i * psize, 3 * psize, 0), - SyscallSucceedsWithValue(psize)); - } - EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Direct writes from t_ to s_. - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(writev(t_, iov, 2), SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(readv(s_, iov, 3), SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_name = addr_[0]; - msg.msg_namelen = addrlen_; - msg.msg_iov = iov; - msg.msg_iovlen = 2; - ASSERT_THAT(sendmsg(t_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_iov = iov; - msg.msg_iovlen = 3; - ASSERT_THAT(recvmsg(s_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, FIONREADShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, FIONREADWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(s_, str, sizeof(str), 0), - SyscallSucceedsWithValue(sizeof(str))); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); -} - -TEST_P(UdpSocketTest, FIONREAD) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, psize); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, 0, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(0)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(s_, str, 0, 0), SyscallSucceedsWithValue(0)); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, ErrorQueue) { - char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. - EXPECT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_ERRQUEUE), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SoTimestampOffByDefault) { - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestamp) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - int v = 1; - ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); - - struct timeval tv = {}; - memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); - - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, WriteShutdownNotConnected) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, TimestampIoctl) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); -} - -TEST_P(UdpSocketTest, TimetstampIoctlNothingRead) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); -} - -// Test that the timestamp accessed via SIOCGSTAMP is still accessible after -// SO_TIMESTAMP is enabled and used to retrieve a timestamp. -TEST_P(UdpSocketTest, TimestampIoctlPersistence) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // Enable SO_TIMESTAMP and send a message. - int v = 1; - EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - // There should be a message for SO_TIMESTAMP. - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg = {}; - iovec iov = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - - // The ioctl should return the exact same values as before. - struct timeval tv2 = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv2), SyscallSucceeds()); - ASSERT_EQ(tv.tv_sec, tv2.tv_sec); - ASSERT_EQ(tv.tv_usec, tv2.tv_usec); -} - INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc new file mode 100644 index 000000000..147978f46 --- /dev/null +++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc @@ -0,0 +1,54 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/udp_socket_test_cases.h" + +#include <arpa/inet.h> +#include <fcntl.h> +#include <linux/errqueue.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +TEST_P(UdpSocketTest, ErrorQueue) { + char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + EXPECT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_ERRQUEUE), + SyscallFailsWithErrno(EAGAIN)); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc new file mode 100644 index 000000000..dc35c2f50 --- /dev/null +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -0,0 +1,1498 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/udp_socket_test_cases.h" + +#include <arpa/inet.h> +#include <fcntl.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +// Gets a pointer to the port component of the given address. +uint16_t* Port(struct sockaddr_storage* addr) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + return &sin->sin_port; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + return &sin6->sin6_port; + } + } + + return nullptr; +} + +void UdpSocketTest::SetUp() { + int type; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); + addrlen_ = sizeof(*sin); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + type = AF_INET6; + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); + addrlen_ = sizeof(*sin6); + if (GetParam() == AddressFamily::kIpv6) { + sin6->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + sin6->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); + + ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); + + memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); + anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_); + anyaddr_->sa_family = type; + + if (gvisor::testing::IsRunningOnGvisor()) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { + ports_[i] = TestPort + i; + } + } else { + // When not under gvisor, use utility function to pick port. Assert that + // all ports are different. + std::string error; + for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { + // Find an unused port, we specify port 0 to allow the kernel to provide + // the port. + bool unique = true; + do { + ports_[i] = ASSERT_NO_ERRNO_AND_VALUE(PortAvailable( + 0, AddressFamily::kDualStack, SocketType::kUdp, false)); + ASSERT_GT(ports_[i], 0); + for (size_t j = 0; j < i; ++j) { + if (ports_[j] == ports_[i]) { + unique = false; + break; + } + } + } while (!unique); + } + } + + // Initialize the sockaddrs. + for (size_t i = 0; i < ABSL_ARRAYSIZE(addr_); ++i) { + memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); + + addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]); + addr_[i]->sa_family = type; + + switch (type) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + sin->sin_port = htons(ports_[i]); + break; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr_[i]); + sin6->sin6_addr = in6addr_loopback; + sin6->sin6_port = htons(ports_[i]); + break; + } + } + } +} + +TEST_P(UdpSocketTest, Creation) { + int type = AF_INET6; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + } + + int s_; + + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); + EXPECT_THAT(close(s_), SyscallSucceeds()); + + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); + EXPECT_THAT(close(s_), SyscallSucceeds()); + + ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); +} + +TEST_P(UdpSocketTest, Getsockname) { + // Check that we're not bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, anyaddr_, addrlen_), 0); + + // Bind, then check that we get the right address. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); +} + +TEST_P(UdpSocketTest, Getpeername) { + // Check that we're not connected. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Connect, then check that we get the right address. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); +} + +TEST_P(UdpSocketTest, SendNotConnected) { + // Do send & write, they must fail. + char buf[512]; + EXPECT_THAT(send(s_, buf, sizeof(buf), 0), + SyscallFailsWithErrno(EDESTADDRREQ)); + + EXPECT_THAT(write(s_, buf, sizeof(buf)), SyscallFailsWithErrno(EDESTADDRREQ)); + + // Use sendto. + ASSERT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ConnectBinds) { + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ReceiveNotBound) { + char buf[512]; + EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, Bind) { + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Try to bind again. + EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); + + // Check that we're still bound to the original address. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); +} + +TEST_P(UdpSocketTest, BindInUse) { + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Try to bind again. + EXPECT_THAT(bind(t_, addr_[0], addrlen_), SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(UdpSocketTest, ReceiveAfterConnect) { + // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); + + // Get the address s_ was bound to during connect. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + // Send from t_ to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { + // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); + + // Get the address s_ was bound to during connect. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + for (int i = 0; i < 2; i++) { + // Send from t_ to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Disconnect s_. + struct sockaddr addr = {}; + addr.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds()); + // Connect s_ loopback:TestPort. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + } +} + +TEST_P(UdpSocketTest, Connect) { + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, addr_[0], addrlen_), 0); + + // Try to bind after connect. + EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); + + // Try to connect again. + EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + + // Check that peer name changed. + peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); +} + +void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { + struct sockaddr_storage addr = {}; + + // Precondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr any = IN6ADDR_ANY_INIT; + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); + } + + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } + + struct sockaddr_storage baddr = {}; + if (family == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), + SyscallSucceeds()); + } + + // Postcondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback; + if (family == AddressFamily::kIpv6) { + loopback = IN6ADDR_LOOPBACK_INIT; + } else { + TestAddress const& v4_mapped_loopback = V4MappedLoopback(); + loopback = reinterpret_cast<const struct sockaddr_in6*>( + &v4_mapped_loopback.addr) + ->sin6_addr; + } + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } + + addrlen = sizeof(addr); + if (port == 0) { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } else { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + } + } +} + +TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + ConnectAny(GetParam(), s_, port); +} + +void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { + struct sockaddr_storage addr = {}; + + socklen_t addrlen = sizeof(addr); + struct sockaddr_storage baddr = {}; + if (family == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), + SyscallSucceeds()); + // Now the socket is bound to the loopback address. + + // Disconnect + addrlen = sizeof(addr); + addr.ss_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + DisconnectAfterConnectAny(GetParam(), s_, 0); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + DisconnectAfterConnectAny(GetParam(), s_, port); +} + +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { + struct sockaddr_storage baddr = {}; + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addr_in->sin_family = AF_INET; + addr_in->sin_port = port; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + addr_in->sin6_scope_id = 0; + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } + ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), + SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + // If the socket is bound to ANY and connected to a loopback address, + // getsockname() has to return the loopback address. + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { + struct sockaddr_storage baddr = {}; + socklen_t addrlen; + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addr_in->sin_family = AF_INET; + addr_in->sin_port = port; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + addr_in->sin6_scope_id = 0; + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } + ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), + SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, Disconnect) { + for (int i = 0; i < 2; i++) { + // Try to connect again. + EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); + + // Try to disconnect. + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&addr), 0); + } +} + +TEST_P(UdpSocketTest, ConnectBadAddress) { + struct sockaddr addr = {}; + addr.sa_family = addr_[0]->sa_family; + ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Send to a different destination than we're connected to. + char buf[512]; + EXPECT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[1], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + + // Bind t_ to loopback:TestPort+1. + ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from s_ to t_. + ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); + // Receive the packet. + char received[3]; + EXPECT_THAT(read(t_, received, sizeof(received)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + + // Bind t_ to loopback:TestPort+1. + ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + + // Set t_ to non-blocking. + int opts = 0; + ASSERT_THAT(opts = fcntl(t_, F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(t_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from s_ to t_. + ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); + // Receive the packet. + char received[3]; + EXPECT_THAT(read(t_, received, sizeof(received)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(read(t_, received, sizeof(received)), + SyscallFailsWithErrno(EAGAIN)); +} + +TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { + // Bind s_ to loopback. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Send some data to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, SendAndReceiveConnected) { + // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + + // Bind t_ to loopback:TestPort+1. + ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + + // Send some data from t_ to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveFromNotConnected) { + // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + + // Bind t_ to loopback:TestPort+2. + ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); + + // Send some data from t_ to s_. + char buf[512]; + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that the data isn't_ received because it was sent from a different + // address than we're connected. + EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveBeforeConnect) { + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Bind t_ to loopback:TestPort+2. + ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); + + // Send some data from t_ to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Connect to loopback:TestPort+1. + ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + + // Receive the data. It works because it was sent before the connect. + char received[sizeof(buf)]; + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Send again. This time it should not be received. + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveFrom) { + // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + + // Bind t_ to loopback:TestPort+1. + ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + + // Send some data from t_ to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data and sender address. + char received[sizeof(buf)]; + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0, + reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); +} + +TEST_P(UdpSocketTest, Listen) { + ASSERT_THAT(listen(s_, SOMAXCONN), SyscallFailsWithErrno(EOPNOTSUPP)); +} + +TEST_P(UdpSocketTest, Accept) { + ASSERT_THAT(accept(s_, nullptr, nullptr), SyscallFailsWithErrno(EOPNOTSUPP)); +} + +// This test validates that a read shutdown with pending data allows the read +// to proceed with the data before returning EAGAIN. +TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { + char received[512]; + + // Bind t_ to loopback:TestPort+2. + ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); + + // Connect the socket, then try to shutdown again. + ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + + // Verify that we get EWOULDBLOCK when there is nothing to read. + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + const char* buf = "abc"; + EXPECT_THAT(write(t_, buf, 3), SyscallSucceedsWithValue(3)); + + int opts = 0; + ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(s_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); + ASSERT_NE(opts & O_NONBLOCK, 0); + + EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + + // We should get the data even though read has been shutdown. + EXPECT_THAT(recv(s_, received, 2, 0), SyscallSucceedsWithValue(2)); + + // Because we read less than the entire packet length, since it's a packet + // based socket any subsequent reads should return EWOULDBLOCK. + EXPECT_THAT(recv(s_, received, 1, 0), SyscallFailsWithErrno(EWOULDBLOCK)); +} + +// This test is validating that even after a socket is shutdown if it's +// reconnected it will reset the shutdown state. +TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { + char received[512]; + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReadShutdown) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + + char received[512]; + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + + char received[512]; + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then shutdown from another thread. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + EXPECT_THAT(shutdown(this->s_, SHUT_RD), SyscallSucceeds()); + }); + EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); + t.Join(); + + EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, WriteShutdown) { + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SynchronousReceive) { + // Bind s_ to loopback. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Send some data to s_ from another thread. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + // Receive the data prior to actually starting the other thread. + char received[512]; + EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Start the thread. + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + ASSERT_THAT( + sendto(this->t_, buf, sizeof(buf), 0, this->addr_[0], this->addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + }); + + EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(512)); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Send 3 packets from t_ to s_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 3; ++i) { + ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(psize)); + } + + // Receive the data as 3 separate packets. + char received[6 * psize]; + for (int i = 0; i < 3; ++i) { + EXPECT_THAT(recv(s_, received + i * psize, 3 * psize, 0), + SyscallSucceedsWithValue(psize)); + } + EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Direct writes from t_ to s_. + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + // Send 2 packets from t_ to s_, where each packet's data consists of 2 + // discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(writev(t_, iov, 2), SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(readv(s_, iov, 3), SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Send 2 packets from t_ to s_, where each packet's data consists of 2 + // discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_name = addr_[0]; + msg.msg_namelen = addrlen_; + msg.msg_iov = iov; + msg.msg_iovlen = 2; + ASSERT_THAT(sendmsg(t_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_iov = iov; + msg.msg_iovlen = 3; + ASSERT_THAT(recvmsg(s_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, FIONREADShutdown) { + int n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, FIONREADWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(s_, str, sizeof(str), 0), + SyscallSucceedsWithValue(sizeof(str))); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); + + EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); +} + +TEST_P(UdpSocketTest, Fionread) { + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from t_ to s_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 3; ++i) { + ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(psize)); + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, psize); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from t_ to s_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 3; ++i) { + ASSERT_THAT(sendto(t_, buf + i * psize, 0, 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(0)); + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Bind s_ to loopback:TestPort. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(s_, str, 0, 0), SyscallSucceedsWithValue(0)); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, SoTimestampOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoTimestamp) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + int v = 1; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from t_ to s_. + ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); + + struct timeval tv = {}; + memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); + + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // There should be nothing to get via ioctl. + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, WriteShutdownNotConnected) { + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, TimestampIoctl) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from t_ to s_. + ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); +} + +TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + struct timeval tv = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); +} + +// Test that the timestamp accessed via SIOCGSTAMP is still accessible after +// SO_TIMESTAMP is enabled and used to retrieve a timestamp. +TEST_P(UdpSocketTest, TimestampIoctlPersistence) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from t_ to s_. + ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // Enable SO_TIMESTAMP and send a message. + int v = 1; + EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + + // There should be a message for SO_TIMESTAMP. + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg = {}; + iovec iov = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + + // The ioctl should return the exact same values as before. + struct timeval tv2 = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv2), SyscallSucceeds()); + ASSERT_EQ(tv.tv_sec, tv2.tv_sec); + ASSERT_EQ(tv.tv_usec, tv2.tv_usec); +} + +// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on +// outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SetAndReceiveTOS) { + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + ASSERT_THAT( + setsockopt(s_, recv_level, recv_type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Set socket TOS. + int sent_level = recv_level; + int sent_type = IP_TOS; + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + } + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + ASSERT_THAT( + setsockopt(t_, sent_level, sent_type, &sent_tos, sizeof(sent_tos)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_type == IPV6_TCLASS) { + cmsg_data_len = sizeof(int); + } + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = received_cmsgbuf.size(); + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the +// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SendAndReceiveTOS) { + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + int recv_opt = kSockOptOn; + ASSERT_THAT( + setsockopt(s_, recv_level, recv_type, &recv_opt, sizeof(recv_opt)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + int sent_level = recv_level; + int sent_type = IP_TOS; + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + cmsg_data_len = sizeof(int); + } + std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + sent_msg.msg_control = &sent_cmsgbuf[0]; + sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + + // Manually add control message. + struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); + sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); + sent_cmsg->cmsg_level = sent_level; + sent_cmsg->cmsg_type = sent_type; + *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; + + ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h new file mode 100644 index 000000000..2fd79d99e --- /dev/null +++ b/test/syscalls/linux/udp_socket_test_cases.h @@ -0,0 +1,74 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ +#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// The initial port to be be used on gvisor. +constexpr int TestPort = 40000; + +// Fixture for tests parameterized by the address family to use (AF_INET and +// AF_INET6) when creating sockets. +class UdpSocketTest + : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { + protected: + // Creates two sockets that will be used by test cases. + void SetUp() override; + + // Closes the sockets created by SetUp(). + void TearDown() override { + EXPECT_THAT(close(s_), SyscallSucceeds()); + EXPECT_THAT(close(t_), SyscallSucceeds()); + + for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { + ASSERT_NO_ERRNO(FreeAvailablePort(ports_[i])); + } + } + + // First UDP socket. + int s_; + + // Second UDP socket. + int t_; + + // The length of the socket address. + socklen_t addrlen_; + + // Initialized address pointing to loopback and port TestPort+i. + struct sockaddr* addr_[3]; + + // Initialize "any" address. + struct sockaddr* anyaddr_; + + // Used ports. + int ports_[3]; + + private: + // Storage for the loopback addresses. + struct sockaddr_storage addr_storage_[3]; + + // Storage for the "any" address. + struct sockaddr_storage anyaddr_storage_; +}; + +} // namespace testing +} // namespace gvisor + +#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/unix_domain_socket_test_util.cc b/test/syscalls/linux/unix_domain_socket_test_util.cc index 7fb9eed8d..b05ab2900 100644 --- a/test/syscalls/linux/unix_domain_socket_test_util.cc +++ b/test/syscalls/linux/unix_domain_socket_test_util.cc @@ -15,6 +15,7 @@ #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include <sys/un.h> + #include <vector> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/unix_domain_socket_test_util.h b/test/syscalls/linux/unix_domain_socket_test_util.h index 5eca0b7f0..b8073db17 100644 --- a/test/syscalls/linux/unix_domain_socket_test_util.h +++ b/test/syscalls/linux/unix_domain_socket_test_util.h @@ -16,6 +16,7 @@ #define GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_ #include <string> + #include "test/syscalls/linux/socket_test_util.h" namespace gvisor { diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc index 80716859a..12b925a51 100644 --- a/test/syscalls/linux/utimes.cc +++ b/test/syscalls/linux/utimes.cc @@ -20,6 +20,7 @@ #include <time.h> #include <unistd.h> #include <utime.h> + #include <string> #include "absl/time/time.h" diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc index 40c0014b9..ce1899f45 100644 --- a/test/syscalls/linux/vdso_clock_gettime.cc +++ b/test/syscalls/linux/vdso_clock_gettime.cc @@ -17,6 +17,7 @@ #include <syscall.h> #include <time.h> #include <unistd.h> + #include <map> #include <string> #include <utility> diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc new file mode 100644 index 000000000..75740238c --- /dev/null +++ b/test/syscalls/linux/xattr.cc @@ -0,0 +1,488 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <limits.h> +#include <sys/types.h> +#include <sys/xattr.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/file_base.h" +#include "test/util/capability_util.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class XattrTest : public FileTest {}; + +TEST_F(XattrTest, XattrNullName) { + const char* path = test_file_name_.c_str(); + + EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EFAULT)); + EXPECT_THAT(getxattr(path, nullptr, nullptr, 0), + SyscallFailsWithErrno(EFAULT)); +} + +TEST_F(XattrTest, XattrEmptyName) { + const char* path = test_file_name_.c_str(); + + EXPECT_THAT(setxattr(path, "", nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(ERANGE)); + EXPECT_THAT(getxattr(path, "", nullptr, 0), SyscallFailsWithErrno(ERANGE)); +} + +TEST_F(XattrTest, XattrLargeName) { + const char* path = test_file_name_.c_str(); + std::string name = "user."; + name += std::string(XATTR_NAME_MAX - name.length(), 'a'); + + // TODO(b/127675828): Support setxattr and getxattr. + if (!IsRunningOnGvisor()) { + EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), + SyscallSucceedsWithValue(0)); + } + + name += "a"; + EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(ERANGE)); + EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), + SyscallFailsWithErrno(ERANGE)); +} + +TEST_F(XattrTest, XattrInvalidPrefix) { + const char* path = test_file_name_.c_str(); + std::string name(XATTR_NAME_MAX, 'a'); + EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EOPNOTSUPP)); + EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +TEST_F(XattrTest, XattrReadOnly) { + // Drop capabilities that allow us to override file and directory permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + size_t size = sizeof(val); + + // TODO(b/127675828): Support setxattr and getxattr. + if (!IsRunningOnGvisor()) { + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), + SyscallSucceeds()); + } + + ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IRUSR)); + + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), + SyscallFailsWithErrno(EACCES)); + + // TODO(b/127675828): Support setxattr and getxattr. + if (!IsRunningOnGvisor()) { + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, size), + SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, val); + } +} + +TEST_F(XattrTest, XattrWriteOnly) { + // Drop capabilities that allow us to override file and directory permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR)); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + size_t size = sizeof(val); + + // TODO(b/127675828): Support setxattr and getxattr. + if (!IsRunningOnGvisor()) { + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), + SyscallSucceeds()); + } + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES)); +} + +TEST_F(XattrTest, XattrTrustedWithNonadmin) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.abc"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, XattrOnDirectory) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + char name[] = "user.abc"; + EXPECT_THAT(setxattr(dir.path().c_str(), name, NULL, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(getxattr(dir.path().c_str(), name, NULL, 0), + SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, XattrOnSymlink) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(dir.path(), test_file_name_)); + char name[] = "user.abc"; + EXPECT_THAT(setxattr(link.path().c_str(), name, NULL, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(getxattr(link.path().c_str(), name, NULL, 0), + SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, XattrOnInvalidFileTypes) { + char name[] = "user.abc"; + + char char_device[] = "/dev/zero"; + EXPECT_THAT(setxattr(char_device, name, NULL, 0, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(getxattr(char_device, name, NULL, 0), + SyscallFailsWithErrno(ENODATA)); + + // Use tmpfs, where creation of named pipes is supported. + const std::string fifo = NewTempAbsPathInDir("/dev/shm"); + const char* path = fifo.c_str(); + EXPECT_THAT(mknod(path, S_IFIFO | S_IRUSR | S_IWUSR, 0), SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, NULL, 0, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(getxattr(path, name, NULL, 0), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + std::vector<char> val = {'a', 'a'}; + size_t size = 1; + EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + std::vector<char> expected_buf = {'a', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), + SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, SetxattrZeroSize) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds()); + + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, XATTR_SIZE_MAX), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(buf, '-'); +} + +TEST_F(XattrTest, SetxattrSizeTooLarge) { + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + + // Note that each particular fs implementation may stipulate a lower size + // limit, in which case we actually may fail (e.g. error with ENOSPC) for + // some sizes under XATTR_SIZE_MAX. + size_t size = XATTR_SIZE_MAX + 1; + std::vector<char> val(size); + EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), + SyscallFailsWithErrno(E2BIG)); + + // TODO(b/127675828): Support setxattr and getxattr. + if (!IsRunningOnGvisor()) { + EXPECT_THAT(getxattr(path, name, nullptr, 0), + SyscallFailsWithErrno(ENODATA)); + } +} + +TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0), + SyscallFailsWithErrno(EFAULT)); + + // TODO(b/127675828): Support setxattr and getxattr. + if (!IsRunningOnGvisor()) { + EXPECT_THAT(getxattr(path, name, nullptr, 0), + SyscallFailsWithErrno(ENODATA)); + } +} + +TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + std::vector<char> val(XATTR_SIZE_MAX + 1); + std::fill(val.begin(), val.end(), 'a'); + size_t size = 1; + EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + std::vector<char> expected_buf = {'a', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), size), + SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, SetxattrReplaceWithSmaller) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + std::vector<char> val = {'a', 'a'}; + EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + std::vector<char> expected_buf = {'a', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(1)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, SetxattrReplaceWithLarger) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + std::vector<char> val = {'a', 'a'}; + EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(2)); + EXPECT_EQ(buf, val); +} + +TEST_F(XattrTest, SetxattrCreateFlag) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), + SyscallFailsWithErrno(EEXIST)); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, SetxattrReplaceFlag) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), + SyscallFailsWithErrno(ENODATA)); + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), + SyscallSucceeds()); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, SetxattrInvalidFlags) { + const char* path = test_file_name_.c_str(); + int invalid_flags = 0xff; + EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_F(XattrTest, Getxattr) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + int val = 1234; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + int buf = 0; + EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, val); +} + +TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + std::vector<char> val = {'a', 'a'}; + size_t size = val.size(); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, 1), SyscallFailsWithErrno(ERANGE)); + EXPECT_EQ(buf, '-'); +} + +TEST_F(XattrTest, GetxattrSizeLargerThanValue) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, 1, /*flags=*/0), SyscallSucceeds()); + + std::vector<char> buf(XATTR_SIZE_MAX); + std::fill(buf.begin(), buf.end(), '-'); + std::vector<char> expected_buf = buf; + expected_buf[0] = 'a'; + EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), + SyscallSucceedsWithValue(1)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, GetxattrZeroSize) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, 0), + SyscallSucceedsWithValue(sizeof(val))); + EXPECT_EQ(buf, '-'); +} + +TEST_F(XattrTest, GetxattrSizeTooLarge) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf(XATTR_SIZE_MAX + 1); + std::fill(buf.begin(), buf.end(), '-'); + std::vector<char> expected_buf = buf; + expected_buf[0] = 'a'; + EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), + SyscallSucceedsWithValue(sizeof(val))); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, GetxattrNullValue) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + EXPECT_THAT(getxattr(path, name, nullptr, size), + SyscallFailsWithErrno(EFAULT)); +} + +TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { + // TODO(b/127675828): Support setxattr and getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + char name[] = "user.abc"; + char val = 'a'; + size_t size = sizeof(val); + // Set value with zero size. + EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds()); + // Get value with nonzero size. + EXPECT_THAT(getxattr(path, name, nullptr, size), SyscallSucceedsWithValue(0)); + + // Set value with nonzero size. + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + // Get value with zero size. + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size)); +} + +TEST_F(XattrTest, GetxattrNonexistentName) { + // TODO(b/127675828): Support getxattr. + SKIP_IF(IsRunningOnGvisor()); + + const char* path = test_file_name_.c_str(); + std::string name = "user.nonexistent"; + EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), + SyscallFailsWithErrno(ENODATA)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index c1e9ce22c..b9fd885ff 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/runsc/specutils" "gvisor.dev/gvisor/runsc/testutil" "gvisor.dev/gvisor/test/syscalls/gtest" + "gvisor.dev/gvisor/test/uds" ) // Location of syscall tests, relative to the repo root. @@ -45,11 +46,14 @@ var ( debug = flag.Bool("debug", false, "enable debug logs") strace = flag.Bool("strace", false, "enable strace logs") platform = flag.String("platform", "ptrace", "platform to run on") + network = flag.String("network", "none", "network stack to run on (sandbox, host, none)") useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp") fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") parallel = flag.Bool("parallel", false, "run tests in parallel") runscPath = flag.String("runsc", "", "path to runsc binary") + + addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests") ) // runTestCaseNative runs the test case directly on the host machine. @@ -86,6 +90,19 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { // intepret them. env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + if *addUDSTree { + socketDir, cleanup, err := uds.CreateSocketTree("/tmp") + if err != nil { + t.Fatalf("failed to create socket tree: %v", err) + } + defer cleanup() + + env = append(env, "TEST_UDS_TREE="+socketDir) + // On Linux, the concept of "attach" location doesn't exist. + // Just pass the same path to make these test identical. + env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) + } + cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName()) cmd.Env = env cmd.Stdout = os.Stdout @@ -96,101 +113,39 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { } } -// runsTestCaseRunsc runs the test case in runsc. -func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { - rootDir, err := testutil.SetupRootDir() +// runRunsc runs spec in runsc in a standard test configuration. +// +// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. +// +// Returns an error if the sandboxed application exits non-zero. +func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { + bundleDir, err := testutil.SetupBundleDir(spec) if err != nil { - t.Fatalf("SetupRootDir failed: %v", err) - } - defer os.RemoveAll(rootDir) - - // Run a new container with the test executable and filter for the - // given test suite and name. - spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) - - // Mark the root as writeable, as some tests attempt to - // write to the rootfs, and expect EACCES, not EROFS. - spec.Root.Readonly = false - - // Test spec comes with pre-defined mounts that we don't want. Reset it. - spec.Mounts = nil - if *useTmpfs { - // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on - // features only available in gVisor's internal tmpfs may fail. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp", - Type: "tmpfs", - }) - } else { - // Use a gofer-backed directory as '/tmp'. - // - // Tests might be running in parallel, so make sure each has a - // unique test temp dir. - // - // Some tests (e.g., sticky) access this mount from other - // users, so make sure it is world-accessible. - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") - if err != nil { - t.Fatalf("could not create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - if err := os.Chmod(tmpDir, 0777); err != nil { - t.Fatalf("could not chmod temp dir: %v", err) - } - - spec.Mounts = append(spec.Mounts, specs.Mount{ - Type: "bind", - Destination: "/tmp", - Source: tmpDir, - }) + return fmt.Errorf("SetupBundleDir failed: %v", err) } + defer os.RemoveAll(bundleDir) - // Set environment variable that indicates we are - // running in gVisor and with the given platform. - platformVar := "TEST_ON_GVISOR" - env := append(os.Environ(), platformVar+"="+*platform) - - // Remove env variables that cause the gunit binary to write output - // files, since they will stomp on eachother, and on the output files - // from this go test. - env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) - - // Remove shard env variables so that the gunit binary does not try to - // intepret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) - - // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to - // be backed by tmpfs. - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = "TEST_TMPDIR=/tmp" - break - } - } - - spec.Process.Env = env - - bundleDir, err := testutil.SetupBundleDir(spec) + rootDir, err := testutil.SetupRootDir() if err != nil { - t.Fatalf("SetupBundleDir failed: %v", err) + return fmt.Errorf("SetupRootDir failed: %v", err) } - defer os.RemoveAll(bundleDir) + defer os.RemoveAll(rootDir) + name := tc.FullName() id := testutil.UniqueContainerID() - log.Infof("Running test %q in container %q", tc.FullName(), id) + log.Infof("Running test %q in container %q", name, id) specutils.LogSpec(spec) args := []string{ - "-platform", *platform, "-root", rootDir, - "-file-access", *fileAccess, - "-network=none", + "-network", *network, "-log-format=text", "-TESTONLY-unsafe-nonroot=true", "-net-raw=true", fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM), "-watchdog-action=panic", + "-platform", *platform, + "-file-access", *fileAccess, } if *overlay { args = append(args, "-overlay") @@ -201,14 +156,18 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { if *strace { args = append(args, "-strace") } + if *addUDSTree { + args = append(args, "-fsgofer-host-uds") + } + if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - tdir := filepath.Join(outDir, strings.Replace(tc.FullName(), "/", "_", -1)) + tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1)) if err := os.MkdirAll(tdir, 0755); err != nil { - t.Fatalf("could not create test dir: %v", err) + return fmt.Errorf("could not create test dir: %v", err) } debugLogDir, err := ioutil.TempDir(tdir, "runsc") if err != nil { - t.Fatalf("could not create temp dir: %v", err) + return fmt.Errorf("could not create temp dir: %v", err) } debugLogDir += "/" log.Infof("runsc logs: %s", debugLogDir) @@ -248,38 +207,172 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { if !ok { return } - t.Errorf("%s: Got signal: %v", tc.FullName(), s) + log.Warningf("%s: Got signal: %v", name, s) done := make(chan bool) - go func() { - dArgs := append(args, "-alsologtostderr=true", "debug", "--stacks", id) + dArgs := append([]string{}, args...) + dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) + go func(dArgs []string) { cmd := exec.Command(*runscPath, dArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Run() done <- true - }() + }(dArgs) - timeout := time.Tick(3 * time.Second) + timeout := time.After(3 * time.Second) select { case <-timeout: - t.Logf("runsc debug --stacks is timeouted") + log.Infof("runsc debug --stacks is timeouted") case <-done: } - t.Logf("Send SIGTERM to the sandbox process") - dArgs := append(args, "debug", + log.Warningf("Send SIGTERM to the sandbox process") + dArgs = append(args, "debug", fmt.Sprintf("--signal=%d", syscall.SIGTERM), id) - cmd = exec.Command(*runscPath, dArgs...) + cmd := exec.Command(*runscPath, dArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Run() }() - if err = cmd.Run(); err != nil { - t.Errorf("test %q exited with status %v, want 0", tc.FullName(), err) - } + + err = cmd.Run() + signal.Stop(sig) close(sig) + + return err +} + +// setupUDSTree updates the spec to expose a UDS tree for gofer socket testing. +func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { + socketDir, cleanup, err := uds.CreateSocketTree("/tmp") + if err != nil { + return nil, fmt.Errorf("failed to create socket tree: %v", err) + } + + // Standard access to entire tree. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets", + Source: socketDir, + Type: "bind", + }) + + // Individial attach points for each socket to test mounts that attach + // directly to the sockets. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/stream/echo", + Source: filepath.Join(socketDir, "stream/echo"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/stream/nonlistening", + Source: filepath.Join(socketDir, "stream/nonlistening"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/seqpacket/echo", + Source: filepath.Join(socketDir, "seqpacket/echo"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/seqpacket/nonlistening", + Source: filepath.Join(socketDir, "seqpacket/nonlistening"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/dgram/null", + Source: filepath.Join(socketDir, "dgram/null"), + Type: "bind", + }) + + spec.Process.Env = append(spec.Process.Env, "TEST_UDS_TREE=/tmp/sockets") + spec.Process.Env = append(spec.Process.Env, "TEST_UDS_ATTACH_TREE=/tmp/sockets-attach") + + return cleanup, nil +} + +// runsTestCaseRunsc runs the test case in runsc. +func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { + // Run a new container with the test executable and filter for the + // given test suite and name. + spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + + // Mark the root as writeable, as some tests attempt to + // write to the rootfs, and expect EACCES, not EROFS. + spec.Root.Readonly = false + + // Test spec comes with pre-defined mounts that we don't want. Reset it. + spec.Mounts = nil + if *useTmpfs { + // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on + // features only available in gVisor's internal tmpfs may fail. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp", + Type: "tmpfs", + }) + } else { + // Use a gofer-backed directory as '/tmp'. + // + // Tests might be running in parallel, so make sure each has a + // unique test temp dir. + // + // Some tests (e.g., sticky) access this mount from other + // users, so make sure it is world-accessible. + tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") + if err != nil { + t.Fatalf("could not create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + if err := os.Chmod(tmpDir, 0777); err != nil { + t.Fatalf("could not chmod temp dir: %v", err) + } + + spec.Mounts = append(spec.Mounts, specs.Mount{ + Type: "bind", + Destination: "/tmp", + Source: tmpDir, + }) + } + + // Set environment variables that indicate we are + // running in gVisor with the given platform and network. + platformVar := "TEST_ON_GVISOR" + networkVar := "GVISOR_NETWORK" + env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) + + // Remove env variables that cause the gunit binary to write output + // files, since they will stomp on eachother, and on the output files + // from this go test. + env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) + + // Remove shard env variables so that the gunit binary does not try to + // intepret them. + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + + // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to + // be backed by tmpfs. + for i, kv := range env { + if strings.HasPrefix(kv, "TEST_TMPDIR=") { + env[i] = "TEST_TMPDIR=/tmp" + break + } + } + + spec.Process.Env = env + + if *addUDSTree { + cleanup, err := setupUDSTree(spec) + if err != nil { + t.Fatalf("error creating UDS tree: %v", err) + } + defer cleanup() + } + + if err := runRunsc(tc, spec); err != nil { + t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) + } } // filterEnv returns an environment with the blacklisted variables removed. diff --git a/test/uds/BUILD b/test/uds/BUILD new file mode 100644 index 000000000..a3843e699 --- /dev/null +++ b/test/uds/BUILD @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +go_library( + name = "uds", + testonly = 1, + srcs = ["uds.go"], + importpath = "gvisor.dev/gvisor/test/uds", + deps = [ + "//pkg/log", + "//pkg/unet", + ], +) diff --git a/test/uds/uds.go b/test/uds/uds.go new file mode 100644 index 000000000..b714c61b0 --- /dev/null +++ b/test/uds/uds.go @@ -0,0 +1,228 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package uds contains helpers for testing external UDS functionality. +package uds + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "syscall" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/unet" +) + +// createEchoSocket creates a socket that echoes back anything received. +// +// Only works for stream, seqpacket sockets. +func createEchoSocket(path string, protocol int) (cleanup func(), err error) { + fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + if err != nil { + return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err) + } + + if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err) + } + + if err := syscall.Listen(fd, 0); err != nil { + return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err) + } + + server, err := unet.NewServerSocket(fd) + if err != nil { + return nil, fmt.Errorf("error creating echo(%d) unet socket: %v", protocol, err) + } + + acceptAndEchoOne := func() error { + s, err := server.Accept() + if err != nil { + return fmt.Errorf("failed to accept: %v", err) + } + defer s.Close() + + for { + buf := make([]byte, 512) + for { + n, err := s.Read(buf) + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("failed to read: %d, %v", n, err) + } + + n, err = s.Write(buf[:n]) + if err != nil { + return fmt.Errorf("failed to write: %d, %v", n, err) + } + } + } + } + + go func() { + for { + if err := acceptAndEchoOne(); err != nil { + log.Warningf("Failed to handle echo(%d) socket: %v", protocol, err) + return + } + } + }() + + cleanup = func() { + if err := server.Close(); err != nil { + log.Warningf("Failed to close echo(%d) socket: %v", protocol, err) + } + } + + return cleanup, nil +} + +// createNonListeningSocket creates a socket that is bound but not listening. +// +// Only relevant for stream, seqpacket sockets. +func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) { + fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + if err != nil { + return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err) + } + + if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err) + } + + cleanup = func() { + if err := syscall.Close(fd); err != nil { + log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err) + } + } + + return cleanup, nil +} + +// createNullSocket creates a socket that reads anything received. +// +// Only works for dgram sockets. +func createNullSocket(path string, protocol int) (cleanup func(), err error) { + fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + if err != nil { + return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err) + } + + if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err) + } + + s, err := unet.NewSocket(fd) + if err != nil { + return nil, fmt.Errorf("error creating null(%d) unet socket: %v", protocol, err) + } + + go func() { + buf := make([]byte, 512) + for { + n, err := s.Read(buf) + if err != nil { + log.Warningf("failed to read: %d, %v", n, err) + return + } + } + }() + + cleanup = func() { + if err := s.Close(); err != nil { + log.Warningf("Failed to close null(%d) socket: %v", protocol, err) + } + } + + return cleanup, nil +} + +type socketCreator func(path string, proto int) (cleanup func(), err error) + +// CreateSocketTree creates a local tree of unix domain sockets for use in +// testing: +// * /stream/echo +// * /stream/nonlistening +// * /seqpacket/echo +// * /seqpacket/nonlistening +// * /dgram/null +func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) { + dir, err = ioutil.TempDir(baseDir, "sockets") + if err != nil { + return "", nil, fmt.Errorf("error creating temp dir: %v", err) + } + + var protocols = []struct { + protocol int + name string + sockets map[string]socketCreator + }{ + { + protocol: syscall.SOCK_STREAM, + name: "stream", + sockets: map[string]socketCreator{ + "echo": createEchoSocket, + "nonlistening": createNonListeningSocket, + }, + }, + { + protocol: syscall.SOCK_SEQPACKET, + name: "seqpacket", + sockets: map[string]socketCreator{ + "echo": createEchoSocket, + "nonlistening": createNonListeningSocket, + }, + }, + { + protocol: syscall.SOCK_DGRAM, + name: "dgram", + sockets: map[string]socketCreator{ + "null": createNullSocket, + }, + }, + } + + var cleanups []func() + for _, proto := range protocols { + protoDir := filepath.Join(dir, proto.name) + if err := os.Mkdir(protoDir, 0755); err != nil { + return "", nil, fmt.Errorf("error creating %s dir: %v", proto.name, err) + } + + for name, fn := range proto.sockets { + path := filepath.Join(protoDir, name) + cleanup, err := fn(path, proto.protocol) + if err != nil { + return "", nil, fmt.Errorf("error creating %s %s socket: %v", proto.name, name, err) + } + + cleanups = append(cleanups, cleanup) + } + } + + cleanup = func() { + for _, c := range cleanups { + c() + } + + os.RemoveAll(dir) + } + + return dir, cleanup, nil +} diff --git a/test/util/BUILD b/test/util/BUILD index 5d2a9cc2c..cbc728159 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -232,13 +232,21 @@ cc_library( cc_library( name = "test_util", testonly = 1, - srcs = ["test_util.cc"], + srcs = [ + "test_util.cc", + ] + select_for_linux( + [ + "test_util_impl.cc", + "test_util_runfiles.cc", + ], + ), hdrs = ["test_util.h"], deps = [ ":fs_util", ":logging", ":posix_error", ":save_util", + "@bazel_tools//tools/cpp/runfiles", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index f7d231b14..042cec94a 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -105,6 +105,15 @@ PosixErrorOr<struct stat> Stat(absl::string_view path) { return stat_buf; } +PosixErrorOr<struct stat> Lstat(absl::string_view path) { + struct stat stat_buf; + int res = lstat(std::string(path).c_str(), &stat_buf); + if (res < 0) { + return PosixError(errno, absl::StrCat("lstat ", path)); + } + return stat_buf; +} + PosixErrorOr<struct stat> Fstat(int fd) { struct stat stat_buf; int res = fstat(fd, &stat_buf); @@ -127,7 +136,7 @@ PosixErrorOr<bool> Exists(absl::string_view path) { } PosixErrorOr<bool> IsDirectory(absl::string_view path) { - ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Stat(path)); + ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Lstat(path)); if (S_ISDIR(stat_buf.st_mode)) { return true; } @@ -163,6 +172,26 @@ PosixError Chmod(absl::string_view path, int mode) { return NoError(); } +PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, + dev_t dev) { + int res = mknodat(dfd.get(), std::string(path).c_str(), mode, dev); + if (res < 0) { + return PosixError(errno, absl::StrCat("mknod ", path)); + } + + return NoError(); +} + +PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, + int flags) { + int res = unlinkat(dfd.get(), std::string(path).c_str(), flags); + if (res < 0) { + return PosixError(errno, absl::StrCat("unlink ", path)); + } + + return NoError(); +} + PosixError Mkdir(absl::string_view path, int mode) { int res = mkdir(std::string(path).c_str(), mode); if (res < 0) { diff --git a/test/util/fs_util.h b/test/util/fs_util.h index e5b555891..ee1b341d7 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -21,6 +21,7 @@ #include <unistd.h> #include "absl/strings/string_view.h" +#include "test/util/file_descriptor.h" #include "test/util/posix_error.h" namespace gvisor { @@ -44,6 +45,14 @@ PosixError Delete(absl::string_view path); // Changes the mode of a file or returns an error. PosixError Chmod(absl::string_view path, int mode); +// Create a special or ordinary file. +PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, + dev_t dev); + +// Unlink the file. +PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, + int flags); + // Truncates a file to the given length or returns an error. PosixError Truncate(absl::string_view path, int length); diff --git a/test/util/fs_util_test.cc b/test/util/fs_util_test.cc index 2a200320a..657b6a46e 100644 --- a/test/util/fs_util_test.cc +++ b/test/util/fs_util_test.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/util/fs_util.h" + #include <errno.h> + #include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/util/mount_util.h b/test/util/mount_util.h index 38ec6c8a1..484de560e 100644 --- a/test/util/mount_util.h +++ b/test/util/mount_util.h @@ -17,6 +17,7 @@ #include <errno.h> #include <sys/mount.h> + #include <functional> #include <string> diff --git a/test/util/multiprocess_util.cc b/test/util/multiprocess_util.cc index 95f5f3b4f..8b676751b 100644 --- a/test/util/multiprocess_util.cc +++ b/test/util/multiprocess_util.cc @@ -14,6 +14,7 @@ #include "test/util/multiprocess_util.h" +#include <asm/unistd.h> #include <errno.h> #include <fcntl.h> #include <signal.h> @@ -30,11 +31,12 @@ namespace gvisor { namespace testing { -PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, - const ExecveArray& argv, - const ExecveArray& envv, - const std::function<void()>& fn, pid_t* child, - int* execve_errno) { +namespace { + +// exec_fn wraps a variant of the exec family, e.g. execve or execveat. +PosixErrorOr<Cleanup> ForkAndExecHelper(const std::function<void()>& exec_fn, + const std::function<void()>& fn, + pid_t* child, int* execve_errno) { int pfds[2]; int ret = pipe2(pfds, O_CLOEXEC); if (ret < 0) { @@ -76,7 +78,9 @@ PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, fn(); } - execve(filename.c_str(), argv.get(), envv.get()); + // Call variant of exec function. + exec_fn(); + int error = errno; if (WriteFd(pfds[1], &error, sizeof(error)) != sizeof(error)) { // We can't do much if the write fails, but we can at least exit with a @@ -116,6 +120,36 @@ PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, return std::move(cleanup); } +} // namespace + +PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, + const ExecveArray& argv, + const ExecveArray& envv, + const std::function<void()>& fn, pid_t* child, + int* execve_errno) { + char* const* argv_data = argv.get(); + char* const* envv_data = envv.get(); + const std::function<void()> exec_fn = [=] { + execve(filename.c_str(), argv_data, envv_data); + }; + return ForkAndExecHelper(exec_fn, fn, child, execve_errno); +} + +PosixErrorOr<Cleanup> ForkAndExecveat(const int32_t dirfd, + const std::string& pathname, + const ExecveArray& argv, + const ExecveArray& envv, const int flags, + const std::function<void()>& fn, + pid_t* child, int* execve_errno) { + char* const* argv_data = argv.get(); + char* const* envv_data = envv.get(); + const std::function<void()> exec_fn = [=] { + syscall(__NR_execveat, dirfd, pathname.c_str(), argv_data, envv_data, + flags); + }; + return ForkAndExecHelper(exec_fn, fn, child, execve_errno); +} + PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn) { pid_t pid = fork(); if (pid == 0) { diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h index 0aecd3439..61526b4e7 100644 --- a/test/util/multiprocess_util.h +++ b/test/util/multiprocess_util.h @@ -102,6 +102,22 @@ inline PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, return ForkAndExec(filename, argv, envv, [] {}, child, execve_errno); } +// Equivalent to ForkAndExec, except using dirfd and flags with execveat. +PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, const std::string& pathname, + const ExecveArray& argv, + const ExecveArray& envv, int flags, + const std::function<void()>& fn, + pid_t* child, int* execve_errno); + +inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, + const std::string& pathname, + const ExecveArray& argv, + const ExecveArray& envv, int flags, + pid_t* child, int* execve_errno) { + return ForkAndExecveat( + dirfd, pathname, argv, envv, flags, [] {}, child, execve_errno); +} + // Calls fn in a forked subprocess and returns the exit status of the // subprocess. // diff --git a/test/util/posix_error_test.cc b/test/util/posix_error_test.cc index d67270842..bf9465abb 100644 --- a/test/util/posix_error_test.cc +++ b/test/util/posix_error_test.cc @@ -15,6 +15,7 @@ #include "test/util/posix_error.h" #include <errno.h> + #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc index c0fd9a095..c01f916aa 100644 --- a/test/util/pty_util.cc +++ b/test/util/pty_util.cc @@ -24,6 +24,14 @@ namespace gvisor { namespace testing { PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { + PosixErrorOr<int> n = SlaveID(master); + if (!n.ok()) { + return PosixErrorOr<FileDescriptor>(n.error()); + } + return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); +} + +PosixErrorOr<int> SlaveID(const FileDescriptor& master) { // Get pty index. int n; int ret = ioctl(master.get(), TIOCGPTN, &n); @@ -38,7 +46,7 @@ PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { return PosixError(errno, "ioctl(TIOSPTLCK) failed"); } - return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); + return n; } } // namespace testing diff --git a/test/util/pty_util.h b/test/util/pty_util.h index 367b14f15..0722da379 100644 --- a/test/util/pty_util.h +++ b/test/util/pty_util.h @@ -24,6 +24,9 @@ namespace testing { // Opens the slave end of the passed master as R/W and nonblocking. PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); +// Get the number of the slave end of the master. +PosixErrorOr<int> SlaveID(const FileDescriptor& master); + } // namespace testing } // namespace gvisor diff --git a/test/util/rlimit_util.cc b/test/util/rlimit_util.cc index 684253f78..d7bfc1606 100644 --- a/test/util/rlimit_util.cc +++ b/test/util/rlimit_util.cc @@ -15,6 +15,7 @@ #include "test/util/rlimit_util.h" #include <sys/resource.h> + #include <cerrno> #include "test/util/cleanup.h" diff --git a/test/util/signal_util.cc b/test/util/signal_util.cc index 26738864f..5ee95ee80 100644 --- a/test/util/signal_util.cc +++ b/test/util/signal_util.cc @@ -15,6 +15,7 @@ #include "test/util/signal_util.h" #include <signal.h> + #include <ostream> #include "gtest/gtest.h" diff --git a/test/util/signal_util.h b/test/util/signal_util.h index 7fd2af015..bcf85c337 100644 --- a/test/util/signal_util.h +++ b/test/util/signal_util.h @@ -18,6 +18,7 @@ #include <signal.h> #include <sys/syscall.h> #include <unistd.h> + #include <ostream> #include "gmock/gmock.h" diff --git a/test/util/temp_path.h b/test/util/temp_path.h index 92d669503..9e5ac11f4 100644 --- a/test/util/temp_path.h +++ b/test/util/temp_path.h @@ -16,6 +16,7 @@ #define GVISOR_TEST_UTIL_TEMP_PATH_H_ #include <sys/stat.h> + #include <string> #include <utility> diff --git a/test/util/test_util.cc b/test/util/test_util.cc index ba0dcf7d0..848504c88 100644 --- a/test/util/test_util.cc +++ b/test/util/test_util.cc @@ -41,6 +41,7 @@ namespace gvisor { namespace testing { #define TEST_ON_GVISOR "TEST_ON_GVISOR" +#define GVISOR_NETWORK "GVISOR_NETWORK" bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; } @@ -60,6 +61,11 @@ Platform GvisorPlatform() { abort(); } +bool IsRunningWithHostinet() { + char* env = getenv(GVISOR_NETWORK); + return env && strcmp(env, "host") == 0; +} + // Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations // %ebx contains the address of the global offset table. %rbx is occasionally // used to address stack variables in presence of dynamic allocas. @@ -116,9 +122,6 @@ PosixErrorOr<KernelVersion> GetKernelVersion() { return ParseKernelVersion(buf.release); } -void SetupGvisorDeathTest() { -} - std::string CPUSetToString(const cpu_set_t& set, size_t cpus) { std::string str = "cpuset["; for (unsigned int n = 0; n < cpus; n++) { @@ -224,15 +227,5 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) { return abs_diff <= static_cast<uint64_t>(tolerance * target); } -void TestInit(int* argc, char*** argv) { - ::testing::InitGoogleTest(argc, *argv); - ::absl::ParseCommandLine(*argc, *argv); - - // Always mask SIGPIPE as it's common and tests aren't expected to handle it. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); -} - } // namespace testing } // namespace gvisor diff --git a/test/util/test_util.h b/test/util/test_util.h index b9d2dc2ba..b3235c7e3 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -220,8 +220,11 @@ enum class Platform { }; bool IsRunningOnGvisor(); Platform GvisorPlatform(); +bool IsRunningWithHostinet(); +#ifdef __linux__ void SetupGvisorDeathTest(); +#endif struct KernelVersion { int major; @@ -762,6 +765,12 @@ MATCHER_P2(EquivalentWithin, target, tolerance, return Equivalent(arg, target, tolerance); } +// Returns the absolute path to the a data dependency. 'path' is the runfile +// location relative to workspace root. +#ifdef __linux__ +std::string RunfilePath(std::string path); +#endif + void TestInit(int* argc, char*** argv); } // namespace testing diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc new file mode 100644 index 000000000..ba7c0a85b --- /dev/null +++ b/test/util/test_util_impl.cc @@ -0,0 +1,38 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> + +#include "gtest/gtest.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +void SetupGvisorDeathTest() {} + +void TestInit(int* argc, char*** argv) { + ::testing::InitGoogleTest(argc, *argv); + ::absl::ParseCommandLine(*argc, *argv); + + // Always mask SIGPIPE as it's common and tests aren't expected to handle it. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc new file mode 100644 index 000000000..7210094eb --- /dev/null +++ b/test/util/test_util_runfiles.cc @@ -0,0 +1,46 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <iostream> +#include <string> + +#include "test/util/fs_util.h" +#include "test/util/test_util.h" +#include "tools/cpp/runfiles/runfiles.h" + +namespace gvisor { +namespace testing { + +std::string RunfilePath(std::string path) { + static const bazel::tools::cpp::runfiles::Runfiles* const runfiles = [] { + std::string error; + auto* runfiles = + bazel::tools::cpp::runfiles::Runfiles::CreateForTest(&error); + if (runfiles == nullptr) { + std::cerr << "Unable to find runfiles: " << error << std::endl; + } + return runfiles; + }(); + + if (!runfiles) { + // Can't find runfiles? This probably won't work, but __main__/path is our + // best guess. + return JoinPath("__main__", path); + } + + return runfiles->Rlocation(JoinPath("__main__", path)); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/test_util_test.cc b/test/util/test_util_test.cc index b7300d9e5..f42100374 100644 --- a/test/util/test_util_test.cc +++ b/test/util/test_util_test.cc @@ -15,6 +15,7 @@ #include "test/util/test_util.h" #include <errno.h> + #include <vector> #include "gmock/gmock.h" diff --git a/third_party/gvsync/downgradable_rwmutex_1_12_unsafe.go b/third_party/gvsync/downgradable_rwmutex_1_12_unsafe.go deleted file mode 100644 index 855b2a2b1..000000000 --- a/third_party/gvsync/downgradable_rwmutex_1_12_unsafe.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.12 -// +build !go1.13 - -// TODO(b/133868570): Delete once Go 1.12 is no longer supported. - -package gvsync - -import _ "unsafe" - -//go:linkname runtimeSemrelease112 sync.runtime_Semrelease -func runtimeSemrelease112(s *uint32, handoff bool) - -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) { - // 'skipframes' is only available starting from 1.13. - runtimeSemrelease112(s, handoff) -} diff --git a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go b/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go deleted file mode 100644 index 8baec5458..000000000 --- a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.13 -// +build !go1.14 - -// Check go:linkname function signatures when updating Go version. - -package gvsync - -import _ "unsafe" - -//go:linkname runtimeSemrelease sync.runtime_Semrelease -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) diff --git a/tools/go_branch.sh b/tools/go_branch.sh index ddb9b6e7b..f97a74aaf 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -17,9 +17,9 @@ set -eo pipefail # Discovery the package name from the go.mod file. -declare -r gomod="$(pwd)/go.mod" -declare -r module=$(cat "${gomod}" | grep -E "^module" | cut -d' ' -f2) -declare -r gosum="$(pwd)/go.sum" +declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2) +declare -r origpwd=$(pwd) +declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE") # Check that gopath has been built. declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}" @@ -65,10 +65,22 @@ git checkout -b go "${go_branch}" git merge --no-commit --strategy ours ${head} || \ git merge --allow-unrelated-histories --no-commit --strategy ours ${head} -# Sync the entire gopath_dir and go.mod. -rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" . -cp "${gomod}" . -cp "${gosum}" . +# Sync the entire gopath_dir. +rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" . + +# Add additional files. +for file in "${othersrc[@]}"; do + cp "${origpwd}"/"${file}" . +done + +# Construct a new README.md. +cat > README.md <<EOF +# gVisor + +This branch is a synthetic branch, containing only Go sources, that is +compatible with standard Go tools. See the master branch for authoritative +sources and tests. +EOF # There are a few solitary files that can get left behind due to the way bazel # constructs the gopath target. Note that we don't find all Go files here diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index fa82f8e9b..d412e1ccf 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_marshal:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_marshal:defs.bzl", "go_library") - package_group( name = "gomarshal_test", packages = [ diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD index 8fb43179b..9bb89e1da 100644 --- a/tools/go_marshal/test/external/BUILD +++ b/tools/go_marshal/test/external/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_marshal:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "external", testonly = 1, diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index a7961c2fb..33267c074 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -43,13 +43,13 @@ def _go_stateify_impl(ctx): # Run the stateify command. args = ["-output=%s" % output.path] - args += ["-pkg=%s" % ctx.attr.package] - args += ["-arch=%s" % ctx.attr.arch] + args.append("-pkg=%s" % ctx.attr.package) + args.append("-arch=%s" % ctx.attr.arch) if ctx.attr._statepkg: - args += ["-statepkg=%s" % ctx.attr._statepkg] + args.append("-statepkg=%s" % ctx.attr._statepkg) if ctx.attr.imports: - args += ["-imports=%s" % ",".join(ctx.attr.imports)] - args += ["--"] + args.append("-imports=%s" % ",".join(ctx.attr.imports)) + args.append("--") for src in ctx.attr.srcs: args += [f.path for f in src.files.to_list()] ctx.actions.run( @@ -124,8 +124,8 @@ def go_library(name, srcs, deps = [], imports = [], **kwargs): imports = imports, package = name, arch = select({ - "@bazel_tools//src/conditions:linux_aarch64": "arm64", - "//conditions:default": "amd64", + "@bazel_tools//src/conditions:linux_aarch64": "arm64", + "//conditions:default": "amd64", }), out = name + "_state_autogen.go", ) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 47c8ea1d7..7d5d291e6 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -90,8 +90,9 @@ func matchtag(tag string, goarch string) bool { return tag == goarch } -// canBuild reports whether we can build this file for target platform by checking file name and build tags. -// The code is derived from the Go source cmd.dist.build.shouldbuild. +// canBuild reports whether we can build this file for target platform by +// checking file name and build tags. The code is derived from the Go source +// cmd.dist.build.shouldbuild. func canBuild(file, goTargetArch string) bool { name := filepath.Base(file) excluded := func(list []string, ok string) bool { @@ -120,11 +121,6 @@ func canBuild(file, goTargetArch string) bool { if p == "" { continue } - code := p - i := strings.Index(code, "//") - if i > 0 { - code = strings.TrimSpace(code[:i]) - } if !strings.HasPrefix(p, "//") { break } diff --git a/tools/make_repository.sh b/tools/make_repository.sh index 071f72b74..27ffbc9f3 100755 --- a/tools/make_repository.sh +++ b/tools/make_repository.sh @@ -17,13 +17,13 @@ # Parse arguments. We require more than two arguments, which are the private # keyring, the e-mail associated with the signer, and the list of packages. if [ "$#" -le 3 ]; then - echo "usage: $0 <private-key> <signer-email> <component> <packages...>" + echo "usage: $0 <private-key> <signer-email> <component> <root> <packages...>" exit 1 fi -declare -r private_key=$(readlink -e "$1") -declare -r signer="$2" -declare -r component="$3" -shift; shift; shift +declare -r private_key=$(readlink -e "$1"); shift +declare -r signer="$1"; shift +declare -r component="$1"; shift +declare -r root="$1"; shift # Verbose from this point. set -xeo pipefail @@ -40,7 +40,7 @@ cleanup() { trap cleanup EXIT gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2 -# Copy the packages, and ensure permissions are correct. +# Copy the packages into the root. for pkg in "$@"; do name=$(basename "${pkg}" .deb) name=$(basename "${name}" .changes) @@ -48,32 +48,61 @@ for pkg in "$@"; do if [[ "${name}" == "${arch}" ]]; then continue # Not a regular package. fi - mkdir -p "${tmpdir}"/"${component}"/binary-"${arch}" - cp -a "${pkg}" "${tmpdir}"/"${component}"/binary-"${arch}" + if [[ "${pkg}" =~ ^.*\.deb$ ]]; then + # Extract from the debian file. + version=$(dpkg --info "${pkg}" | grep -E 'Version:' | cut -d':' -f2) + elif [[ "${pkg}" =~ ^.*\.changes$ ]]; then + # Extract from the changes file. + version=$(grep -E 'Version:' "${pkg}" | cut -d':' -f2) + else + # Unsupported file type. + echo "Unknown file type: ${pkg}" + exit 1 + fi + version=${version// /} # Trim whitespace. + mkdir -p "${root}"/pool/"${version}"/binary-"${arch}" + cp -a "${pkg}" "${root}"/pool/"${version}"/binary-"${arch}" done -find "${tmpdir}" -type f -exec chmod 0644 {} \; -# Ensure there are no symlinks hanging around; these may be remnants of the -# build process. They may be useful for other things, but we are going to build -# an index of the actual packages here. -find "${tmpdir}" -type l -exec rm -f {} \; +# Ensure all permissions are correct. +find "${root}"/pool -type f -exec chmod 0644 {} \; # Sign all packages. -for file in "${tmpdir}"/"${component}"/binary-*/*.deb; do +for file in "${root}"/pool/*/binary-*/*.deb; do dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2 done # Build the package list. -for dir in "${tmpdir}"/"${component}"/binary-*; do - (cd "${dir}" && apt-ftparchive packages . | gzip > Packages.gz) +declare arches=() +for dir in "${root}"/pool/*/binary-*; do + name=$(basename "${dir}") + arch=${name##binary-} + arches+=("${arch}") + repo_packages="${tmpdir}"/"${component}"/"${name}" + mkdir -p "${repo_packages}" + (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages) + (cd "${repo_packages}" && cat Packages | gzip > Packages.gz) + (cd "${repo_packages}" && cat Packages | xz > Packages.xz) done # Build the release list. -(cd "${tmpdir}" && apt-ftparchive release . > Release) +cat > "${tmpdir}"/apt.conf <<EOF +APT { + FTPArchive { + Release { + Architectures "${arches[@]}"; + Components "${component}"; + }; + }; +}; +EOF +(cd "${tmpdir}" && apt-ftparchive -c=apt.conf release . > Release) +rm "${tmpdir}"/apt.conf # Sign the release. -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release >&2) -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release >&2) +declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512") +(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release >&2) +(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release >&2) # Show the results. echo "${tmpdir}" diff --git a/tools/tag_release.sh b/tools/tag_release.sh index 9d5a60583..f33b902d6 100755 --- a/tools/tag_release.sh +++ b/tools/tag_release.sh @@ -64,5 +64,6 @@ fi # Tag the given commit (annotated, to record the committer). declare -r tag="release-${release}" -(git tag -a "${tag}" "${commit}" && git push origin tag "${tag}") || \ +(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \ + git push origin tag "${tag}") || \ (git tag -d "${tag}" && false) diff --git a/vdso/BUILD b/vdso/BUILD index 7ceed349e..2b6744c26 100644 --- a/vdso/BUILD +++ b/vdso/BUILD @@ -68,14 +68,14 @@ genrule( "&& $(location :check_vdso) " + "--check-data " + "--vdso $(location vdso.so) ", + exec_tools = [ + ":check_vdso", + ], features = ["-pie"], toolchains = [ "@bazel_tools//tools/cpp:current_cc_toolchain", ":no_pie_cc_flags", ], - tools = [ - ":check_vdso", - ], visibility = ["//:sandbox"], ) @@ -87,6 +87,6 @@ cc_flags_supplier( py_binary( name = "check_vdso", srcs = ["check_vdso.py"], - python_version = "PY2", + python_version = "PY3", visibility = ["//:sandbox"], ) |