diff options
264 files changed, 12953 insertions, 1450 deletions
@@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Build with C++17. +build --cxxopt=-std=c++17 + # Display the current git revision in the info block. build --stamp --workspace_status_command tools/workspace_status.sh diff --git a/Dockerfile b/Dockerfile index 6e9d870db..5b95822f9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ FROM ubuntu:bionic -RUN apt-get update && apt-get install -y curl gnupg2 git python3 +RUN apt-get update && apt-get install -y curl gnupg2 git python3 python3-distutils python3-pip RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ - curl https://bazel.build/bazel-release.pub.gpg | apt-key add - + curl https://bazel.build/bazel-release.pub.gpg | apt-key add - RUN apt-get update && apt-get install -y bazel && apt-get clean WORKDIR /gvisor @@ -48,9 +48,10 @@ Make sure the following dependencies are installed: * Linux 4.14.77+ ([older linux][old-linux]) * [git][git] -* [Bazel][bazel] 0.28.0+ +* [Bazel][bazel] 1.2+ * [Python][python] * [Docker version 17.09.0 or greater][docker] +* C++ toolchain supporting C++17 (GCC 7+, Clang 5+) * Gold linker (e.g. `binutils-gold` package on Ubuntu) ### Building @@ -1,6 +1,7 @@ -# Load go bazel rules and gazelle. load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") +# Load go bazel rules and gazelle. http_archive( name = "io_bazel_rules_go", sha256 = "b9aa86ec08a292b97ec4591cf578e020b35f98e12173bbd4a921f84f583aebd9", @@ -58,6 +59,26 @@ load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") protobuf_deps() +# Load python dependencies. +git_repository( + name = "rules_python", + commit = "94677401bc56ed5d756f50b441a6a5c7f735a6d4", + remote = "https://github.com/bazelbuild/rules_python.git", + shallow_since = "1573842889 -0500", +) + +load("@rules_python//python:pip.bzl", "pip_import") + +pip_import( + name = "pydeps", + python_interpreter = "python3", + requirements = "//benchmarks:requirements.txt", +) + +load("@pydeps//:requirements.bzl", "pip_install") + +pip_install() + # Load bazel_toolchain to support Remote Build Execution. # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( diff --git a/benchmarks/BUILD b/benchmarks/BUILD new file mode 100644 index 000000000..dbadeeaf2 --- /dev/null +++ b/benchmarks/BUILD @@ -0,0 +1,9 @@ +package(licenses = ["notice"]) + +py_binary( + name = "benchmarks", + srcs = ["run.py"], + main = "run.py", + python_version = "PY3", + deps = ["//benchmarks/runner"], +) diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..ad44cd6ac --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,172 @@ +# Benchmark tools + +These scripts are tools for collecting performance data for Docker-based tests. + +## Setup + +The scripts assume the following: + +* You have a local machine with bazel installed. +* You have some machine(s) with docker installed. These machines will be + refered to as the "Environment". +* Environment machines have the runtime(s) under test installed, such that you + can run docker with a command like: `docker run --runtime=$RUNTIME + your/image`. +* You are able to login to machines in the environment with the local machine + via ssh and the user for ssh can run docker commands without using `sudo`. +* The docker daemon on each of your environment machines is listening on + `unix:///var/run/docker.sock` (docker's default). + +For configuring the environment manually, consult the +[dockerd documentation][dockerd]. + +## Environment + +All benchmarks require a user defined yaml file describe the environment. These +files are of the form: + +```yaml +machine1: local +machine2: + hostname: 100.100.100.100 + username: username + key_path: ~/private_keyfile + key_password: passphrase +machine3: + hostname: 100.100.100.101 + username: username + key_path: ~/private_keyfile + key_password: passphrase +``` + +The yaml file defines an environment with three machines named `machine1`, +`machine2` and `machine3`. `machine1` is the local machine, `machine2` and +`machine3` are remote machines. Both `machine2` and `machine3` should be +reachable by `ssh`. For example, the command `ssh -i ~/private_keyfile +username@100.100.100.100` (using the passphrase `passphrase`) should connect to +`machine2`. + +The above is an example only. Machines should be uniform, since they are treated +as such by the tests. Machines must also be accessible to each other via their +default routes. Furthermore, some benchmarks will meaningless if running on the +local machine, such as density. + +For remote machines, `hostname`, `key_path`, and `username` are required and +others are optional. In addition key files must be generated +[using the instrcutions below](#generating-ssh-keys). + +The above yaml file can be checked for correctness with the `validate` command +in the top level perf.py script: + +`bazel run :benchmarks -- validate $PWD/examples/localhost.yaml` + +## Running benchmarks + +To list available benchmarks, use the `list` commmand: + +```bash +bazel run :benchmarks -- list + +... +Benchmark: sysbench.cpu +Metrics: events_per_second + Run sysbench CPU test. Additional arguments can be provided for sysbench. + + :param max_prime: The maximum prime number to search. +``` + +To run benchmarks, use the `run` command. For example, to run the sysbench +benchmark above: + +```bash +bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml sysbench.cpu +``` + +You can run parameterized benchmarks, for example to run with different +runtimes: + +```bash +bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --runtime=runc --runtime=runsc sysbench.cpu +``` + +Or with different parameters: + +```bash +bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --max_prime=10 --max_prime=100 sysbench.cpu +``` + +## Writing benchmarks + +To write new benchmarks, you should familiarize yourself with the structure of +the repository. There are three key components. + +## Harness + +The harness makes use of the [docker py SDK][docker-py]. It is advisable that +you familiarize yourself with that API when making changes, specifically: + +* clients +* containers +* images + +In general, benchmarks need only interact with the `Machine` objects provided to +the benchmark function, which are the machines defined in the environment. These +objects allow the benchmark to define the relationships between different +containers, and parse the output. + +## Workloads + +The harness requires workloads to run. These are all available in the +`workloads` directory. + +In general, a workload consists of a Dockerfile to build it (while these are not +hermetic, in general they should be as fixed and isolated as possible), some +parses for output if required, parser tests and sample data. Provided the test +is named after the workload package and contains a function named `sample`, this +variable will be used to automatically mock workload output when the `--mock` +flag is provided to the main tool. + +## Writing benchmarks + +Benchmarks define the tests themselves. All benchmarks have the following +function signature: + +```python +def my_func(output) -> float: + return float(output) + +@benchmark(metrics = my_func, machines = 1) +def my_benchmark(machine: machine.Machine, arg: str): + return "3.4432" +``` + +Each benchmark takes a variable amount of position arguments as +`harness.Machine` objects and some set of keyword arguments. It is recommended +that you accept arbitrary keyword arguments and pass them through when +constructing the container under test. + +To write a new benchmark, open a module in the `suites` directory and use the +above signature. You should add a descriptive doc string to describe what your +benchmark is and any test centric arguments. + +## Generating SSH Keys + +The scripts only support RSA Keys, and ssh library used in paramiko. Paramiko +only supports RSA keys that look like the following (PEM format): + +```bash +$ cat /path/to/ssh/key + +-----BEGIN RSA PRIVATE KEY----- +...private key text... +-----END RSA PRIVATE KEY----- + +``` + +To generate ssh keys in PEM format, use the [`-t rsa -m PEM -b 4096`][RSA-keys]. +option. + +[dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/ +[docker-py]: https://docker-py.readthedocs.io/en/stable/ +[paramiko]: http://docs.paramiko.org/en/2.4/api/client.html +[RSA-keys]: https://serverfault.com/questions/939909/ssh-keygen-does-not-create-rsa-private-key diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl new file mode 100644 index 000000000..79e6cdbc8 --- /dev/null +++ b/benchmarks/defs.bzl @@ -0,0 +1,18 @@ +"""Provides python helper functions.""" + +load("@pydeps//:requirements.bzl", _requirement = "requirement") + +def filter_deps(deps = None): + if deps == None: + deps = [] + return [dep for dep in deps if dep] + +def py_library(deps = None, **kwargs): + return native.py_library(deps = filter_deps(deps), **kwargs) + +def py_test(deps = None, **kwargs): + return native.py_test(deps = filter_deps(deps), **kwargs) + +def requirement(name, direct = True): + """ requirement returns the required dependency. """ + return _requirement(name) diff --git a/benchmarks/examples/localhost.yaml b/benchmarks/examples/localhost.yaml new file mode 100644 index 000000000..f70fe0fb7 --- /dev/null +++ b/benchmarks/examples/localhost.yaml @@ -0,0 +1,2 @@ +client: localhost +server: localhost diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD new file mode 100644 index 000000000..9546220c4 --- /dev/null +++ b/benchmarks/harness/BUILD @@ -0,0 +1,89 @@ +load("//benchmarks:defs.bzl", "py_library", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "harness", + srcs = ["__init__.py"], +) + +py_library( + name = "benchmark_driver", + srcs = ["benchmark_driver.py"], + deps = [ + "//benchmarks/harness/machine_mocks", + "//benchmarks/harness/machine_producers:machine_producer", + "//benchmarks/suites", + ], +) + +py_library( + name = "container", + srcs = ["container.py"], + deps = [ + requirement("asn1crypto", False), + requirement("chardet", False), + requirement("certifi", False), + requirement("docker", True), + requirement("docker-pycreds", False), + requirement("idna", False), + requirement("ptyprocess", False), + requirement("requests", False), + requirement("urllib3", False), + requirement("websocket-client", False), + ], +) + +py_library( + name = "machine", + srcs = ["machine.py"], + deps = [ + "//benchmarks/harness", + "//benchmarks/harness:container", + "//benchmarks/harness:ssh_connection", + "//benchmarks/harness:tunnel_dispatcher", + requirement("asn1crypto", False), + requirement("chardet", False), + requirement("certifi", False), + requirement("docker", True), + requirement("docker-pycreds", False), + requirement("idna", False), + requirement("ptyprocess", False), + requirement("requests", False), + requirement("urllib3", False), + requirement("websocket-client", False), + ], +) + +py_library( + name = "ssh_connection", + srcs = ["ssh_connection.py"], + deps = [ + "//benchmarks/harness", + requirement("bcrypt", False), + requirement("cffi", False), + requirement("paramiko", True), + requirement("cryptography", False), + ], +) + +py_library( + name = "tunnel_dispatcher", + srcs = ["tunnel_dispatcher.py"], + deps = [ + requirement("asn1crypto", False), + requirement("chardet", False), + requirement("certifi", False), + requirement("docker", True), + requirement("docker-pycreds", False), + requirement("idna", False), + requirement("pexpect", True), + requirement("ptyprocess", False), + requirement("requests", False), + requirement("urllib3", False), + requirement("websocket-client", False), + ], +) diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py new file mode 100644 index 000000000..a7f34da9e --- /dev/null +++ b/benchmarks/harness/__init__.py @@ -0,0 +1,25 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core benchmark utilities.""" + +import os + +# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a +# format string that accepts a single string parameter. +LOCAL_WORKLOADS_PATH = os.path.join( + os.path.dirname(__file__), "../workloads/{}") + +# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the +# remote host. This is a format string that accepts a single string parameter. +REMOTE_WORKLOADS_PATH = "workloads/{}" diff --git a/benchmarks/harness/benchmark_driver.py b/benchmarks/harness/benchmark_driver.py new file mode 100644 index 000000000..9abc21b54 --- /dev/null +++ b/benchmarks/harness/benchmark_driver.py @@ -0,0 +1,85 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Main driver for benchmarks.""" + +import copy +import statistics +import threading +import types + +from benchmarks import suites +from benchmarks.harness.machine_producers import machine_producer + + +# pylint: disable=too-many-instance-attributes +class BenchmarkDriver: + """Allocates machines and invokes a benchmark method.""" + + def __init__(self, + producer: machine_producer.MachineProducer, + method: types.FunctionType, + runs: int = 1, + **kwargs): + + self._producer = producer + self._method = method + self._kwargs = copy.deepcopy(kwargs) + self._threads = [] + self.lock = threading.RLock() + self._runs = runs + self._metric_results = {} + + def start(self): + """Starts a benchmark thread.""" + for _ in range(self._runs): + thread = threading.Thread(target=self._run_method) + thread.start() + self._threads.append(thread) + + def join(self): + """Joins the thread.""" + # pylint: disable=expression-not-assigned + [t.join() for t in self._threads] + + def _run_method(self): + """Runs all benchmarks.""" + machines = self._producer.get_machines( + suites.benchmark_machines(self._method)) + try: + result = self._method(*machines, **self._kwargs) + for name, res in result: + with self.lock: + if name in self._metric_results: + self._metric_results[name].append(res) + else: + self._metric_results[name] = [res] + finally: + # Always release. + self._producer.release_machines(machines) + + def median(self): + """Returns the median result, after join is finished.""" + for key, value in self._metric_results.items(): + yield key, [statistics.median(value)] + + def all(self): + """Returns all results.""" + for key, value in self._metric_results.items(): + yield key, value + + def meanstd(self): + """Returns all results.""" + for key, value in self._metric_results.items(): + mean = statistics.mean(value) + yield key, [mean, statistics.stdev(value, xbar=mean)] diff --git a/benchmarks/harness/container.py b/benchmarks/harness/container.py new file mode 100644 index 000000000..585436e20 --- /dev/null +++ b/benchmarks/harness/container.py @@ -0,0 +1,181 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Container definitions.""" + +import contextlib +import logging +import pydoc +import types +from typing import Tuple + +import docker +import docker.errors + +from benchmarks import workloads + + +class Container: + """Abstract container. + + Must be a context manager. + + Usage: + + with Container(client, image, ...): + ... + """ + + def run(self, **env) -> str: + """Run the container synchronously.""" + raise NotImplementedError + + def detach(self, **env): + """Run the container asynchronously.""" + raise NotImplementedError + + def address(self) -> Tuple[str, int]: + """Return the bound address for the container.""" + raise NotImplementedError + + def get_names(self) -> types.GeneratorType: + """Return names of all containers.""" + raise NotImplementedError + + +# pylint: disable=too-many-instance-attributes +class DockerContainer(Container): + """Class that handles creating a docker container.""" + + # pylint: disable=too-many-arguments + def __init__(self, + client: docker.DockerClient, + host: str, + image: str, + count: int = 1, + runtime: str = "runc", + port: int = 0, + **kwargs): + """Trys to setup "count" containers. + + Args: + client: A docker client from dockerpy. + host: The host address the image is running on. + image: The name of the image to run. + count: The number of containers to setup. + runtime: The container runtime to use. + port: The port to reserve. + **kwargs: Additional container options. + """ + assert count >= 1 + assert port == 0 or count == 1 + self._client = client + self._host = host + self._containers = [] + self._count = count + self._image = image + self._runtime = runtime + self._port = port + self._kwargs = kwargs + if port != 0: + self._ports = {"%d/tcp" % port: None} + else: + self._ports = {} + + @contextlib.contextmanager + def detach(self, **env): + env = ["%s=%s" % (key, value) for (key, value) in env.items()] + # Start all containers. + for _ in range(self._count): + try: + # Start the container in a detached mode. + container = self._client.containers.run( + self._image, + detach=True, + remove=True, + runtime=self._runtime, + ports=self._ports, + environment=env, + **self._kwargs) + logging.info("Started detached container %s -> %s", self._image, + container.attrs["Id"]) + self._containers.append(container) + except Exception as exc: + self._clean_containers() + raise exc + try: + # Wait for all containers to be up. + for container in self._containers: + while not container.attrs["State"]["Running"]: + container = self._client.containers.get(container.attrs["Id"]) + yield self + finally: + self._clean_containers() + + def address(self) -> Tuple[str, int]: + assert self._count == 1 + assert self._port != 0 + container = self._client.containers.get(self._containers[0].attrs["Id"]) + port = container.attrs["NetworkSettings"]["Ports"][ + "%d/tcp" % self._port][0]["HostPort"] + return (self._host, port) + + def get_names(self) -> types.GeneratorType: + for container in self._containers: + yield container.name + + def run(self, **env) -> str: + env = ["%s=%s" % (key, value) for (key, value) in env.items()] + return self._client.containers.run( + self._image, + runtime=self._runtime, + ports=self._ports, + remove=True, + environment=env, + **self._kwargs).decode("utf-8") + + def _clean_containers(self): + """Kills all containers.""" + for container in self._containers: + try: + container.kill() + except docker.errors.NotFound: + pass + + +class MockContainer(Container): + """Mock of Container.""" + + def __init__(self, workload: str): + self._workload = workload + + def __enter__(self): + return self + + def run(self, **env): + # Lookup sample data if any exists for the workload module. We use a + # well-defined test locate and a well-defined sample function. + mod = pydoc.locate(workloads.__name__ + "." + self._workload) + if hasattr(mod, "sample"): + return mod.sample(**env) + return "" # No output. + + def address(self) -> Tuple[str, int]: + return ("example.com", 80) + + def get_names(self) -> types.GeneratorType: + yield "mock" + + @contextlib.contextmanager + def detach(self, **env): + yield self diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py new file mode 100644 index 000000000..2166d040a --- /dev/null +++ b/benchmarks/harness/machine.py @@ -0,0 +1,191 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Machine abstraction. This is the primary API for benchmarks.""" + +import logging +import re +import subprocess +import time +from typing import Tuple + +import docker + +from benchmarks import harness +from benchmarks.harness import container +from benchmarks.harness import machine_mocks +from benchmarks.harness import ssh_connection +from benchmarks.harness import tunnel_dispatcher + + +class Machine: + """The machine object is the primary object for benchmarks. + + Machine objects are passed to each metric function call and benchmarks use + machines to access real connections to those machines. + """ + + def run(self, cmd: str) -> Tuple[str, str]: + """Convenience method for running a bash command on a machine object. + + Some machines may point to the local machine, and thus, do not have ssh + connections. Run runs a command either local or over ssh and returns the + output stdout and stderr as strings. + + Args: + cmd: The command to run as a string. + + Returns: + The command output. + """ + raise NotImplementedError + + def read(self, path: str) -> str: + """Reads the contents of some file. + + This will be mocked. + + Args: + path: The path to the file to be read. + + Returns: + The file contents. + """ + raise NotImplementedError + + def pull(self, workload: str) -> str: + """Send the given workload to the machine, build and tag it. + + All images must be defined by the workloads directory. + + Args: + workload: The workload name. + + Returns: + The workload tag. + """ + raise NotImplementedError + + def container(self, image: str, **kwargs) -> container.Container: + """Returns a container object. + + Args: + image: The pulled image tag. + **kwargs: Additional container options. + + Returns: + :return: a container.Container object. + """ + raise NotImplementedError + + def sleep(self, amount: float): + """Sleeps the given amount of time.""" + raise NotImplementedError + + +class MockMachine(Machine): + """A mocked machine.""" + + def run(self, cmd: str) -> Tuple[str, str]: + return "", "" + + def read(self, path: str) -> str: + return machine_mocks.Readfile(path) + + def pull(self, workload: str) -> str: + return workload # Workload is the tag. + + def container(self, image: str, **kwargs) -> container.Container: + return container.MockContainer(image) + + def sleep(self, amount: float): + pass + + +def get_address(machine: Machine) -> str: + """Return a machine's default address.""" + default_route, _ = machine.run("ip route get 8.8.8.8") + return re.search(" src ([0-9.]+) ", default_route).group(1) + + +class LocalMachine(Machine): + """The local machine.""" + + def __init__(self, name): + self._name = name + self._docker_client = docker.from_env() + + def __str__(self): + return self._name + + def run(self, cmd: str) -> Tuple[str, str]: + process = subprocess.Popen( + cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + return stdout.decode("utf-8"), stderr.decode("utf-8") + + def read(self, path: str) -> str: + # Read the exact path locally. + return open(path, "r").read() + + def pull(self, workload: str) -> str: + # Run the docker build command locally. + logging.info("Building %s@%s locally...", workload, self._name) + self.run("docker build --tag={} {}".format( + workload, harness.LOCAL_WORKLOADS_PATH.format(workload))) + return workload # Workload is the tag. + + def container(self, image: str, **kwargs) -> container.Container: + # Return a local docker container directly. + return container.DockerContainer(self._docker_client, get_address(self), + image, **kwargs) + + def sleep(self, amount: float): + time.sleep(amount) + + +class RemoteMachine(Machine): + """Remote machine accessible via an SSH connection.""" + + def __init__(self, name, **kwargs): + self._name = name + self._ssh_connection = ssh_connection.SSHConnection(name, **kwargs) + self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs) + self._tunnel.connect() + self._docker_client = self._tunnel.get_docker_client() + + def __str__(self): + return self._name + + def run(self, cmd: str) -> Tuple[str, str]: + return self._ssh_connection.run(cmd) + + def read(self, path: str) -> str: + # Just cat remotely. + stdout, stderr = self._ssh_connection.run("cat '{}'".format(path)) + return stdout + stderr + + def pull(self, workload: str) -> str: + # Push to the remote machine and build. + logging.info("Building %s@%s remotely...", workload, self._name) + remote_path = self._ssh_connection.send_workload(workload) + self.run("docker build --tag={} {}".format(workload, remote_path)) + return workload # Workload is the tag. + + def container(self, image: str, **kwargs) -> container.Container: + # Return a remote docker container. + return container.DockerContainer(self._docker_client, get_address(self), + image, **kwargs) + + def sleep(self, amount: float): + time.sleep(amount) diff --git a/benchmarks/harness/machine_mocks/BUILD b/benchmarks/harness/machine_mocks/BUILD new file mode 100644 index 000000000..c8ec4bc79 --- /dev/null +++ b/benchmarks/harness/machine_mocks/BUILD @@ -0,0 +1,9 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "machine_mocks", + srcs = ["__init__.py"], +) diff --git a/benchmarks/harness/machine_mocks/__init__.py b/benchmarks/harness/machine_mocks/__init__.py new file mode 100644 index 000000000..00f0085d7 --- /dev/null +++ b/benchmarks/harness/machine_mocks/__init__.py @@ -0,0 +1,81 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Machine mock files.""" + +MEMINFO = """\ +MemTotal: 7652344 kB +MemFree: 7174724 kB +MemAvailable: 7152008 kB +Buffers: 7544 kB +Cached: 178856 kB +SwapCached: 0 kB +Active: 270928 kB +Inactive: 68436 kB +Active(anon): 153124 kB +Inactive(anon): 880 kB +Active(file): 117804 kB +Inactive(file): 67556 kB +Unevictable: 0 kB +Mlocked: 0 kB +SwapTotal: 0 kB +SwapFree: 0 kB +Dirty: 900 kB +Writeback: 0 kB +AnonPages: 153000 kB +Mapped: 129120 kB +Shmem: 1044 kB +Slab: 60864 kB +SReclaimable: 22792 kB +SUnreclaim: 38072 kB +KernelStack: 2672 kB +PageTables: 5756 kB +NFS_Unstable: 0 kB +Bounce: 0 kB +WritebackTmp: 0 kB +CommitLimit: 3826172 kB +Committed_AS: 663836 kB +VmallocTotal: 34359738367 kB +VmallocUsed: 0 kB +VmallocChunk: 0 kB +HardwareCorrupted: 0 kB +AnonHugePages: 0 kB +ShmemHugePages: 0 kB +ShmemPmdMapped: 0 kB +CmaTotal: 0 kB +CmaFree: 0 kB +HugePages_Total: 0 +HugePages_Free: 0 +HugePages_Rsvd: 0 +HugePages_Surp: 0 +Hugepagesize: 2048 kB +DirectMap4k: 94196 kB +DirectMap2M: 4624384 kB +DirectMap1G: 3145728 kB +""" + +CONTENTS = { + "/proc/meminfo": MEMINFO, +} + + +def Readfile(path: str) -> str: + """Reads a mock file. + + Args: + path: The target path. + + Returns: + Mocked file contents or None. + """ + return CONTENTS.get(path, None) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD new file mode 100644 index 000000000..5b2228e01 --- /dev/null +++ b/benchmarks/harness/machine_producers/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "harness", + srcs = ["__init__.py"], +) + +py_library( + name = "machine_producer", + srcs = ["machine_producer.py"], +) + +py_library( + name = "mock_producer", + srcs = ["mock_producer.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/harness/machine_producers:machine_producer", + ], +) + +py_library( + name = "yaml_producer", + srcs = ["yaml_producer.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/harness/machine_producers:machine_producer", + requirement("PyYAML", False), + ], +) diff --git a/benchmarks/harness/machine_producers/__init__.py b/benchmarks/harness/machine_producers/__init__.py new file mode 100644 index 000000000..634ef4843 --- /dev/null +++ b/benchmarks/harness/machine_producers/__init__.py @@ -0,0 +1,13 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py new file mode 100644 index 000000000..124ee14cc --- /dev/null +++ b/benchmarks/harness/machine_producers/machine_producer.py @@ -0,0 +1,30 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract types.""" + +from typing import List + +from benchmarks.harness import machine + + +class MachineProducer: + """Abstract Machine producer.""" + + def get_machines(self, num_machines: int) -> List[machine.Machine]: + """Returns the requested number of machines.""" + raise NotImplementedError + + def release_machines(self, machine_list: List[machine.Machine]): + """Releases the given set of machines.""" + raise NotImplementedError diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py new file mode 100644 index 000000000..4f29ad53f --- /dev/null +++ b/benchmarks/harness/machine_producers/mock_producer.py @@ -0,0 +1,31 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Producers of mocks.""" + +from typing import List + +from benchmarks.harness import machine +from benchmarks.harness.machine_producers import machine_producer + + +class MockMachineProducer(machine_producer.MachineProducer): + """Produces MockMachine objects.""" + + def get_machines(self, num_machines: int) -> List[machine.MockMachine]: + """Returns the request number of MockMachines.""" + return [machine.MockMachine() for i in range(num_machines)] + + def release_machines(self, machine_list: List[machine.MockMachine]): + """No-op.""" + return diff --git a/benchmarks/harness/machine_producers/yaml_producer.py b/benchmarks/harness/machine_producers/yaml_producer.py new file mode 100644 index 000000000..5d334e480 --- /dev/null +++ b/benchmarks/harness/machine_producers/yaml_producer.py @@ -0,0 +1,106 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Producers based on yaml files.""" + +import os +import threading +from typing import Dict +from typing import List + +import yaml + +from benchmarks.harness import machine +from benchmarks.harness.machine_producers import machine_producer + + +class YamlMachineProducer(machine_producer.MachineProducer): + """Loads machines from a yaml file.""" + + def __init__(self, path: str): + self.machines = build_machines(path) + self.max_machines = len(self.machines) + self.machine_condition = threading.Condition() + + def get_machines(self, num_machines: int) -> List[machine.Machine]: + if num_machines > self.max_machines: + raise ValueError( + "Insufficient Ammount of Machines. {ask} asked for and have {max_num} max." + .format(ask=num_machines, max_num=self.max_machines)) + + with self.machine_condition: + while not self._enough_machines(num_machines): + self.machine_condition.wait(timeout=1) + return [self.machines.pop(0) for _ in range(num_machines)] + + def release_machines(self, machine_list: List[machine.Machine]): + with self.machine_condition: + while machine_list: + next_machine = machine_list.pop() + self.machines.append(next_machine) + self.machine_condition.notify() + + def _enough_machines(self, ask: int): + return ask <= len(self.machines) + + +def build_machines(path: str, num_machines: str = -1) -> List[machine.Machine]: + """Builds machine objects defined by the yaml file "path". + + Args: + path: The path to a yaml file which defines machines. + num_machines: Optional limit on how many machine objects to build. + + Returns: + Machine objects in a list. + + If num_machines is set, len(machines) <= num_machines. + """ + data = parse_yaml(path) + machines = [] + for key, value in data.items(): + if len(machines) == num_machines: + return machines + if isinstance(value, dict): + machines.append(machine.RemoteMachine(key, **value)) + else: + machines.append(machine.LocalMachine(key)) + return machines + + +def parse_yaml(path: str) -> Dict[str, Dict[str, str]]: + """Parse the yaml file pointed by path. + + Args: + path: The path to yaml file. + + Returns: + The contents of the yaml file as a dictionary. + """ + data = get_file_contents(path) + return yaml.load(data, Loader=yaml.Loader) + + +def get_file_contents(path: str) -> str: + """Dumps the file contents to a string and returns them. + + Args: + path: The path to dump. + + Returns: + The file contents as a string. + """ + if not os.path.isabs(path): + path = os.path.abspath(path) + with open(path) as input_file: + return input_file.read() diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py new file mode 100644 index 000000000..fcbfbcdb2 --- /dev/null +++ b/benchmarks/harness/ssh_connection.py @@ -0,0 +1,111 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SSHConnection handles the details of SSH connections.""" + +import os +import warnings + +import paramiko + +from benchmarks import harness + +# Get rid of paramiko Cryptography Warnings. +warnings.filterwarnings(action="ignore", module=".*paramiko.*") + + +def send_one_file(client: paramiko.SSHClient, path: str, remote_dir: str): + """Sends a single file via an SSH client. + + Args: + client: The existing SSH client. + path: The local path. + remote_dir: The remote directory. + """ + filename = path.split("/").pop() + client.exec_command("mkdir -p " + remote_dir) + with client.open_sftp() as ftp_client: + ftp_client.put(path, os.path.join(remote_dir, filename)) + + +class SSHConnection: + """SSH connection to a remote machine.""" + + def __init__(self, name: str, hostname: str, key_path: str, username: str, + **kwargs): + """Sets up a paramiko ssh connection to the given hostname.""" + self._name = name # Unused. + self._hostname = hostname + self._username = username + self._key_path = key_path # RSA Key path + self._kwargs = kwargs + # SSHConnection wraps paramiko. paramiko supports RSA, ECDSA, and Ed25519 + # keys, and we've chosen to only suport and require RSA keys. paramiko + # supports RSA keys that begin with '----BEGIN RSAKEY----'. + # https://stackoverflow.com/questions/53600581/ssh-key-generated-by-ssh-keygen-is-not-recognized-by-paramiko + self.rsa_key = self._rsa() + self.run("true") # Validate. + + def _client(self) -> paramiko.SSHClient: + """Returns a connected SSH client.""" + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + hostname=self._hostname, + port=22, + username=self._username, + pkey=self.rsa_key, + allow_agent=False, + look_for_keys=False) + return client + + def _rsa(self): + if "key_password" in self._kwargs: + password = self._kwargs["key_password"] + else: + password = None + rsa = paramiko.RSAKey.from_private_key_file(self._key_path, password) + return rsa + + def run(self, cmd: str) -> (str, str): + """Runs a command via ssh. + + Args: + cmd: The shell command to run. + + Returns: + The contents of stdout and stderr. + """ + with self._client() as client: + _, stdout, stderr = client.exec_command(command=cmd) + stdout.channel.recv_exit_status() + stdout = stdout.read().decode("utf-8") + stderr = stderr.read().decode("utf-8") + return stdout, stderr + + def send_workload(self, name: str) -> str: + """Sends a workload to the remote machine. + + Args: + name: The workload name. + + Returns: + The remote path. + """ + with self._client() as client: + for dirpath, _, filenames in os.walk( + harness.LOCAL_WORKLOADS_PATH.format(name)): + for filename in filenames: + send_one_file(client, os.path.join(dirpath, filename), + harness.REMOTE_WORKLOADS_PATH.format(name)) + return harness.REMOTE_WORKLOADS_PATH.format(name) diff --git a/benchmarks/harness/tunnel_dispatcher.py b/benchmarks/harness/tunnel_dispatcher.py new file mode 100644 index 000000000..8dfe2862a --- /dev/null +++ b/benchmarks/harness/tunnel_dispatcher.py @@ -0,0 +1,82 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tunnel handles setting up connections to remote machines.""" + +import os +import tempfile +import time + +import docker +import pexpect + +SSH_TUNNEL_COMMAND = """ssh + -o GlobalKnownHostsFile=/dev/null + -o UserKnownHostsFile=/dev/null + -o StrictHostKeyChecking=no + -nNT -L {filename}:/var/run/docker.sock + -i {key_path} + {username}@{hostname}""" + + +class Tunnel: + """The tunnel object represents the tunnel via ssh. + + This connects a local unix domain socket with a remote socket. + """ + + def __init__(self, name, hostname: str, username: str, key_path: str, + **kwargs): + self._filename = tempfile.NamedTemporaryFile(prefix=name).name + self._hostname = hostname + self._username = username + self._key_path = key_path + self._kwargs = kwargs + self._process = None + + def connect(self): + """Connects the SSH tunnel.""" + cmd = SSH_TUNNEL_COMMAND.format( + filename=self._filename, + key_path=self._key_path, + username=self._username, + hostname=self._hostname) + self._process = pexpect.spawn(cmd, timeout=10) + + # If given a password, assume we'll be asked for it. + if "key_password" in self._kwargs: + self._process.expect(["Enter passphrase for key .*: "]) + self._process.sendline(self._kwargs["key_password"]) + + while True: + # Wait for the tunnel to appear. + if self._process.exitstatus is not None: + raise ConnectionError("Error in setting up ssh tunnel") + if os.path.exists(self._filename): + return + time.sleep(0.1) + + def path(self): + """Return the socket file.""" + return self._filename + + def get_docker_client(self): + """Returns a docker client for this Tunne0l.""" + return docker.DockerClient(base_url="unix:/" + self._filename) + + def __del__(self): + """Closes the ssh connection process and deletes the socket file.""" + if self._process: + self._process.close() + if os.path.exists(self._filename): + os.remove(self._filename) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 000000000..577eb1a2e --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,32 @@ +asn1crypto==1.2.0 +atomicwrites==1.3.0 +attrs==19.3.0 +bcrypt==3.1.7 +certifi==2019.9.11 +cffi==1.13.2 +chardet==3.0.4 +Click==7.0 +cryptography==2.8 +docker==3.7.0 +docker-pycreds==0.4.0 +idna==2.8 +importlib-metadata==0.23 +more-itertools==7.2.0 +packaging==19.2 +paramiko==2.6.0 +pathlib2==2.3.5 +pexpect==4.7.0 +pluggy==0.9.0 +ptyprocess==0.6.0 +py==1.8.0 +pycparser==2.19 +PyNaCl==1.3.0 +pyparsing==2.4.5 +pytest==4.3.0 +PyYAML==5.1.2 +requests==2.22.0 +six==1.13.0 +urllib3==1.25.7 +wcwidth==0.1.7 +websocket-client==0.56.0 +zipp==0.6.0 diff --git a/benchmarks/run.py b/benchmarks/run.py new file mode 100644 index 000000000..a22eb8641 --- /dev/null +++ b/benchmarks/run.py @@ -0,0 +1,19 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Benchmark runner.""" + +from benchmarks import runner + +if __name__ == "__main__": + runner.runner() diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD new file mode 100644 index 000000000..de24824cc --- /dev/null +++ b/benchmarks/runner/BUILD @@ -0,0 +1,53 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package(licenses = ["notice"]) + +py_library( + name = "runner", + srcs = ["__init__.py"], + data = [ + "//benchmarks/workloads:files", + ], + visibility = ["//benchmarks:__pkg__"], + deps = [ + "//benchmarks/harness:benchmark_driver", + "//benchmarks/harness/machine_producers:mock_producer", + "//benchmarks/harness/machine_producers:yaml_producer", + "//benchmarks/suites", + "//benchmarks/suites:absl", + "//benchmarks/suites:density", + "//benchmarks/suites:fio", + "//benchmarks/suites:helpers", + "//benchmarks/suites:http", + "//benchmarks/suites:media", + "//benchmarks/suites:ml", + "//benchmarks/suites:network", + "//benchmarks/suites:redis", + "//benchmarks/suites:startup", + "//benchmarks/suites:sysbench", + "//benchmarks/suites:syscall", + requirement("click", True), + ], +) + +py_test( + name = "runner_test", + srcs = ["runner_test.py"], + python_version = "PY3", + tags = [ + "local", + "manual", + ], + deps = [ + ":runner", + requirement("click", True), + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py new file mode 100644 index 000000000..9bf9cfd65 --- /dev/null +++ b/benchmarks/runner/__init__.py @@ -0,0 +1,301 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""High-level benchmark utility.""" + +import copy +import csv +import logging +import pkgutil +import pydoc +import re +import sys +import types +from typing import List +from typing import Tuple + +import click + +from benchmarks import suites +from benchmarks.harness import benchmark_driver +from benchmarks.harness.machine_producers import mock_producer +from benchmarks.harness.machine_producers import yaml_producer + + +@click.group() +@click.option( + "--verbose/--no-verbose", default=False, help="Enable verbose logging.") +@click.option("--debug/--no-debug", default=False, help="Enable debug logging.") +def runner(verbose: bool = False, debug: bool = False): + """Run distributed benchmarks. + + See the run and list commands for details. + + Args: + verbose: Enable verbose logging. + debug: Enable debug logging (supercedes verbose). + """ + if debug: + logging.basicConfig(level=logging.DEBUG) + elif verbose: + logging.basicConfig(level=logging.INFO) + + +def find_benchmarks( + regex: str) -> List[Tuple[str, types.ModuleType, types.FunctionType]]: + """Finds all available benchmarks. + + Args: + regex: A regular expression to match. + + Returns: + A (short_name, module, function) tuple for each match. + """ + pkgs = pkgutil.walk_packages(suites.__path__, suites.__name__ + ".") + found = [] + for _, name, _ in pkgs: + mod = pydoc.locate(name) + funcs = [ + getattr(mod, x) + for x in dir(mod) + if suites.is_benchmark(getattr(mod, x)) + ] + for func in funcs: + # Use the short_name with the benchmarks. prefix stripped. + prefix_len = len(suites.__name__ + ".") + short_name = mod.__name__[prefix_len:] + "." + func.__name__ + # Add to the list if a pattern is provided. + if re.compile(regex).match(short_name): + found.append((short_name, mod, func)) + return found + + +@runner.command("list") +@click.argument("method", nargs=-1) +def list_all(method): + """Lists available benchmarks.""" + if not method: + method = ".*" + else: + method = "(" + ",".join(method) + ")" + for (short_name, _, func) in find_benchmarks(method): + print("Benchmark %s:" % short_name) + metrics = suites.benchmark_metrics(func) + if func.__doc__: + print(" " + func.__doc__.lstrip().rstrip()) + if metrics: + print("\n Metrics:") + for metric in metrics: + print("\t{name}: {doc}".format(name=metric[0], doc=metric[1])) + print("\n") + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-branches +# pylint: disable=too-many-locals +@runner.command( + context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) +@click.pass_context +@click.argument("method") +@click.option("--mock/--no-mock", default=False, help="Mock the machines.") +@click.option("--env", default=None, help="Specify a yaml file with machines.") +@click.option( + "--runtime", default=["runc"], help="The runtime to use.", multiple=True) +@click.option("--metric", help="The metric to extract.", multiple=True) +@click.option( + "--runs", default=1, help="The number of times to run each benchmark.") +@click.option( + "--stat", + default="median", + help="How to aggregate the data from all runs." + "\nmedian - returns the median of all runs (default)" + "\nall - returns all results comma separated" + "\nmeanstd - returns result as mean,std") +# pylint: disable=too-many-statements +def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str], + metric: List[str], stat: str, **kwargs): + """Runs arbitrary benchmarks. + + All unknown command line flags are passed through to the underlying benchmark + method. Flags may be specified multiple times, in which case it is considered + a "dimension" for the test, and a comma-separated table will be emitted + instead of a single result. + + See the output of list to see available metrics for any given benchmark + method. The method parameter is a regular expression that will match against + available benchmarks. If multiple benchmarks match, then that is considered a + distinct "dimension" for the test. + + All benchmarks are run in parallel where possible, but have exclusive + ownership over the individual machines. + + Exactly one of the --mock and --env flag must be specified. + + Every benchmark method will be run the times indicated by --runs. + + Args: + ctx: Click context. + method: A regular expression for methods to be run. + runs: Number of runs. + env: Environment to use. + mock: If true, use mocked environment (supercedes env). + runtime: A list of runtimes to test. + metric: A list of metrics to extract. + stat: The class of statistics to extract. + **kwargs: Dimensions to test. + """ + # First, calculate additional arguments. + # + # This essentially calculates any arguments that appear multiple times, and + # moves those to the "dimensions" dictionary, which maps to lists. These + # dimensions are then iterated over to generate the relevant csv output. + dimensions = {} + + if stat not in ["median", "all", "meanstd"]: + raise ValueError("Illegal value for --result, see help.") + + def squish(key: str, value: str): + """Collapse an argument into kwargs or dimensions.""" + if key in dimensions: + # Extend an existing dimension. + dimensions[key].append(value) + elif key in kwargs: + # Create a new dimension. + dimensions[key] = [kwargs[key], value] + del kwargs[key] + else: + # A single value. + kwargs[key] = value + + for item in ctx.args: + if "=" in method: + # This must be the method. The method is simply set to the first + # non-matching argument, which we're also parsing here. + item, method = method, item + if "=" not in item: + logging.error("illegal argument: %s", item) + sys.exit(1) + (key, value) = item.lstrip("-").split("=", 1) + squish(key, value) + + # Convert runtime and metric to dimensions. + # + # They exist only in the arguments above for documentation purposes. + # Essentially here we are treating them like anything else. Note however, + # that an empty set here will result in a dimension. This is important for + # metrics, where an empty set actually means all metrics. + def fold(key: str, value, allow_flatten=False): + """Collapse a list value into kwargs or dimensions.""" + if len(value) == 1 and allow_flatten: + kwargs[key] = value[0] + else: + dimensions[key] = value + + fold("runtime", runtime, allow_flatten=True) + fold("metric", metric) + + # Lookup the methods. + # + # We match the method parameter to a regular expression. This allows you to + # do things like `run --mock .*` for a broad test. Note that we track the + # short_names in the dimensions here, and look up again in the recursion. + methods = { + short_name: func for (short_name, _, func) in find_benchmarks(method) + } + if not methods: + # Must match at least one method. + logging.error("no matching benchmarks for %s: try list.", method) + sys.exit(1) + fold("method", list(methods.keys()), allow_flatten=True) + + # Construct the environment. + if mock and env: + # You can't provide both. + logging.error("both --mock and --env are set: which one is it?") + sys.exit(1) + elif mock: + producer = mock_producer.MockMachineProducer() + elif env: + producer = yaml_producer.YamlMachineProducer(env) + else: + # You must provide one of mock or env. + logging.error("no enviroment provided: use --mock or --env.") + sys.exit(1) + + # Spin up the drivers. + # + # We ensure that metric is the last entry, because we have special behavior. + # They actually run the test once and the benchmark is a generator that + # produces all viable metrics. + dimension_keys = list(dimensions.keys()) + if "metric" in dimension_keys: + dimension_keys.remove("metric") + dimension_keys.append("metric") + drivers = [] + + def _start(keywords, finished, left): + """Runs a test across dimensions recursively.""" + # Resolve the method fully, it starts as a string. + if "method" in keywords and isinstance(keywords["method"], str): + keywords["method"] = methods[keywords["method"]] + # Is this a non-recursive case? + if not left: + driver = benchmark_driver.BenchmarkDriver(producer, runs=runs, **keywords) + driver.start() + drivers.append((finished, driver)) + else: + # Recurse on the next dimension. + current, left = left[0], left[1:] + keywords = copy.deepcopy(keywords) + if current == "metric": + # We use a generator, popped below. Note that metric is + # guaranteed to be the last element here, and we will provide + # the value for 'done' below when generating the csv. + keywords[current] = dimensions[current] + _start(keywords, finished, left) + else: + # Generate manually. + for value in dimensions[current]: + keywords[current] = value + _start(keywords, finished + [value], left) + + # Start all the drivers, recursively. + _start(kwargs, [], dimension_keys) + + # Finish all tests, write results. + output = csv.writer(sys.stdout) + output.writerow(dimension_keys + ["result"]) + for (done, driver) in drivers: + driver.join() + for (metric_name, result) in getattr(driver, stat)(): + output.writerow([ # Collapse the method name. + hasattr(x, "__name__") and x.__name__ or x for x in done + ] + [metric_name] + result) + + +@runner.command() +@click.argument("env") +@click.option( + "--cmd", default="uname -a", help="command to run on all found machines") +@click.option( + "--workload", default="true", help="workload to run all found machines") +def validate(env, cmd, workload): + """Validates an environment described by yaml file.""" + producer = yaml_producer.YamlMachineProducer(env) + for machine in producer.machines: + print("Machine %s:" % machine) + stdout, _ = machine.run(cmd) + print(" Output of '%s': %s" % (cmd, stdout.lstrip().rstrip())) + image = machine.pull(workload) + stdout = machine.container(image).run() + print(" Container %s: %s" % (workload, stdout.lstrip().rstrip())) diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py new file mode 100644 index 000000000..5719c2838 --- /dev/null +++ b/benchmarks/runner/runner_test.py @@ -0,0 +1,59 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Top-level tests.""" + +import os +import subprocess +import sys + +from click import testing +import pytest + +from benchmarks import runner + + +def _get_locale(): + output = subprocess.check_output(["locale", "-a"]) + locales = output.split() + if b"en_US.utf8" in locales: + return "en_US.UTF-8" + else: + return "C.UTF-8" + + +def _set_locale(): + locale = _get_locale() + if os.getenv("LANG") != locale: + os.environ["LANG"] = locale + os.environ["LC_ALL"] = locale + os.execv("/proc/self/exe", ["python"] + sys.argv) + + +def test_list(): + cli_runner = testing.CliRunner() + result = cli_runner.invoke(runner.runner, ["list"]) + print(result.output) + assert result.exit_code == 0 + + +def test_run(): + cli_runner = testing.CliRunner() + result = cli_runner.invoke(runner.runner, ["run", "--mock", "."]) + print(result.output) + assert result.exit_code == 0 + + +if __name__ == "__main__": + _set_locale() + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/suites/BUILD b/benchmarks/suites/BUILD new file mode 100644 index 000000000..04fc23261 --- /dev/null +++ b/benchmarks/suites/BUILD @@ -0,0 +1,130 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "suites", + srcs = ["__init__.py"], +) + +py_library( + name = "absl", + srcs = ["absl.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/absl", + ], +) + +py_library( + name = "density", + srcs = ["density.py"], + deps = [ + "//benchmarks/harness:container", + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + ], +) + +py_library( + name = "fio", + srcs = ["fio.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + "//benchmarks/workloads/fio", + ], +) + +py_library( + name = "helpers", + srcs = ["helpers.py"], + deps = ["//benchmarks/harness:machine"], +) + +py_library( + name = "http", + srcs = ["http.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/ab", + ], +) + +py_library( + name = "media", + srcs = ["media.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + "//benchmarks/workloads/ffmpeg", + ], +) + +py_library( + name = "ml", + srcs = ["ml.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:startup", + "//benchmarks/workloads/tensorflow", + ], +) + +py_library( + name = "network", + srcs = ["network.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + "//benchmarks/workloads/iperf", + ], +) + +py_library( + name = "redis", + srcs = ["redis.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/redisbenchmark", + ], +) + +py_library( + name = "startup", + srcs = ["startup.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/suites:helpers", + ], +) + +py_library( + name = "sysbench", + srcs = ["sysbench.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/sysbench", + ], +) + +py_library( + name = "syscall", + srcs = ["syscall.py"], + deps = [ + "//benchmarks/harness:machine", + "//benchmarks/suites", + "//benchmarks/workloads/syscall", + ], +) diff --git a/benchmarks/suites/__init__.py b/benchmarks/suites/__init__.py new file mode 100644 index 000000000..360736cc3 --- /dev/null +++ b/benchmarks/suites/__init__.py @@ -0,0 +1,119 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core benchmark annotations.""" + +import functools +import inspect +import types +from typing import List +from typing import Tuple + +BENCHMARK_METRICS = '__benchmark_metrics__' +BENCHMARK_MACHINES = '__benchmark_machines__' + + +def is_benchmark(func: types.FunctionType) -> bool: + """Returns true if the given function is a benchmark.""" + return isinstance(func, types.FunctionType) and \ + hasattr(func, BENCHMARK_METRICS) and \ + hasattr(func, BENCHMARK_MACHINES) + + +def benchmark_metrics(func: types.FunctionType) -> List[Tuple[str, str]]: + """Returns the list of available metrics.""" + return [(metric.__name__, metric.__doc__) + for metric in getattr(func, BENCHMARK_METRICS)] + + +def benchmark_machines(func: types.FunctionType) -> int: + """Returns the number of machines required.""" + return getattr(func, BENCHMARK_MACHINES) + + +# pylint: disable=unused-argument +def default(value, **kwargs): + """Returns the passed value.""" + return value + + +def benchmark(metrics: List[types.FunctionType] = None, + machines: int = 1) -> types.FunctionType: + """Define a benchmark function with metrics. + + Args: + metrics: A list of metric functions. + machines: The number of machines required. + + Returns: + A function that accepts the given number of machines, and iteratively + returns a set of (metric_name, metric_value) pairs when called repeatedly. + """ + if not metrics: + # The default passes through. + metrics = [default] + + def decorator(func: types.FunctionType) -> types.FunctionType: + """Decorator function.""" + # Every benchmark should accept at least two parameters: + # runtime: The runtime to use for the benchmark (str, required). + # metrics: The metrics to use, if not the default (str, optional). + @functools.wraps(func) + def wrapper(*args, runtime: str, metric: list = None, **kwargs): + """Wrapper function.""" + # First -- ensure that we marshall all types appropriately. In + # general, we will call this with only strings. These strings will + # need to be converted to their underlying types/classes. + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.annotation != inspect.Parameter.empty and \ + param.name in kwargs and not isinstance(kwargs[param.name], param.annotation): + try: + # Marshall to the appropriate type. + kwargs[param.name] = param.annotation(kwargs[param.name]) + except Exception as exc: + raise ValueError( + 'illegal type for %s(%s=%s): %s' % + (func.__name__, param.name, kwargs[param.name], exc)) + elif param.default != inspect.Parameter.empty and \ + param.name not in kwargs: + # Ensure that we have the value set, because it will + # be passed to the metric function for evaluation. + kwargs[param.name] = param.default + + # Next, figure out how to apply a metric. We do this prior to + # running the underlying function to prevent having to wait a few + # minutes for a result just to see some error. + if not metric: + # Return all metrics in the iterator. + result = func(*args, runtime=runtime, **kwargs) + for metric_func in metrics: + yield (metric_func.__name__, metric_func(result, **kwargs)) + else: + result = None + for single_metric in metric: + for metric_func in metrics: + # Is this a function that matches the name? + # Apply this function to the result. + if metric_func.__name__ == single_metric: + if not result: + # Lazy evaluation: only if metric matches. + result = func(*args, runtime=runtime, **kwargs) + yield single_metric, metric_func(result, **kwargs) + + # Set metadata on the benchmark (used above). + setattr(wrapper, BENCHMARK_METRICS, metrics) + setattr(wrapper, BENCHMARK_MACHINES, machines) + return wrapper + + return decorator diff --git a/benchmarks/suites/absl.py b/benchmarks/suites/absl.py new file mode 100644 index 000000000..5d9b57a09 --- /dev/null +++ b/benchmarks/suites/absl.py @@ -0,0 +1,37 @@ +# python3 +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""absl build benchmark.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import absl + + +@suites.benchmark(metrics=[absl.elapsed_time], machines=1) +def build(target: machine.Machine, **kwargs) -> str: + """Runs the absl workload and report the absl build time. + + Runs the 'bazel build //absl/...' in a clean bazel directory and + monitors time elapsed. + + Args: + target: A machine object. + **kwargs: Additional container options. + + Returns: + Container output. + """ + image = target.pull("absl") + return target.container(image, **kwargs).run() diff --git a/benchmarks/suites/density.py b/benchmarks/suites/density.py new file mode 100644 index 000000000..89d29fb26 --- /dev/null +++ b/benchmarks/suites/density.py @@ -0,0 +1,121 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Density tests.""" + +import re +import types + +from benchmarks import suites +from benchmarks.harness import container +from benchmarks.harness import machine +from benchmarks.suites import helpers + + +# pylint: disable=unused-argument +def memory_usage(value, **kwargs): + """Returns the passed value.""" + return value + + +def density(target: machine.Machine, + workload: str, + count: int = 50, + wait: float = 0, + load_func: types.FunctionType = None, + **kwargs): + """Calculate the average memory usage per container. + + Args: + target: A machine object. + workload: The workload to run. + count: The number of containers to start. + wait: The time to wait after starting. + load_func: Callback that is called after count images have been started on + the given machine. + **kwargs: Additional container options. + + Returns: + The average usage in Kb per container. + """ + count = int(count) + + # Drop all caches. + helpers.drop_caches(target) + before = target.read("/proc/meminfo") + + # Load the workload. + image = target.pull(workload) + + with target.container( + image=image, count=count, **kwargs).detach() as containers: + # Call the optional load function callback if given. + if load_func: + load_func(target, containers) + # Wait 'wait' time before taking a measurement. + target.sleep(wait) + + # Drop caches again. + helpers.drop_caches(target) + after = target.read("/proc/meminfo") + + # Calculate the memory used. + available_re = re.compile(r"MemAvailable:\s*(\d+)\skB\n") + before_available = available_re.findall(before) + after_available = available_re.findall(after) + return 1024 * float(int(before_available[0]) - + int(after_available[0])) / float(count) + + +def load_redis(target: machine.Machine, containers: container.Container): + """Use redis-benchmark "LPUSH" to load each container with 1G of data. + + Args: + target: A machine object. + containers: A set of containers. + """ + target.pull("redisbenchmark") + for name in containers.get_names(): + flags = "-d 10000 -t LPUSH" + target.container( + "redisbenchmark", links={ + name: name + }).run( + host=name, flags=flags) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def empty(target: machine.Machine, **kwargs) -> float: + """Run trivial containers in a density test.""" + return density(target, workload="sleep", wait=1.0, **kwargs) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def node(target: machine.Machine, **kwargs) -> float: + """Run node containers in a density test.""" + return density(target, workload="node", wait=3.0, **kwargs) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def ruby(target: machine.Machine, **kwargs) -> float: + """Run ruby containers in a density test.""" + return density(target, workload="ruby", wait=3.0, **kwargs) + + +@suites.benchmark(metrics=[memory_usage], machines=1) +def redis(target: machine.Machine, **kwargs) -> float: + """Run redis containers in a density test.""" + if "count" not in kwargs: + kwargs["count"] = 5 + return density( + target, workload="redis", wait=3.0, load_func=load_redis, **kwargs) diff --git a/benchmarks/suites/fio.py b/benchmarks/suites/fio.py new file mode 100644 index 000000000..2171790c5 --- /dev/null +++ b/benchmarks/suites/fio.py @@ -0,0 +1,165 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File I/O tests.""" + +import os + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers +from benchmarks.workloads import fio + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +def run_fio(target: machine.Machine, + test: str, + ioengine: str = "sync", + size: int = 1024 * 1024 * 1024, + iodepth: int = 4, + blocksize: int = 1024 * 1024, + time: int = -1, + mount_dir: str = "", + filename: str = "file.dat", + tmpfs: bool = False, + ramp_time: int = 0, + **kwargs) -> str: + """FIO benchmarks. + + For more on fio see: + https://media.readthedocs.org/pdf/fio/latest/fio.pdf + + Args: + target: A machine object. + test: The test to run (read, write, randread, randwrite, etc.) + ioengine: The engine for I/O. + size: The size of the generated file in bytes (if an integer) or 5g, 16k, + etc. + iodepth: The I/O for certain engines. + blocksize: The blocksize for reads and writes in bytes (if an integer) or + 4k, etc. + time: If test is time based, how long to run in seconds. + mount_dir: The absolute path on the host to mount a bind mount. + filename: The name of the file to creat inside container. For a path of + /dir/dir/file, the script setup a volume like 'docker run -v + mount_dir:/dir/dir fio' and fio will create (and delete) the file + /dir/dir/file. If tmpfs is set, this /dir/dir will be a tmpfs. + tmpfs: If true, mount on tmpfs. + ramp_time: The time to run before recording statistics + **kwargs: Additional container options. + + Returns: + The output of fio as a string. + """ + # Pull the image before dropping caches. + image = target.pull("fio") + + if not mount_dir: + stdout, _ = target.run("pwd") + mount_dir = stdout.rstrip() + + # Setup the volumes. + volumes = {mount_dir: {"bind": "/disk", "mode": "rw"}} if not tmpfs else None + tmpfs = {"/disk": ""} if tmpfs else None + + # Construct a file in the volume. + filepath = os.path.join("/disk", filename) + + # If we are running a read test, us fio to write a file and then flush file + # data from memory. + if "read" in test: + target.container( + image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( + test="write", + ioengine="sync", + size=size, + iodepth=iodepth, + blocksize=blocksize, + path=filepath) + helpers.drop_caches(target) + + # Run the test. + time_str = "--time_base --runtime={time}".format( + time=time) if int(time) > 0 else "" + res = target.container( + image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( + test=test, + ioengine=ioengine, + size=size, + iodepth=iodepth, + blocksize=blocksize, + time=time_str, + path=filepath, + ramp_time=ramp_time) + + target.run( + "rm {path}".format(path=os.path.join(mount_dir.rstrip(), filename))) + + return res + + +@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) +def read(*args, **kwargs): + """Read test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="read", **kwargs) + + +@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) +def randread(*args, **kwargs): + """Random read test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="randread", **kwargs) + + +@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) +def write(*args, **kwargs): + """Write test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="write", **kwargs) + + +@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) +def randwrite(*args, **kwargs): + """Random write test. + + Args: + *args: None. + **kwargs: Additional container options. + + Returns: + The output of fio. + """ + return run_fio(*args, test="randwrite", **kwargs) diff --git a/benchmarks/suites/helpers.py b/benchmarks/suites/helpers.py new file mode 100644 index 000000000..b3c7360ab --- /dev/null +++ b/benchmarks/suites/helpers.py @@ -0,0 +1,57 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Benchmark helpers.""" + +import datetime +from benchmarks.harness import machine + + +class Timer: + """Helper to time runtime of some call. + + Usage: + + with Timer as t: + # do something. + t.get_time_in_seconds() + """ + + def __init__(self): + self._start = datetime.datetime.now() + + def __enter__(self): + self.start() + return self + + def start(self): + """Starts the timer.""" + self._start = datetime.datetime.now() + + def elapsed(self) -> float: + """Returns the elapsed time in seconds.""" + return (datetime.datetime.now() - self._start).total_seconds() + + def __exit__(self, exception_type, exception_value, exception_traceback): + pass + + +def drop_caches(target: machine.Machine): + """Drops caches on the machine. + + Args: + target: A machine object. + """ + target.run("sudo sync") + target.run("sudo sysctl vm.drop_caches=3") + target.run("sudo sysctl vm.drop_caches=3") diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py new file mode 100644 index 000000000..ea9024e43 --- /dev/null +++ b/benchmarks/suites/http.py @@ -0,0 +1,138 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""HTTP benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import ab + + +# pylint: disable=too-many-arguments +def http(server: machine.Machine, + client: machine.Machine, + workload: str, + requests: int = 5000, + connections: int = 10, + port: int = 80, + path: str = "notfound", + **kwargs) -> str: + """Run apachebench (ab) against an http server. + + Args: + server: A machine object. + client: A machine object. + workload: The http-serving workload. + requests: Number of requests to send the server. Default is 5000. + connections: Number of concurent connections to use. Default is 10. + port: The port to access in benchmarking. + path: File to download, generally workload-specific. + **kwargs: Additional container options. + + Returns: + The full apachebench output. + """ + # Pull the client & server. + apachebench = client.pull("ab") + netcat = client.pull("netcat") + image = server.pull(workload) + + with server.container(image, port=port, **kwargs).detach() as container: + (host, port) = container.address() + # Wait for the server to come up. + client.container(netcat).run(host=host, port=port) + # Run the benchmark, no arguments. + return client.container(apachebench).run( + host=host, + port=port, + requests=requests, + connections=connections, + path=path) + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +def http_app(server: machine.Machine, + client: machine.Machine, + workload: str, + requests: int = 5000, + connections: int = 10, + port: int = 80, + path: str = "notfound", + **kwargs) -> str: + """Run apachebench (ab) against an http application. + + Args: + server: A machine object. + client: A machine object. + workload: The http-serving workload. + requests: Number of requests to send the server. Default is 5000. + connections: Number of concurent connections to use. Default is 10. + port: The port to use for benchmarking. + path: File to download, generally workload-specific. + **kwargs: Additional container options. + + Returns: + The full apachebench output. + """ + # Pull the client & server. + apachebench = client.pull("ab") + netcat = client.pull("netcat") + server_netcat = server.pull("netcat") + redis = server.pull("redis") + image = server.pull(workload) + redis_port = 6379 + redis_name = "redis_server" + + with server.container(redis, name=redis_name).detach(): + server.container(server_netcat, links={redis_name: redis_name})\ + .run(host=redis_name, port=redis_port) + with server.container(image, port=port, links={redis_name: redis_name}, **kwargs)\ + .detach(host=redis_name) as container: + (host, port) = container.address() + # Wait for the server to come up. + client.container(netcat).run(host=host, port=port) + # Run the benchmark, no arguments. + return client.container(apachebench).run( + host=host, + port=port, + requests=requests, + connections=connections, + path=path) + + +@suites.benchmark(metrics=[ab.transfer_rate, ab.latency], machines=2) +def httpd(*args, **kwargs) -> str: + """Apache2 benchmark.""" + return http(*args, workload="httpd", port=80, **kwargs) + + +@suites.benchmark( + metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) +def nginx(*args, **kwargs) -> str: + """Nginx benchmark.""" + return http(*args, workload="nginx", port=80, **kwargs) + + +@suites.benchmark( + metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) +def node(*args, **kwargs) -> str: + """Node benchmark.""" + return http_app(*args, workload="node_template", path="", port=8080, **kwargs) + + +@suites.benchmark( + metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) +def ruby(*args, **kwargs) -> str: + """Ruby benchmark.""" + return http_app(*args, workload="ruby_template", path="", port=9292, **kwargs) diff --git a/benchmarks/suites/media.py b/benchmarks/suites/media.py new file mode 100644 index 000000000..9cbffdaa1 --- /dev/null +++ b/benchmarks/suites/media.py @@ -0,0 +1,42 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Media processing benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers +from benchmarks.workloads import ffmpeg + + +@suites.benchmark(metrics=[ffmpeg.run_time], machines=1) +def transcode(target: machine.Machine, **kwargs) -> float: + """Runs a video transcoding workload and times it. + + Args: + target: A machine object. + **kwargs: Additional container options. + + Returns: + Total workload runtime. + """ + # Load before timing. + image = target.pull("ffmpeg") + + # Drop caches. + helpers.drop_caches(target) + + # Time startup + transcoding. + with helpers.Timer() as timer: + target.container(image, **kwargs).run() + return timer.elapsed() diff --git a/benchmarks/suites/ml.py b/benchmarks/suites/ml.py new file mode 100644 index 000000000..a394d1f69 --- /dev/null +++ b/benchmarks/suites/ml.py @@ -0,0 +1,33 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Machine Learning tests.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import startup +from benchmarks.workloads import tensorflow + + +@suites.benchmark(metrics=[tensorflow.run_time], machines=1) +def train(target: machine.Machine, **kwargs): + """Run the tensorflow benchmark and return the runtime in seconds of workload. + + Args: + target: A machine object. + **kwargs: Additional container options. + + Returns: + The total runtime. + """ + return startup.startup(target, workload="tensorflow", count=1, **kwargs) diff --git a/benchmarks/suites/network.py b/benchmarks/suites/network.py new file mode 100644 index 000000000..f973cf3f1 --- /dev/null +++ b/benchmarks/suites/network.py @@ -0,0 +1,101 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Network microbenchmarks.""" + +from typing import Dict + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers +from benchmarks.workloads import iperf + + +def run_iperf(client: machine.Machine, + server: machine.Machine, + client_kwargs: Dict[str, str] = None, + server_kwargs: Dict[str, str] = None) -> str: + """Measure iperf performance. + + Args: + client: A machine object. + server: A machine object. + client_kwargs: Additional client container options. + server_kwargs: Additional server container options. + + Returns: + The output of iperf. + """ + if not client_kwargs: + client_kwargs = dict() + if not server_kwargs: + server_kwargs = dict() + + # Pull images. + netcat = client.pull("netcat") + iperf_client_image = client.pull("iperf") + iperf_server_image = server.pull("iperf") + + # Set this due to a bug in the kernel that resets connections. + client.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") + server.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") + + with server.container( + iperf_server_image, port=5001, **server_kwargs).detach() as iperf_server: + (host, port) = iperf_server.address() + # Wait until the service is available. + client.container(netcat).run(host=host, port=port) + # Run a warm-up run. + client.container( + iperf_client_image, stderr=True, **client_kwargs).run( + host=host, port=port) + # Run the client with relevant arguments. + res = client.container(iperf_client_image, stderr=True, **client_kwargs)\ + .run(host=host, port=port) + helpers.drop_caches(client) + return res + + +@suites.benchmark(metrics=[iperf.bandwidth], machines=2) +def upload(client: machine.Machine, server: machine.Machine, **kwargs) -> str: + """Measure upload performance. + + Args: + client: A machine object. + server: A machine object. + **kwargs: Client container options. + + Returns: + The output of iperf. + """ + if kwargs["runtime"] == "runc": + kwargs["network_mode"] = "host" + return run_iperf(client, server, client_kwargs=kwargs) + + +@suites.benchmark(metrics=[iperf.bandwidth], machines=2) +def download(client: machine.Machine, server: machine.Machine, **kwargs) -> str: + """Measure download performance. + + Args: + client: A machine object. + server: A machine object. + **kwargs: Server container options. + + Returns: + The output of iperf. + """ + + client_kwargs = {"network_mode": "host"} + return run_iperf( + client, server, client_kwargs=client_kwargs, server_kwargs=kwargs) diff --git a/benchmarks/suites/redis.py b/benchmarks/suites/redis.py new file mode 100644 index 000000000..b84dd073d --- /dev/null +++ b/benchmarks/suites/redis.py @@ -0,0 +1,46 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Redis benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import redisbenchmark + + +@suites.benchmark(metrics=list(redisbenchmark.METRICS.values()), machines=2) +def redis(server: machine.Machine, + client: machine.Machine, + flags: str = "", + **kwargs) -> str: + """Run redis-benchmark on client pointing at server machine. + + Args: + server: A machine object. + client: A machine object. + flags: Flags to pass redis-benchmark. + **kwargs: Additional container options. + + Returns: + Output from redis-benchmark. + """ + redis_server = server.pull("redis") + redis_client = client.pull("redisbenchmark") + netcat = client.pull("netcat") + with server.container( + redis_server, port=6379, **kwargs).detach() as container: + (host, port) = container.address() + # Wait for the container to be up. + client.container(netcat).run(host=host, port=port) + # Run all redis benchmarks. + return client.container(redis_client).run(host=host, port=port, flags=flags) diff --git a/benchmarks/suites/startup.py b/benchmarks/suites/startup.py new file mode 100644 index 000000000..a1b6c5753 --- /dev/null +++ b/benchmarks/suites/startup.py @@ -0,0 +1,110 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Start-up benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.suites import helpers + + +# pylint: disable=unused-argument +def startup_time_ms(value, **kwargs): + """Returns average startup time per container in milliseconds. + + Args: + value: The floating point time in seconds. + **kwargs: Ignored. + + Returns: + The time given in milliseconds. + """ + return value * 1000 + + +def startup(target: machine.Machine, + workload: str, + count: int = 5, + port: int = 0, + **kwargs): + """Time the startup of some workload. + + Args: + target: A machine object. + workload: The workload to run. + count: Number of containers to start. + port: The port to check for liveness, if provided. + **kwargs: Additional container options. + + Returns: + The mean start-up time in seconds. + """ + # Load before timing. + image = target.pull(workload) + netcat = target.pull("netcat") + count = int(count) + port = int(port) + + with helpers.Timer() as timer: + for _ in range(count): + if not port: + # Run the container synchronously. + target.container(image, **kwargs).run() + else: + # Run a detached container until httpd available. + with target.container(image, port=port, **kwargs).detach() as server: + (server_host, server_port) = server.address() + target.container(netcat).run(host=server_host, port=server_port) + return timer.elapsed() / float(count) + + +@suites.benchmark(metrics=[startup_time_ms], machines=1) +def empty(target: machine.Machine, **kwargs) -> float: + """Time the startup of a trivial container. + + Args: + target: A machine object. + **kwargs: Additional startup options. + + Returns: + The time to run the container. + """ + return startup(target, workload="true", **kwargs) + + +@suites.benchmark(metrics=[startup_time_ms], machines=1) +def node(target: machine.Machine, **kwargs) -> float: + """Time the startup of the node container. + + Args: + target: A machine object. + **kwargs: Additional statup options. + + Returns: + The time to run the container. + """ + return startup(target, workload="node", port=8080, **kwargs) + + +@suites.benchmark(metrics=[startup_time_ms], machines=1) +def ruby(target: machine.Machine, **kwargs) -> float: + """Time the startup of the ruby container. + + Args: + target: A machine object. + **kwargs: Additional startup options. + + Returns: + The time to run the container. + """ + return startup(target, workload="ruby", port=3000, **kwargs) diff --git a/benchmarks/suites/sysbench.py b/benchmarks/suites/sysbench.py new file mode 100644 index 000000000..2a6e2126c --- /dev/null +++ b/benchmarks/suites/sysbench.py @@ -0,0 +1,119 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sysbench-based benchmarks.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads import sysbench + + +def run_sysbench(target: machine.Machine, + test: str = "cpu", + threads: int = 8, + time: int = 5, + options: str = "", + **kwargs) -> str: + """Run sysbench container with arguments. + + Args: + target: A machine object. + test: Relevant sysbench test to run (e.g. cpu, memory). + threads: The number of threads to use for tests. + time: The time to run tests. + options: Additional sysbench options. + **kwargs: Additional container options. + + Returns: + The output of the command as a string. + """ + image = target.pull("sysbench") + return target.container(image, **kwargs).run( + test=test, threads=threads, time=time, options=options) + + +@suites.benchmark(metrics=[sysbench.cpu_events_per_second], machines=1) +def cpu(target: machine.Machine, max_prime: int = 5000, **kwargs) -> str: + """Run sysbench CPU test. + + Additional arguments can be provided for sysbench. + + Args: + target: A machine object. + max_prime: The maximum prime number to search. + **kwargs: + - threads: The number of threads to use for tests. + - time: The time to run tests. + - options: Additional sysbench options. See sysbench tool: + https://github.com/akopytov/sysbench + + Returns: + Sysbench output. + """ + options = kwargs.pop("options", "") + options += " --cpu-max-prime={}".format(max_prime) + return run_sysbench(target, test="cpu", options=options, **kwargs) + + +@suites.benchmark(metrics=[sysbench.memory_ops_per_second], machines=1) +def memory(target: machine.Machine, **kwargs) -> str: + """Run sysbench memory test. + + Additional arguments can be provided per sysbench. + + Args: + target: A machine object. + **kwargs: + - threads: The number of threads to use for tests. + - time: The time to run tests. + - options: Additional sysbench options. See sysbench tool: + https://github.com/akopytov/sysbench + + Returns: + Sysbench output. + """ + return run_sysbench(target, test="memory", **kwargs) + + +@suites.benchmark( + metrics=[ + sysbench.mutex_time, sysbench.mutex_latency, sysbench.mutex_deviation + ], + machines=1) +def mutex(target: machine.Machine, + locks: int = 4, + count: int = 10000000, + threads: int = 8, + **kwargs) -> str: + """Run sysbench mutex test. + + Additional arguments can be provided per sysbench. + + Args: + target: A machine object. + locks: The number of locks to use. + count: The number of mutexes. + threads: The number of threads to use for tests. + **kwargs: + - time: The time to run tests. + - options: Additional sysbench options. See sysbench tool: + https://github.com/akopytov/sysbench + + Returns: + Sysbench output. + """ + options = kwargs.pop("options", "") + options += " --mutex-loops=1 --mutex-locks={} --mutex-num={}".format( + count, locks) + return run_sysbench( + target, test="mutex", options=options, threads=threads, **kwargs) diff --git a/benchmarks/suites/syscall.py b/benchmarks/suites/syscall.py new file mode 100644 index 000000000..fa7665b00 --- /dev/null +++ b/benchmarks/suites/syscall.py @@ -0,0 +1,37 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Syscall microbenchmark.""" + +from benchmarks import suites +from benchmarks.harness import machine +from benchmarks.workloads.syscall import syscall_time_ns + + +@suites.benchmark(metrics=[syscall_time_ns], machines=1) +def syscall(target: machine.Machine, count: int = 1000000, **kwargs) -> str: + """Runs the syscall workload and report the syscall time. + + Runs the syscall 'SYS_gettimeofday(0,0)' 'count' times and monitors time + elapsed based on the runtime's MONOTONIC clock. + + Args: + target: A machine object. + count: The number of syscalls to execute. + **kwargs: Additional container options. + + Returns: + Container output. + """ + image = target.pull("syscall") + return target.container(image, **kwargs).run(count=count) diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD new file mode 100644 index 000000000..735d7127f --- /dev/null +++ b/benchmarks/tcp/BUILD @@ -0,0 +1,41 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("@rules_cc//cc:defs.bzl", "cc_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "tcp_proxy", + srcs = ["tcp_proxy.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +# nsjoin is a trivial replacement for nsenter. This is used because nsenter is +# not available on all systems where this benchmark is run (and we aim to +# minimize external dependencies.) + +cc_binary( + name = "nsjoin", + srcs = ["nsjoin.c"], + visibility = ["//:sandbox"], +) + +sh_binary( + name = "tcp_benchmark", + srcs = ["tcp_benchmark.sh"], + data = [ + ":nsjoin", + ":tcp_proxy", + ], + visibility = ["//:sandbox"], +) diff --git a/benchmarks/tcp/README.md b/benchmarks/tcp/README.md new file mode 100644 index 000000000..38e6e69f0 --- /dev/null +++ b/benchmarks/tcp/README.md @@ -0,0 +1,87 @@ +# TCP Benchmarks + +This directory contains a standardized TCP benchmark. This helps to evaluate the +performance of netstack and native networking stacks under various conditions. + +## `tcp_benchmark` + +This benchmark allows TCP throughput testing under various conditions. The setup +consists of an iperf client, a client proxy, a server proxy and an iperf server. +The client proxy and server proxy abstract the network mechanism used to +communicate between the iperf client and server. + +The setup looks like the following: + +``` + +--------------+ (native) +--------------+ + | iperf client |[lo @ 10.0.0.1]------>| client proxy | + +--------------+ +--------------+ + [client.0 @ 10.0.0.2] + (netstack) | | (native) + +------+-----+ + | + [br0] + | + Network emulation applied ---> [wan.0:wan.1] + | + [br1] + | + +------+-----+ + (netstack) | | (native) + [server.0 @ 10.0.0.3] + +--------------+ +--------------+ + | iperf server |<------[lo @ 10.0.0.4]| server proxy | + +--------------+ (native) +--------------+ +``` + +Different configurations can be run using different arguments. For example: + +* Native test under normal internet conditions: `tcp_benchmark` +* Native test under ideal conditions: `tcp_benchmark --ideal` +* Netstack client under ideal conditions: `tcp_benchmark --client --ideal` +* Netstack client with 5% packet loss: `tcp_benchmark --client --ideal --loss + 5` + +Use `tcp_benchmark --help` for full arguments. + +This tool may be used to easily generate data for graphing. For example, to +generate a CSV for various latencies, you might do: + +``` +rm -f /tmp/netstack_latency.csv /tmp/native_latency.csv +latencies=$(seq 0 5 50; + seq 60 10 100; + seq 125 25 250; + seq 300 50 500) +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/netstack_latency.csv +done +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/native_latency.csv +done +``` + +Similarly, to generate a CSV for various levels of packet loss, the following +would be appropriate: + +``` +rm -f /tmp/netstack_loss.csv /tmp/native_loss.csv +losses=$(seq 0 0.1 1.0; + seq 1.2 0.2 2.0; + seq 2.5 0.5 5.0; + seq 6.0 1.0 10.0) +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/netstack_loss.csv +done +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/native_loss.csv +done +``` diff --git a/benchmarks/tcp/nsjoin.c b/benchmarks/tcp/nsjoin.c new file mode 100644 index 000000000..524b4d549 --- /dev/null +++ b/benchmarks/tcp/nsjoin.c @@ -0,0 +1,47 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include <errno.h> +#include <fcntl.h> +#include <sched.h> +#include <stdio.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +int main(int argc, char** argv) { + if (argc <= 2) { + fprintf(stderr, "error: must provide a namespace file.\n"); + fprintf(stderr, "usage: %s <file> [arguments...]\n", argv[0]); + return 1; + } + + int fd = open(argv[1], O_RDONLY); + if (fd < 0) { + fprintf(stderr, "error opening %s: %s\n", argv[1], strerror(errno)); + return 1; + } + if (setns(fd, 0) < 0) { + fprintf(stderr, "error joining %s: %s\n", argv[1], strerror(errno)); + return 1; + } + + execvp(argv[2], &argv[2]); + return 1; +} diff --git a/benchmarks/tcp/tcp_benchmark.sh b/benchmarks/tcp/tcp_benchmark.sh new file mode 100755 index 000000000..69344c9c3 --- /dev/null +++ b/benchmarks/tcp/tcp_benchmark.sh @@ -0,0 +1,369 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TCP benchmark; see README.md for documentation. + +# Fixed parameters. +iperf_port=45201 # Not likely to be privileged. +proxy_port=44000 # Ditto. +client_addr=10.0.0.1 +client_proxy_addr=10.0.0.2 +server_proxy_addr=10.0.0.3 +server_addr=10.0.0.4 +mask=8 + +# Defaults; this provides a reasonable approximation of a decent internet link. +# Parameters can be varied independently from this set to see response to +# various changes in the kind of link available. +client=false +server=false +verbose=false +gso=0 +swgso=false +mtu=1280 # 1280 is a reasonable lowest-common-denominator. +latency=10 # 10ms approximates a fast, dedicated connection. +latency_variation=1 # +/- 1ms is a relatively low amount of jitter. +loss=0.1 # 0.1% loss is non-zero, but not extremely high. +duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses. +duration=30 # 30s is enough time to consistent results (experimentally). +helper_dir=$(dirname $0) +netstack_opts= + +# Check for netem support. +lsmod_output=$(lsmod | grep sch_netem) +if [ "$?" != "0" ]; then + echo "warning: sch_netem may not be installed." >&2 +fi + +while [ $# -gt 0 ]; do + case "$1" in + --client) + client=true + ;; + --client_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -client_tcp_probe_file=$1" + ;; + --server) + server=true + ;; + --verbose) + verbose=true + ;; + --gso) + shift + gso=$1 + ;; + --swgso) + swgso=true + ;; + --server_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -server_tcp_probe_file=$1" + ;; + --ideal) + mtu=1500 # Standard ethernet. + latency=0 # No latency. + latency_variation=0 # No jitter. + loss=0 # No loss. + duplicate=0 # No duplicates. + ;; + --mtu) + shift + [ "$#" -le 0 ] && echo "no mtu provided" && exit 1 + mtu=$1 + ;; + --sack) + netstack_opts="${netstack_opts} -sack" + ;; + --cubic) + netstack_opts="${netstack_opts} -cubic" + ;; + --duration) + shift + [ "$#" -le 0 ] && echo "no duration provided" && exit 1 + duration=$1 + ;; + --latency) + shift + [ "$#" -le 0 ] && echo "no latency provided" && exit 1 + latency=$1 + ;; + --latency-variation) + shift + [ "$#" -le 0 ] && echo "no latency variation provided" && exit 1 + latency_variation=$1 + ;; + --loss) + shift + [ "$#" -le 0 ] && echo "no loss probability provided" && exit 1 + loss=$1 + ;; + --duplicate) + shift + [ "$#" -le 0 ] && echo "no duplicate provided" && exit 1 + duplicate=$1 + ;; + --cpuprofile) + shift + netstack_opts="${netstack_opts} -cpuprofile=$1" + ;; + --memprofile) + shift + netstack_opts="${netstack_opts} -memprofile=$1" + ;; + --helpers) + shift + [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1 + helper_dir=$1 + ;; + *) + echo "usage: $0 [options]" + echo "options:" + echo " --help show this message" + echo " --verbose verbose output" + echo " --client use netstack as the client" + echo " --ideal reset all network emulation" + echo " --server use netstack as the server" + echo " --mtu set the mtu (bytes)" + echo " --sack enable SACK support" + echo " --cubic enable CUBIC congestion control for Netstack" + echo " --duration set the test duration (s)" + echo " --latency set the latency (ms)" + echo " --latency-variation set the latency variation" + echo " --loss set the loss probability (%)" + echo " --duplicate set the duplicate probability (%)" + echo " --helpers set the helper directory" + echo "" + echo "The output will of the script will be:" + echo " <throughput> <client-cpu-usage> <server-cpu-usage>" + exit 1 + esac + shift +done + +if [ ${verbose} == "true" ]; then + set -x +fi + +# Latency needs to be halved, since it's applied on both ways. +half_latency=$(echo ${latency}/2 | bc -l | awk '{printf "%1.2f", $0}') +half_loss=$(echo ${loss}/2 | bc -l | awk '{printf "%1.6f", $0}') +half_duplicate=$(echo ${duplicate}/2 | bc -l | awk '{printf "%1.6f", $0}') +helper_dir=${helper_dir#$(pwd)/} # Use relative paths. +proxy_binary=${helper_dir}/tcp_proxy +nsjoin_binary=${helper_dir}/nsjoin + +if [ ! -e ${proxy_binary} ]; then + echo "Could not locate ${proxy_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ ! -e ${nsjoin_binary} ]; then + echo "Could not locate ${nsjoin_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ $(echo ${latency_variation} | awk '{printf "%1.2f", $0}') != "0.00" ]; then + # As long as there's some jitter, then we use the paretonormal distribution. + # This will preserve the minimum RTT, but add a realistic amount of jitter to + # the connection and cause re-ordering, etc. The regular pareto distribution + # appears to an unreasonable level of delay (we want only small spikes.) + distribution="distribution paretonormal" +else + distribution="" +fi + +# Client proxy that will listen on the client's iperf target forward traffic +# using the host networking stack. +client_args="${proxy_binary} -port ${proxy_port} -forward ${server_proxy_addr}:${proxy_port}" +if ${client}; then + # Client proxy that will listen on the client's iperf target + # and forward traffic using netstack. + client_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -client \\ + -mtu ${mtu} -iface client.0 -addr ${client_proxy_addr} -mask ${mask} \\ + -forward ${server_proxy_addr}:${proxy_port} -gso=${gso} -swgso=${swgso}" +fi + +# Server proxy that will listen on the proxy port and forward to the server's +# iperf server using the host networking stack. +server_args="${proxy_binary} -port ${proxy_port} -forward ${server_addr}:${iperf_port}" +if ${server}; then + # Server proxy that will listen on the proxy port and forward to the servers' + # iperf server using netstack. + server_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -server \\ + -mtu ${mtu} -iface server.0 -addr ${server_proxy_addr} -mask ${mask} \\ + -forward ${server_addr}:${iperf_port} -gso=${gso} -swgso=${swgso}" +fi + +# Specify loss and duplicate parameters only if they are non-zero +loss_opt="" +if [ "$(echo $half_loss | bc -q)" != "0" ]; then + loss_opt="loss random ${half_loss}%" +fi +duplicate_opt="" +if [ "$(echo $half_duplicate | bc -q)" != "0" ]; then + duplicate_opt="duplicate ${half_duplicate}%" +fi + +exec unshare -U -m -n -r -f -p --mount-proc /bin/bash << EOF +set -e -m + +if [ ${verbose} == "true" ]; then + set -x +fi + +mount -t tmpfs netstack-bench /tmp + +# We may have reset the path in the unshare if the shell loaded some public +# profiles. Ensure that tools are discoverable via the parent's PATH. +export PATH=${PATH} + +# Add client, server interfaces. +ip link add client.0 type veth peer name client.1 +ip link add server.0 type veth peer name server.1 + +# Add network emulation devices. +ip link add wan.0 type veth peer name wan.1 +ip link set wan.0 up +ip link set wan.1 up + +# Enroll on the bridge. +ip link add name br0 type bridge +ip link add name br1 type bridge +ip link set client.1 master br0 +ip link set server.1 master br1 +ip link set wan.0 master br0 +ip link set wan.1 master br1 +ip link set br0 up +ip link set br1 up + +# Set the MTU appropriately. +ip link set client.0 mtu ${mtu} +ip link set server.0 mtu ${mtu} +ip link set wan.0 mtu ${mtu} +ip link set wan.1 mtu ${mtu} + +# Add appropriate latency, loss and duplication. +# +# This is added in at the point of bridge connection. +for device in wan.0 wan.1; do + # NOTE: We don't support a loss correlation as testing has shown that it + # actually doesn't work. The man page actually has a small comment about this + # "It is also possible to add a correlation, but this option is now deprecated + # due to the noticed bad behavior." For more information see netem(8). + tc qdisc add dev \$device root netem \\ + delay ${half_latency}ms ${latency_variation}ms ${distribution} \\ + ${loss_opt} ${duplicate_opt} +done + +# Start a client proxy. +touch /tmp/client.netns +unshare -n mount --bind /proc/self/ns/net /tmp/client.netns + +# Move the endpoint into the namespace. +while ip link | grep client.0 > /dev/null; do + ip link set dev client.0 netns /tmp/client.netns +done + +if ! ${client}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/client.netns ip addr add ${client_proxy_addr}/${mask} dev client.0 +fi + +# Start a server proxy. +touch /tmp/server.netns +unshare -n mount --bind /proc/self/ns/net /tmp/server.netns +# Move the endpoint into the namespace. +while ip link | grep server.0 > /dev/null; do + ip link set dev server.0 netns /tmp/server.netns +done +if ! ${server}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/server.netns ip addr add ${server_proxy_addr}/${mask} dev server.0 +fi + +# Add client and server addresses, and bring everything up. +${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0 +${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0 +${nsjoin_binary} /tmp/client.netns ip link set client.0 up +${nsjoin_binary} /tmp/client.netns ip link set lo up +${nsjoin_binary} /tmp/server.netns ip link set server.0 up +${nsjoin_binary} /tmp/server.netns ip link set lo up +ip link set dev client.1 up +ip link set dev server.1 up + +${nsjoin_binary} /tmp/client.netns ${client_args} & +client_pid=\$! +${nsjoin_binary} /tmp/server.netns ${server_args} & +server_pid=\$! + +# Start the iperf server. +${nsjoin_binary} /tmp/server.netns iperf -p ${iperf_port} -s >&2 & +iperf_pid=\$! + +# Show traffic information. +if ! ${client} && ! ${server}; then + ${nsjoin_binary} /tmp/client.netns ping -c 100 -i 0.001 -W 1 ${server_addr} >&2 || true +fi + +results_file=\$(mktemp) +function cleanup { + rm -f \$results_file + kill -TERM \$client_pid + kill -TERM \$server_pid + wait \$client_pid + wait \$server_pid + kill -9 \$iperf_pid 2>/dev/null +} + +# Allow failure from this point. +set +e +trap cleanup EXIT + +# Run the benchmark, recording the results file. +while ${nsjoin_binary} /tmp/client.netns iperf \\ + -p ${proxy_port} -c ${client_addr} -t ${duration} -f m 2>&1 \\ + | tee \$results_file \\ + | grep "connect failed" >/dev/null; do + sleep 0.1 # Wait for all services. +done + +# Unlink all relevant devices from the bridge. This is because when the bridge +# is deleted, the kernel may hang. It appears that this problem is fixed in +# upstream commit 1ce5cce895309862d2c35d922816adebe094fe4a. +ip link set client.1 nomaster +ip link set server.1 nomaster +ip link set wan.0 nomaster +ip link set wan.1 nomaster + +# Emit raw results. +cat \$results_file >&2 + +# Emit a useful result (final throughput). +mbits=\$(grep Mbits/sec \$results_file \\ + | sed -n -e 's/^.*[[:space:]]\\([[:digit:]]\\+\\(\\.[[:digit:]]\\+\\)\\?\\)[[:space:]]*Mbits\\/sec.*/\\1/p') +client_cpu_ticks=\$(cat /proc/\$client_pid/stat \\ + | awk '{print (\$14+\$15);}') +server_cpu_ticks=\$(cat /proc/\$server_pid/stat \\ + | awk '{print (\$14+\$15);}') +ticks_per_sec=\$(getconf CLK_TCK) +client_cpu_load=\$(bc -l <<< \$client_cpu_ticks/\$ticks_per_sec/${duration}) +server_cpu_load=\$(bc -l <<< \$server_cpu_ticks/\$ticks_per_sec/${duration}) +echo \$mbits \$client_cpu_load \$server_cpu_load +EOF diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go new file mode 100644 index 000000000..361a56755 --- /dev/null +++ b/benchmarks/tcp/tcp_proxy.go @@ -0,0 +1,436 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary tcp_proxy is a simple TCP proxy. +package main + +import ( + "encoding/gob" + "flag" + "fmt" + "io" + "log" + "math/rand" + "net" + "os" + "os/signal" + "regexp" + "runtime" + "runtime/pprof" + "strconv" + "syscall" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +var ( + port = flag.Int("port", 0, "bind port (all addresses)") + forward = flag.String("forward", "", "forwarding target") + client = flag.Bool("client", false, "use netstack for listen") + server = flag.Bool("server", false, "use netstack for dial") + + // Netstack-specific options. + mtu = flag.Int("mtu", 1280, "mtu for network stack") + addr = flag.String("addr", "", "address for tap-based netstack") + mask = flag.Int("mask", 8, "mask size for address") + iface = flag.String("iface", "", "network interface name to bind for netstack") + sack = flag.Bool("sack", false, "enable SACK support for netstack") + cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack") + gso = flag.Int("gso", 0, "GSO maximum size") + swgso = flag.Bool("swgso", false, "software-level GSO") + clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.") + memprofile = flag.String("memprofile", "", "write memory profile to the specified file.") +) + +type impl interface { + dial(address string) (net.Conn, error) + listen(port int) (net.Listener, error) + printStats() +} + +type netImpl struct{} + +func (netImpl) dial(address string) (net.Conn, error) { + return net.Dial("tcp", address) +} + +func (netImpl) listen(port int) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf(":%d", port)) +} + +func (netImpl) printStats() { +} + +const ( + nicID = 1 // Fixed. + rcvBufSize = 1 << 20 // 1MB. +) + +type netstackImpl struct { + s *stack.Stack + addr tcpip.Address + mode string +} + +func setupNetwork(ifaceName string) (fd int, err error) { + // Get all interfaces in the namespace. + ifaces, err := net.Interfaces() + if err != nil { + return -1, fmt.Errorf("querying interfaces: %v", err) + } + + for _, iface := range ifaces { + if iface.Name != ifaceName { + continue + } + // Create the socket. + const protocol = 0x0300 // htons(ETH_P_ALL) + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + if err != nil { + return -1, fmt.Errorf("unable to create raw socket: %v", err) + } + + // Bind to the appropriate device. + ll := syscall.SockaddrLinklayer{ + Protocol: protocol, + Ifindex: iface.Index, + Pkttype: syscall.PACKET_HOST, + } + if err := syscall.Bind(fd, &ll); err != nil { + return -1, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) + } + + // RAW Sockets by default have a very small SO_RCVBUF of 256KB, + // up it to at least 1MB to reduce packet drops. + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, rcvBufSize); err != nil { + return -1, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + } + + if !*swgso && *gso != 0 { + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + return -1, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) + } + } + return fd, nil + } + return -1, fmt.Errorf("failed to find interface: %v", ifaceName) +} + +func newNetstackImpl(mode string) (impl, error) { + fd, err := setupNetwork(*iface) + if err != nil { + return nil, err + } + + // Parse details. + parsedAddr := tcpip.Address(net.ParseIP(*addr).To4()) + parsedDest := tcpip.Address("") // Filled in below. + parsedMask := tcpip.AddressMask("") // Filled in below. + switch *mask { + case 8: + parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0}) + case 16: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0}) + case 24: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0}) + default: + // This is just laziness; we don't expect a different mask. + return nil, fmt.Errorf("mask %d not supported", mask) + } + + // Create a new network stack. + netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()} + transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()} + s := stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + }) + + // Generate a new mac for the eth device. + mac := make(net.HardwareAddr, 6) + rand.Read(mac) // Fill with random data. + mac[0] &^= 0x1 // Clear multicast bit. + mac[0] |= 0x2 // Set local assignment bit (IEEE802). + ep, err := fdbased.New(&fdbased.Options{ + FDs: []int{fd}, + MTU: uint32(*mtu), + EthernetHeader: true, + Address: tcpip.LinkAddress(mac), + // Enable checksum generation as we need to generate valid + // checksums for the veth device to deliver our packets to the + // peer. But we do want to disable checksum verification as veth + // devices do perform GRO and the linux host kernel may not + // regenerate valid checksums after GRO. + TXChecksumOffload: false, + RXChecksumOffload: true, + PacketDispatchMode: fdbased.RecvMMsg, + GSOMaxSize: uint32(*gso), + SoftwareGSOEnabled: *swgso, + }) + if err != nil { + return nil, fmt.Errorf("failed to create FD endpoint: %v", err) + } + if err := s.CreateNIC(nicID, ep); err != nil { + return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil { + return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err) + } + + subnet, err := tcpip.NewSubnet(parsedDest, parsedMask) + if err != nil { + return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err) + } + // Add default route; we only support + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + + // Set protocol options. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %v", err) + } + + // Set Congestion Control to cubic if requested. + if *cubic { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %v", err) + } + } + + return netstackImpl{ + s: s, + addr: parsedAddr, + mode: mode, + }, nil +} + +func (n netstackImpl) dial(address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + if host == "" { + // A host must be provided for the dial. + return nil, fmt.Errorf("no host provided") + } + portNumber, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + addr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(net.ParseIP(host).To4()), + Port: uint16(portNumber), + } + conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return conn, nil +} + +func (n netstackImpl) listen(port int) (net.Listener, error) { + addr := tcpip.FullAddress{ + NIC: nicID, + Port: uint16(port), + } + listener, err := gonet.NewListener(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return listener, nil +} + +var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`) + +func (n netstackImpl) printStats() { + // Don't show zero fields. + stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "") + log.Printf("netstack %s Stats: %+v\n", n.mode, stats) +} + +// installProbe installs a TCP Probe function that will dump endpoint +// state to the specified file. It also returns a close func() that +// can be used to close the probeFile. +func (n netstackImpl) installProbe(probeFileName string) (close func()) { + // Install Probe to dump out end point state. + probeFile, err := os.Create(probeFileName) + if err != nil { + log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err) + } + probeEncoder := gob.NewEncoder(probeFile) + // Install a TCP Probe. + n.s.AddTCPProbe(func(state stack.TCPEndpointState) { + probeEncoder.Encode(state) + }) + return func() { probeFile.Close() } +} + +func main() { + flag.Parse() + if *port == 0 { + log.Fatalf("no port provided") + } + if *forward == "" { + log.Fatalf("no forward provided") + } + // Seed the random number generator to ensure that we are given MAC addresses that don't + // for the case of the client and server stack. + rand.Seed(time.Now().UTC().UnixNano()) + + if *cpuprofile != "" { + f, err := os.Create(*cpuprofile) + if err != nil { + log.Fatal("could not create CPU profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing CPU profile: ", err) + } + }() + if err := pprof.StartCPUProfile(f); err != nil { + log.Fatal("could not start CPU profile: ", err) + } + defer pprof.StopCPUProfile() + } + + var ( + in impl + out impl + err error + ) + if *server { + in, err = newNetstackImpl("server") + if *serverTCPProbeFile != "" { + defer in.(netstackImpl).installProbe(*serverTCPProbeFile)() + } + + } else { + in = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + if *client { + out, err = newNetstackImpl("client") + if *clientTCPProbeFile != "" { + defer out.(netstackImpl).installProbe(*clientTCPProbeFile)() + } + } else { + out = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + + // Dial forward before binding. + var next net.Conn + for { + next, err = out.dial(*forward) + if err == nil { + break + } + time.Sleep(50 * time.Millisecond) + log.Printf("connect failed retrying: %v", err) + } + + // Bind once to the server socket. + listener, err := in.listen(*port) + if err != nil { + // Should not happen, everything must be bound by this time + // this proxy is started. + log.Fatalf("unable to listen: %v", err) + } + log.Printf("client=%v, server=%v, ready.", *client, *server) + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM) + go func() { + <-sigs + if *cpuprofile != "" { + pprof.StopCPUProfile() + } + if *memprofile != "" { + f, err := os.Create(*memprofile) + if err != nil { + log.Fatal("could not create memory profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing memory profile: ", err) + } + }() + runtime.GC() // get up-to-date statistics + if err := pprof.WriteHeapProfile(f); err != nil { + log.Fatalf("Unable to write heap profile: %v", err) + } + } + os.Exit(0) + }() + + for { + // Forward all connections. + inConn, err := listener.Accept() + if err != nil { + // This should not happen; we are listening + // successfully. Exhausted all available FDs? + log.Fatalf("accept error: %v", err) + } + log.Printf("incoming connection established.") + + // Copy both ways. + go io.Copy(inConn, next) + go io.Copy(next, inConn) + + // Print stats every second. + go func() { + t := time.NewTicker(time.Second) + defer t.Stop() + for { + <-t.C + in.printStats() + out.printStats() + } + }() + + for { + // Dial again. + next, err = out.dial(*forward) + if err == nil { + break + } + } + } +} diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD new file mode 100644 index 000000000..643806105 --- /dev/null +++ b/benchmarks/workloads/BUILD @@ -0,0 +1,35 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "workloads", + srcs = ["__init__.py"], +) + +filegroup( + name = "files", + srcs = [ + "//benchmarks/workloads/ab:files", + "//benchmarks/workloads/absl:files", + "//benchmarks/workloads/curl:files", + "//benchmarks/workloads/ffmpeg:files", + "//benchmarks/workloads/fio:files", + "//benchmarks/workloads/httpd:files", + "//benchmarks/workloads/iperf:files", + "//benchmarks/workloads/netcat:files", + "//benchmarks/workloads/nginx:files", + "//benchmarks/workloads/node:files", + "//benchmarks/workloads/node_template:files", + "//benchmarks/workloads/redis:files", + "//benchmarks/workloads/redisbenchmark:files", + "//benchmarks/workloads/ruby:files", + "//benchmarks/workloads/ruby_template:files", + "//benchmarks/workloads/sleep:files", + "//benchmarks/workloads/sysbench:files", + "//benchmarks/workloads/syscall:files", + "//benchmarks/workloads/tensorflow:files", + "//benchmarks/workloads/true:files", + ], +) diff --git a/benchmarks/workloads/__init__.py b/benchmarks/workloads/__init__.py new file mode 100644 index 000000000..e12651e76 --- /dev/null +++ b/benchmarks/workloads/__init__.py @@ -0,0 +1,14 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Workloads, parsers and test data.""" diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD new file mode 100644 index 000000000..e99a8d674 --- /dev/null +++ b/benchmarks/workloads/ab/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "ab", + srcs = ["__init__.py"], +) + +py_test( + name = "ab_test", + srcs = ["ab_test.py"], + python_version = "PY3", + deps = [ + ":ab", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/ab/Dockerfile b/benchmarks/workloads/ab/Dockerfile new file mode 100644 index 000000000..0d0b6e2eb --- /dev/null +++ b/benchmarks/workloads/ab/Dockerfile @@ -0,0 +1,15 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + apache2-utils \ + && rm -rf /var/lib/apt/lists/* + +# Parameterized workload. +ENV requests 5000 +ENV connections 10 +ENV host localhost +ENV port 8080 +ENV path notfound +CMD ["sh", "-c", "ab -n ${requests} -c ${connections} http://${host}:${port}/${path}"] diff --git a/benchmarks/workloads/ab/__init__.py b/benchmarks/workloads/ab/__init__.py new file mode 100644 index 000000000..eedf8e083 --- /dev/null +++ b/benchmarks/workloads/ab/__init__.py @@ -0,0 +1,88 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Apachebench tool.""" + +import re + +SAMPLE_DATA = """This is ApacheBench, Version 2.3 <$Revision: 1826891 $> +Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ +Licensed to The Apache Software Foundation, http://www.apache.org/ + +Benchmarking 10.10.10.10 (be patient).....done + + +Server Software: Apache/2.4.38 +Server Hostname: 10.10.10.10 +Server Port: 80 + +Document Path: /latin10k.txt +Document Length: 210 bytes + +Concurrency Level: 1 +Time taken for tests: 0.180 seconds +Complete requests: 100 +Failed requests: 0 +Non-2xx responses: 100 +Total transferred: 38800 bytes +HTML transferred: 21000 bytes +Requests per second: 556.44 [#/sec] (mean) +Time per request: 1.797 [ms] (mean) +Time per request: 1.797 [ms] (mean, across all concurrent requests) +Transfer rate: 210.84 [Kbytes/sec] received + +Connection Times (ms) + min mean[+/-sd] median max +Connect: 0 0 0.2 0 2 +Processing: 1 2 1.0 1 8 +Waiting: 1 1 1.0 1 7 +Total: 1 2 1.2 1 10 + +Percentage of the requests served within a certain time (ms) + 50% 1 + 66% 2 + 75% 2 + 80% 2 + 90% 2 + 95% 3 + 98% 7 + 99% 10 + 100% 10 (longest request)""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def transfer_rate(data: str, **kwargs) -> float: + """Mean transfer rate in Kbytes/sec.""" + regex = r"Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received" + return float(re.compile(regex).search(data).group(1)) + + +# pylint: disable=unused-argument +def latency(data: str, **kwargs) -> float: + """Mean latency in milliseconds.""" + regex = r"Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s" + res = re.compile(regex).search(data) + return float(res.group(1)) + + +# pylint: disable=unused-argument +def requests_per_second(data: str, **kwargs) -> float: + """Requests per second.""" + regex = r"Requests per second:\s+(\d+\.?\d+?)\s+" + res = re.compile(regex).search(data) + return float(res.group(1)) diff --git a/benchmarks/workloads/ab/ab_test.py b/benchmarks/workloads/ab/ab_test.py new file mode 100644 index 000000000..4afac2996 --- /dev/null +++ b/benchmarks/workloads/ab/ab_test.py @@ -0,0 +1,42 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser test.""" + +import sys + +import pytest + +from benchmarks.workloads import ab + + +def test_transfer_rate_parser(): + """Test transfer rate parser.""" + res = ab.transfer_rate(ab.sample()) + assert res == 210.84 + + +def test_latency_parser(): + """Test latency parser.""" + res = ab.latency(ab.sample()) + assert res == 2 + + +def test_requests_per_second(): + """Test requests per second parser.""" + res = ab.requests_per_second(ab.sample()) + assert res == 556.44 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD new file mode 100644 index 000000000..bb499620e --- /dev/null +++ b/benchmarks/workloads/absl/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "absl", + srcs = ["__init__.py"], +) + +py_test( + name = "absl_test", + srcs = ["absl_test.py"], + python_version = "PY3", + deps = [ + ":absl", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/absl/Dockerfile b/benchmarks/workloads/absl/Dockerfile new file mode 100644 index 000000000..e935c5ddc --- /dev/null +++ b/benchmarks/workloads/absl/Dockerfile @@ -0,0 +1,24 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + wget \ + git \ + pkg-config \ + zip \ + g++ \ + zlib1g-dev \ + unzip \ + python3 \ + && rm -rf /var/lib/apt/lists/* +RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh +RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh +RUN ./bazel-0.27.0-installer-linux-x86_64.sh + +RUN git clone https://github.com/abseil/abseil-cpp.git +WORKDIR abseil-cpp +RUN git checkout 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f +RUN bazel clean +ENV path "absl/base/..." +CMD bazel build ${path} 2>&1 diff --git a/benchmarks/workloads/absl/__init__.py b/benchmarks/workloads/absl/__init__.py new file mode 100644 index 000000000..b40e3f915 --- /dev/null +++ b/benchmarks/workloads/absl/__init__.py @@ -0,0 +1,63 @@ +# python3 +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ABSL build benchmark.""" + +import re + +SAMPLE_BAZEL_OUTPUT = """Extracting Bazel installation... +Starting local Bazel server and connecting to it... +Loading: +Loading: 0 packages loaded +Loading: 0 packages loaded + currently loading: absl/algorithm ... (11 packages) +Analyzing: 241 targets (16 packages loaded, 0 targets configured) +Analyzing: 241 targets (21 packages loaded, 617 targets configured) +Analyzing: 241 targets (27 packages loaded, 687 targets configured) +Analyzing: 241 targets (32 packages loaded, 1105 targets configured) +Analyzing: 241 targets (32 packages loaded, 1294 targets configured) +Analyzing: 241 targets (35 packages loaded, 1575 targets configured) +Analyzing: 241 targets (35 packages loaded, 1575 targets configured) +Analyzing: 241 targets (36 packages loaded, 1603 targets configured) +Analyzing: 241 targets (36 packages loaded, 1603 targets configured) +INFO: Analyzed 241 targets (37 packages loaded, 1864 targets configured). +INFO: Found 241 targets... +[0 / 5] [Prepa] BazelWorkspaceStatusAction stable-status.txt +[16 / 50] [Analy] Compiling absl/base/dynamic_annotations.cc ... (20 actions, 10 running) +[60 / 77] Compiling external/com_google_googletest/googletest/src/gtest.cc; 5s processwrapper-sandbox ... (12 actions, 11 running) +[158 / 174] Compiling absl/container/internal/raw_hash_set_test.cc; 2s processwrapper-sandbox ... (12 actions, 11 running) +[278 / 302] Compiling absl/container/internal/raw_hash_set_test.cc; 6s processwrapper-sandbox ... (12 actions, 11 running) +[384 / 406] Compiling absl/container/internal/raw_hash_set_test.cc; 10s processwrapper-sandbox ... (12 actions, 11 running) +[581 / 604] Compiling absl/container/flat_hash_set_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) +[722 / 745] Compiling absl/container/node_hash_set_test.cc; 9s processwrapper-sandbox ... (12 actions, 11 running) +[846 / 867] Compiling absl/hash/hash_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) +INFO: From Compiling absl/debugging/symbolize_test.cc: +/tmp/cclCVipU.s: Assembler messages: +/tmp/cclCVipU.s:1662: Warning: ignoring changed section attributes for .text +[999 / 1,022] Compiling absl/hash/hash_test.cc; 19s processwrapper-sandbox ... (12 actions, 11 running) +[1,082 / 1,084] Compiling absl/container/flat_hash_map_test.cc; 7s processwrapper-sandbox +INFO: Elapsed time: 81.861s, Critical Path: 23.81s +INFO: 515 processes: 515 processwrapper-sandbox. +INFO: Build completed successfully, 1084 total actions +INFO: Build completed successfully, 1084 total actions""" + + +def sample(): + return SAMPLE_BAZEL_OUTPUT + + +# pylint: disable=unused-argument +def elapsed_time(data: str, **kwargs) -> float: + """Returns the elapsed time for running an absl build.""" + return float(re.compile(r"Elapsed time: (\d*.?\d*)s").search(data).group(1)) diff --git a/benchmarks/workloads/absl/absl_test.py b/benchmarks/workloads/absl/absl_test.py new file mode 100644 index 000000000..41f216999 --- /dev/null +++ b/benchmarks/workloads/absl/absl_test.py @@ -0,0 +1,31 @@ +# python3 +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ABSL build test.""" + +import sys + +import pytest + +from benchmarks.workloads import absl + + +def test_elapsed_time(): + """Test elapsed_time.""" + res = absl.elapsed_time(absl.sample()) + assert res == 81.861 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/curl/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/curl/Dockerfile b/benchmarks/workloads/curl/Dockerfile new file mode 100644 index 000000000..336cb088a --- /dev/null +++ b/benchmarks/workloads/curl/Dockerfile @@ -0,0 +1,14 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Accept a host and port parameter. +ENV host localhost +ENV port 8080 + +# Spin until we make a successful request. +CMD ["sh", "-c", "while ! curl -v -i http://$host:$port; do true; done"] diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD new file mode 100644 index 000000000..c1f2afc40 --- /dev/null +++ b/benchmarks/workloads/ffmpeg/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "ffmpeg", + srcs = ["__init__.py"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/ffmpeg/Dockerfile b/benchmarks/workloads/ffmpeg/Dockerfile new file mode 100644 index 000000000..f2f530d7c --- /dev/null +++ b/benchmarks/workloads/ffmpeg/Dockerfile @@ -0,0 +1,10 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* +WORKDIR /media +ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4 +CMD ["ffmpeg", "-i", "video.mp4", "-c:v", "libx264", "-preset", "veryslow", "output.mp4"] diff --git a/benchmarks/workloads/ffmpeg/__init__.py b/benchmarks/workloads/ffmpeg/__init__.py new file mode 100644 index 000000000..7578a443b --- /dev/null +++ b/benchmarks/workloads/ffmpeg/__init__.py @@ -0,0 +1,20 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simple ffmpeg workload.""" + + +# pylint: disable=unused-argument +def run_time(value, **kwargs): + """Returns the startup and runtime of the ffmpeg workload in seconds.""" + return value diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD new file mode 100644 index 000000000..7fc96cfa5 --- /dev/null +++ b/benchmarks/workloads/fio/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "fio", + srcs = ["__init__.py"], +) + +py_test( + name = "fio_test", + srcs = ["fio_test.py"], + python_version = "PY3", + deps = [ + ":fio", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/fio/Dockerfile b/benchmarks/workloads/fio/Dockerfile new file mode 100644 index 000000000..b3cf864eb --- /dev/null +++ b/benchmarks/workloads/fio/Dockerfile @@ -0,0 +1,23 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + fio \ + && rm -rf /var/lib/apt/lists/* + +# Parameterized test. +ENV test write +ENV ioengine sync +ENV size 5000000 +ENV iodepth 4 +ENV blocksize "1m" +ENV time "" +ENV path "/disk/file.dat" +ENV ramp_time 0 + +CMD ["sh", "-c", "fio --output-format=json --name=test --ramp_time=${ramp_time} --ioengine=${ioengine} --size=${size} \ +--filename=${path} --iodepth=${iodepth} --bs=${blocksize} --rw=${test} ${time}"] + + + diff --git a/benchmarks/workloads/fio/__init__.py b/benchmarks/workloads/fio/__init__.py new file mode 100644 index 000000000..52711e956 --- /dev/null +++ b/benchmarks/workloads/fio/__init__.py @@ -0,0 +1,369 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FIO benchmark tool.""" + +import json + +SAMPLE_DATA = """ +{ + "fio version" : "fio-3.1", + "timestamp" : 1554837456, + "timestamp_ms" : 1554837456621, + "time" : "Tue Apr 9 19:17:36 2019", + "jobs" : [ + { + "jobname" : "test", + "groupid" : 0, + "error" : 0, + "eta" : 2147483647, + "elapsed" : 1, + "job options" : { + "name" : "test", + "ioengine" : "sync", + "size" : "1073741824", + "filename" : "/disk/file.dat", + "iodepth" : "4", + "bs" : "4096", + "rw" : "write" + }, + "read" : { + "io_bytes" : 0, + "io_kbytes" : 0, + "bw" : 0, + "iops" : 0.000000, + "runtime" : 0, + "total_ios" : 0, + "short_ios" : 0, + "drop_ios" : 0, + "slat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "clat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000, + "percentile" : { + "1.000000" : 0, + "5.000000" : 0, + "10.000000" : 0, + "20.000000" : 0, + "30.000000" : 0, + "40.000000" : 0, + "50.000000" : 0, + "60.000000" : 0, + "70.000000" : 0, + "80.000000" : 0, + "90.000000" : 0, + "95.000000" : 0, + "99.000000" : 0, + "99.500000" : 0, + "99.900000" : 0, + "99.950000" : 0, + "99.990000" : 0, + "0.00" : 0, + "0.00" : 0, + "0.00" : 0 + } + }, + "lat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "bw_min" : 0, + "bw_max" : 0, + "bw_agg" : 0.000000, + "bw_mean" : 0.000000, + "bw_dev" : 0.000000, + "bw_samples" : 0, + "iops_min" : 0, + "iops_max" : 0, + "iops_mean" : 0.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 0 + }, + "write" : { + "io_bytes" : 1073741824, + "io_kbytes" : 1048576, + "bw" : 1753471, + "iops" : 438367.892977, + "runtime" : 598, + "total_ios" : 262144, + "short_ios" : 0, + "drop_ios" : 0, + "slat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "clat_ns" : { + "min" : 1693, + "max" : 754733, + "mean" : 2076.404373, + "stddev" : 1724.195529, + "percentile" : { + "1.000000" : 1736, + "5.000000" : 1752, + "10.000000" : 1768, + "20.000000" : 1784, + "30.000000" : 1800, + "40.000000" : 1800, + "50.000000" : 1816, + "60.000000" : 1816, + "70.000000" : 1848, + "80.000000" : 1928, + "90.000000" : 2512, + "95.000000" : 2992, + "99.000000" : 6176, + "99.500000" : 6304, + "99.900000" : 11328, + "99.950000" : 15168, + "99.990000" : 17792, + "0.00" : 0, + "0.00" : 0, + "0.00" : 0 + } + }, + "lat_ns" : { + "min" : 1731, + "max" : 754770, + "mean" : 2117.878979, + "stddev" : 1730.290512 + }, + "bw_min" : 1731120, + "bw_max" : 1731120, + "bw_agg" : 98.725328, + "bw_mean" : 1731120.000000, + "bw_dev" : 0.000000, + "bw_samples" : 1, + "iops_min" : 432780, + "iops_max" : 432780, + "iops_mean" : 432780.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 1 + }, + "trim" : { + "io_bytes" : 0, + "io_kbytes" : 0, + "bw" : 0, + "iops" : 0.000000, + "runtime" : 0, + "total_ios" : 0, + "short_ios" : 0, + "drop_ios" : 0, + "slat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "clat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000, + "percentile" : { + "1.000000" : 0, + "5.000000" : 0, + "10.000000" : 0, + "20.000000" : 0, + "30.000000" : 0, + "40.000000" : 0, + "50.000000" : 0, + "60.000000" : 0, + "70.000000" : 0, + "80.000000" : 0, + "90.000000" : 0, + "95.000000" : 0, + "99.000000" : 0, + "99.500000" : 0, + "99.900000" : 0, + "99.950000" : 0, + "99.990000" : 0, + "0.00" : 0, + "0.00" : 0, + "0.00" : 0 + } + }, + "lat_ns" : { + "min" : 0, + "max" : 0, + "mean" : 0.000000, + "stddev" : 0.000000 + }, + "bw_min" : 0, + "bw_max" : 0, + "bw_agg" : 0.000000, + "bw_mean" : 0.000000, + "bw_dev" : 0.000000, + "bw_samples" : 0, + "iops_min" : 0, + "iops_max" : 0, + "iops_mean" : 0.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 0 + }, + "usr_cpu" : 17.922948, + "sys_cpu" : 81.574539, + "ctx" : 3, + "majf" : 0, + "minf" : 10, + "iodepth_level" : { + "1" : 100.000000, + "2" : 0.000000, + "4" : 0.000000, + "8" : 0.000000, + "16" : 0.000000, + "32" : 0.000000, + ">=64" : 0.000000 + }, + "latency_ns" : { + "2" : 0.000000, + "4" : 0.000000, + "10" : 0.000000, + "20" : 0.000000, + "50" : 0.000000, + "100" : 0.000000, + "250" : 0.000000, + "500" : 0.000000, + "750" : 0.000000, + "1000" : 0.000000 + }, + "latency_us" : { + "2" : 82.737350, + "4" : 12.605286, + "10" : 4.543686, + "20" : 0.107956, + "50" : 0.010000, + "100" : 0.000000, + "250" : 0.000000, + "500" : 0.000000, + "750" : 0.000000, + "1000" : 0.010000 + }, + "latency_ms" : { + "2" : 0.000000, + "4" : 0.000000, + "10" : 0.000000, + "20" : 0.000000, + "50" : 0.000000, + "100" : 0.000000, + "250" : 0.000000, + "500" : 0.000000, + "750" : 0.000000, + "1000" : 0.000000, + "2000" : 0.000000, + ">=2000" : 0.000000 + }, + "latency_depth" : 4, + "latency_target" : 0, + "latency_percentile" : 100.000000, + "latency_window" : 0 + } + ], + "disk_util" : [ + { + "name" : "dm-1", + "read_ios" : 0, + "write_ios" : 3, + "read_merges" : 0, + "write_merges" : 0, + "read_ticks" : 0, + "write_ticks" : 0, + "in_queue" : 0, + "util" : 0.000000, + "aggr_read_ios" : 0, + "aggr_write_ios" : 3, + "aggr_read_merges" : 0, + "aggr_write_merge" : 0, + "aggr_read_ticks" : 0, + "aggr_write_ticks" : 0, + "aggr_in_queue" : 0, + "aggr_util" : 0.000000 + }, + { + "name" : "dm-0", + "read_ios" : 0, + "write_ios" : 3, + "read_merges" : 0, + "write_merges" : 0, + "read_ticks" : 0, + "write_ticks" : 0, + "in_queue" : 0, + "util" : 0.000000, + "aggr_read_ios" : 0, + "aggr_write_ios" : 3, + "aggr_read_merges" : 0, + "aggr_write_merge" : 0, + "aggr_read_ticks" : 0, + "aggr_write_ticks" : 2, + "aggr_in_queue" : 0, + "aggr_util" : 0.000000 + }, + { + "name" : "nvme0n1", + "read_ios" : 0, + "write_ios" : 3, + "read_merges" : 0, + "write_merges" : 0, + "read_ticks" : 0, + "write_ticks" : 2, + "in_queue" : 0, + "util" : 0.000000 + } + ] +} +""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def read_bandwidth(data: str, **kwargs) -> int: + """File I/O bandwidth.""" + return json.loads(data)["jobs"][0]["read"]["bw"] * 1024 + + +# pylint: disable=unused-argument +def write_bandwidth(data: str, **kwargs) -> int: + """File I/O bandwidth.""" + return json.loads(data)["jobs"][0]["write"]["bw"] * 1024 + + +# pylint: disable=unused-argument +def read_io_ops(data: str, **kwargs) -> float: + """File I/O operations per second.""" + return float(json.loads(data)["jobs"][0]["read"]["iops"]) + + +# pylint: disable=unused-argument +def write_io_ops(data: str, **kwargs) -> float: + """File I/O operations per second.""" + return float(json.loads(data)["jobs"][0]["write"]["iops"]) + + +# Change function names so we just print "bandwidth" and "io_ops". +read_bandwidth.__name__ = "bandwidth" +write_bandwidth.__name__ = "bandwidth" +read_io_ops.__name__ = "io_ops" +write_io_ops.__name__ = "io_ops" diff --git a/benchmarks/workloads/fio/fio_test.py b/benchmarks/workloads/fio/fio_test.py new file mode 100644 index 000000000..04a6eeb7e --- /dev/null +++ b/benchmarks/workloads/fio/fio_test.py @@ -0,0 +1,44 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser tests.""" + +import sys + +import pytest + +from benchmarks.workloads import fio + + +def test_read_io_ops(): + """Test read ops parser.""" + assert fio.read_io_ops(fio.sample()) == 0.0 + + +def test_write_io_ops(): + """Test write ops parser.""" + assert fio.write_io_ops(fio.sample()) == 438367.892977 + + +def test_read_bandwidth(): + """Test read bandwidth parser.""" + assert fio.read_bandwidth(fio.sample()) == 0.0 + + +def test_write_bandwith(): + """Test write bandwidth parser.""" + assert fio.write_bandwidth(fio.sample()) == 1753471 * 1024 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/httpd/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/httpd/Dockerfile b/benchmarks/workloads/httpd/Dockerfile new file mode 100644 index 000000000..5259c8f4f --- /dev/null +++ b/benchmarks/workloads/httpd/Dockerfile @@ -0,0 +1,27 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + apache2 \ + && rm -rf /var/lib/apt/lists/* + +# Link the htdoc directory to tmp. +RUN mkdir -p /usr/local/apache2/htdocs && \ + cd /usr/local/apache2 && ln -s /tmp htdocs + +# Generate a bunch of relevant files. +RUN mkdir -p /local && \ + for size in 1 10 100 1000 1024 10240; do \ + dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ + done + +# Standard settings. +ENV APACHE_RUN_DIR /tmp +ENV APACHE_RUN_USER nobody +ENV APACHE_RUN_GROUP nogroup +ENV APACHE_LOG_DIR /tmp +ENV APACHE_PID_FILE /tmp/apache.pid + +# Copy on start-up; serve everything from /tmp (including the configuration). +CMD ["sh", "-c", "cp -a /local/* /tmp && apache2 -c \"ServerName localhost\" -c \"DocumentRoot /tmp\" -X"] diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD new file mode 100644 index 000000000..fe0acbfce --- /dev/null +++ b/benchmarks/workloads/iperf/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "iperf", + srcs = ["__init__.py"], +) + +py_test( + name = "iperf_test", + srcs = ["iperf_test.py"], + python_version = "PY3", + deps = [ + ":iperf", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/iperf/Dockerfile b/benchmarks/workloads/iperf/Dockerfile new file mode 100644 index 000000000..9704c506c --- /dev/null +++ b/benchmarks/workloads/iperf/Dockerfile @@ -0,0 +1,14 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + iperf \ + && rm -rf /var/lib/apt/lists/* + +# Accept a host parameter. +ENV host "" +ENV port 5001 + +# Start as client if the host is provided. +CMD ["sh", "-c", "test -z \"${host}\" && iperf -s || iperf -f K --realtime -c ${host} -p ${port}"] diff --git a/benchmarks/workloads/iperf/__init__.py b/benchmarks/workloads/iperf/__init__.py new file mode 100644 index 000000000..3817a7ade --- /dev/null +++ b/benchmarks/workloads/iperf/__init__.py @@ -0,0 +1,40 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""iperf.""" + +import re + +SAMPLE_DATA = """ +------------------------------------------------------------ +Client connecting to 10.138.15.215, TCP port 32779 +TCP window size: 45.0 KByte (default) +------------------------------------------------------------ +[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779 +[ ID] Interval Transfer Bandwidth +[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec + +""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def bandwidth(data: str, **kwargs) -> float: + """Calculate the bandwidth.""" + regex = r"\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec" + res = re.compile(regex).search(data) + return float(res.group(1)) * 1000 diff --git a/benchmarks/workloads/iperf/iperf_test.py b/benchmarks/workloads/iperf/iperf_test.py new file mode 100644 index 000000000..6959b7e8a --- /dev/null +++ b/benchmarks/workloads/iperf/iperf_test.py @@ -0,0 +1,28 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for iperf.""" + +import sys + +import pytest + +from benchmarks.workloads import iperf + + +def test_bandwidth(): + assert iperf.bandwidth(iperf.sample()) == 45900 * 1000 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/netcat/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/netcat/Dockerfile b/benchmarks/workloads/netcat/Dockerfile new file mode 100644 index 000000000..d8548d89a --- /dev/null +++ b/benchmarks/workloads/netcat/Dockerfile @@ -0,0 +1,14 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + netcat \ + && rm -rf /var/lib/apt/lists/* + +# Accept a host and port parameter. +ENV host localhost +ENV port 8080 + +# Spin until we make a successful request. +CMD ["sh", "-c", "while ! nc -zv $host $port; do true; done"] diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/nginx/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/nginx/Dockerfile b/benchmarks/workloads/nginx/Dockerfile new file mode 100644 index 000000000..b64eb52ae --- /dev/null +++ b/benchmarks/workloads/nginx/Dockerfile @@ -0,0 +1 @@ +FROM nginx:1.15.10 diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD new file mode 100644 index 000000000..59460d02f --- /dev/null +++ b/benchmarks/workloads/node/BUILD @@ -0,0 +1,13 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "index.js", + "package.json", + ], +) diff --git a/benchmarks/workloads/node/Dockerfile b/benchmarks/workloads/node/Dockerfile new file mode 100644 index 000000000..139a38bf5 --- /dev/null +++ b/benchmarks/workloads/node/Dockerfile @@ -0,0 +1,2 @@ +FROM node:onbuild +CMD ["node", "index.js"] diff --git a/benchmarks/workloads/node/index.js b/benchmarks/workloads/node/index.js new file mode 100644 index 000000000..584158462 --- /dev/null +++ b/benchmarks/workloads/node/index.js @@ -0,0 +1,28 @@ +'use strict'; + +var start = new Date().getTime(); + +// Load dependencies to simulate an average nodejs app. +var req_0 = require('async'); +var req_1 = require('bluebird'); +var req_2 = require('firebase'); +var req_3 = require('firebase-admin'); +var req_4 = require('@google-cloud/container'); +var req_5 = require('@google-cloud/logging'); +var req_6 = require('@google-cloud/monitoring'); +var req_7 = require('@google-cloud/spanner'); +var req_8 = require('lodash'); +var req_9 = require('mailgun-js'); +var req_10 = require('request'); +var express = require('express'); +var app = express(); + +var loaded = new Date().getTime() - start; +app.get('/', function(req, res) { + res.send('Hello World!<br>Loaded in ' + loaded + 'ms'); +}); + +console.log('Loaded in ' + loaded + ' ms'); +app.listen(8080, function() { + console.log('Listening on port 8080...'); +}); diff --git a/benchmarks/workloads/node/package.json b/benchmarks/workloads/node/package.json new file mode 100644 index 000000000..c00b9b3cb --- /dev/null +++ b/benchmarks/workloads/node/package.json @@ -0,0 +1,19 @@ +{ + "name": "node", + "version": "1.0.0", + "main": "index.js", + "dependencies": { + "@google-cloud/container": "^0.3.0", + "@google-cloud/logging": "^4.2.0", + "@google-cloud/monitoring": "^0.6.0", + "@google-cloud/spanner": "^2.2.1", + "async": "^2.6.1", + "bluebird": "^3.5.3", + "express": "^4.16.4", + "firebase": "^5.7.2", + "firebase-admin": "^6.4.0", + "lodash": "^4.17.11", + "mailgun-js": "^0.22.0", + "request": "^2.88.0" + } +} diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD new file mode 100644 index 000000000..ae7f121d3 --- /dev/null +++ b/benchmarks/workloads/node_template/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "index.hbs", + "index.js", + "package.json", + "package-lock.json", + ], +) diff --git a/benchmarks/workloads/node_template/Dockerfile b/benchmarks/workloads/node_template/Dockerfile new file mode 100644 index 000000000..7eb065d54 --- /dev/null +++ b/benchmarks/workloads/node_template/Dockerfile @@ -0,0 +1,5 @@ +FROM node:onbuild + +ENV host "127.0.0.1" + +CMD ["sh", "-c", "node index.js ${host}"] diff --git a/benchmarks/workloads/node_template/index.hbs b/benchmarks/workloads/node_template/index.hbs new file mode 100644 index 000000000..03feceb75 --- /dev/null +++ b/benchmarks/workloads/node_template/index.hbs @@ -0,0 +1,8 @@ +<!DOCTYPE html> +<html> +<body> + {{#each text}} + <p>{{this}}</p> + {{/each}} +</body> +</html> diff --git a/benchmarks/workloads/node_template/index.js b/benchmarks/workloads/node_template/index.js new file mode 100644 index 000000000..04a27f356 --- /dev/null +++ b/benchmarks/workloads/node_template/index.js @@ -0,0 +1,43 @@ +const app = require('express')(); +const path = require('path'); +const redis = require('redis'); +const srs = require('secure-random-string'); + +// The hostname is the first argument. +const host_name = process.argv[2]; + +var client = redis.createClient({host: host_name, detect_buffers: true}); + +app.set('views', __dirname); +app.set('view engine', 'hbs'); + +app.get('/', (req, res) => { + var tmp = []; + /* Pull four random keys from the redis server. */ + for (i = 0; i < 4; i++) { + client.get(Math.floor(Math.random() * (100)), function(err, reply) { + tmp.push(reply.toString()); + }); + } + + res.render('index', {text: tmp}); +}); + +/** + * Securely generate a random string. + * @param {number} len + * @return {string} + */ +function randomBody(len) { + return srs({alphanumeric: true, length: len}); +} + +/** Mutates one hundred keys randomly. */ +function generateText() { + for (i = 0; i < 100; i++) { + client.set(i, randomBody(1024)); + } +} + +generateText(); +app.listen(8080); diff --git a/benchmarks/workloads/node_template/package-lock.json b/benchmarks/workloads/node_template/package-lock.json new file mode 100644 index 000000000..580e68aa5 --- /dev/null +++ b/benchmarks/workloads/node_template/package-lock.json @@ -0,0 +1,486 @@ +{ + "name": "nodedum", + "version": "1.0.0", + "lockfileVersion": 1, + "requires": true, + "dependencies": { + "accepts": { + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz", + "integrity": "sha1-63d99gEXI6OxTopywIBcjoZ0a9I=", + "requires": { + "mime-types": "~2.1.18", + "negotiator": "0.6.1" + } + }, + "array-flatten": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI=" + }, + "async": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/async/-/async-2.6.2.tgz", + "integrity": "sha512-H1qVYh1MYhEEFLsP97cVKqCGo7KfCyTt6uEWqsTBr9SO84oK9Uwbyd/yCW+6rKJLHksBNUVWZDAjfS+Ccx0Bbg==", + "requires": { + "lodash": "^4.17.11" + } + }, + "body-parser": { + "version": "1.18.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.18.3.tgz", + "integrity": "sha1-WykhmP/dVTs6DyDe0FkrlWlVyLQ=", + "requires": { + "bytes": "3.0.0", + "content-type": "~1.0.4", + "debug": "2.6.9", + "depd": "~1.1.2", + "http-errors": "~1.6.3", + "iconv-lite": "0.4.23", + "on-finished": "~2.3.0", + "qs": "6.5.2", + "raw-body": "2.3.3", + "type-is": "~1.6.16" + } + }, + "bytes": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", + "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=" + }, + "commander": { + "version": "2.20.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.0.tgz", + "integrity": "sha512-7j2y+40w61zy6YC2iRNpUe/NwhNyoXrYpHMrSunaMG64nRnaf96zO/KMQR4OyN/UnE5KLyEBnKHd4aG3rskjpQ==", + "optional": true + }, + "content-disposition": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.2.tgz", + "integrity": "sha1-DPaLud318r55YcOoUXjLhdunjLQ=" + }, + "content-type": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.4.tgz", + "integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA==" + }, + "cookie": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz", + "integrity": "sha1-5+Ch+e9DtMi6klxcWpboBtFoc7s=" + }, + "cookie-signature": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", + "integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw=" + }, + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "requires": { + "ms": "2.0.0" + } + }, + "depd": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz", + "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=" + }, + "destroy": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz", + "integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA=" + }, + "double-ended-queue": { + "version": "2.1.0-0", + "resolved": "https://registry.npmjs.org/double-ended-queue/-/double-ended-queue-2.1.0-0.tgz", + "integrity": "sha1-ED01J/0xUo9AGIEwyEHv3XgmTlw=" + }, + "ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0=" + }, + "encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k=" + }, + "escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=" + }, + "etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc=" + }, + "express": { + "version": "4.16.4", + "resolved": "https://registry.npmjs.org/express/-/express-4.16.4.tgz", + "integrity": "sha512-j12Uuyb4FMrd/qQAm6uCHAkPtO8FDTRJZBDd5D2KOL2eLaz1yUNdUB/NOIyq0iU4q4cFarsUCrnFDPBcnksuOg==", + "requires": { + "accepts": "~1.3.5", + "array-flatten": "1.1.1", + "body-parser": "1.18.3", + "content-disposition": "0.5.2", + "content-type": "~1.0.4", + "cookie": "0.3.1", + "cookie-signature": "1.0.6", + "debug": "2.6.9", + "depd": "~1.1.2", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "finalhandler": "1.1.1", + "fresh": "0.5.2", + "merge-descriptors": "1.0.1", + "methods": "~1.1.2", + "on-finished": "~2.3.0", + "parseurl": "~1.3.2", + "path-to-regexp": "0.1.7", + "proxy-addr": "~2.0.4", + "qs": "6.5.2", + "range-parser": "~1.2.0", + "safe-buffer": "5.1.2", + "send": "0.16.2", + "serve-static": "1.13.2", + "setprototypeof": "1.1.0", + "statuses": "~1.4.0", + "type-is": "~1.6.16", + "utils-merge": "1.0.1", + "vary": "~1.1.2" + } + }, + "finalhandler": { + "version": "1.1.1", + "resolved": "http://registry.npmjs.org/finalhandler/-/finalhandler-1.1.1.tgz", + "integrity": "sha512-Y1GUDo39ez4aHAw7MysnUD5JzYX+WaIj8I57kO3aEPT1fFRL4sr7mjei97FgnwhAyyzRYmQZaTHb2+9uZ1dPtg==", + "requires": { + "debug": "2.6.9", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "on-finished": "~2.3.0", + "parseurl": "~1.3.2", + "statuses": "~1.4.0", + "unpipe": "~1.0.0" + } + }, + "foreachasync": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz", + "integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY=" + }, + "forwarded": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz", + "integrity": "sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ=" + }, + "fresh": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", + "integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac=" + }, + "handlebars": { + "version": "4.0.14", + "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.0.14.tgz", + "integrity": "sha512-E7tDoyAA8ilZIV3xDJgl18sX3M8xB9/fMw8+mfW4msLW8jlX97bAnWgT3pmaNXuvzIEgSBMnAHfuXsB2hdzfow==", + "requires": { + "async": "^2.5.0", + "optimist": "^0.6.1", + "source-map": "^0.6.1", + "uglify-js": "^3.1.4" + } + }, + "hbs": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.0.4.tgz", + "integrity": "sha512-esVlyV/V59mKkwFai5YmPRSNIWZzhqL5YMN0++ueMxyK1cCfPa5f6JiHtapPKAIVAhQR6rpGxow0troav9WMEg==", + "requires": { + "handlebars": "4.0.14", + "walk": "2.3.9" + } + }, + "http-errors": { + "version": "1.6.3", + "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", + "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", + "requires": { + "depd": "~1.1.2", + "inherits": "2.0.3", + "setprototypeof": "1.1.0", + "statuses": ">= 1.4.0 < 2" + } + }, + "iconv-lite": { + "version": "0.4.23", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.23.tgz", + "integrity": "sha512-neyTUVFtahjf0mB3dZT77u+8O0QB89jFdnBkd5P1JgYPbPaia3gXXOVL2fq8VyU2gMMD7SaN7QukTB/pmXYvDA==", + "requires": { + "safer-buffer": ">= 2.1.2 < 3" + } + }, + "inherits": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", + "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=" + }, + "ipaddr.js": { + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.8.0.tgz", + "integrity": "sha1-6qM9bd16zo9/b+DJygRA5wZzix4=" + }, + "lodash": { + "version": "4.17.15", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz", + "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A==" + }, + "media-typer": { + "version": "0.3.0", + "resolved": "http://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", + "integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=" + }, + "merge-descriptors": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", + "integrity": "sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E=" + }, + "methods": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", + "integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4=" + }, + "mime": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/mime/-/mime-1.4.1.tgz", + "integrity": "sha512-KI1+qOZu5DcW6wayYHSzR/tXKCDC5Om4s1z2QJjDULzLcmf3DvzS7oluY4HCTrc+9FiKmWUgeNLg7W3uIQvxtQ==" + }, + "mime-db": { + "version": "1.37.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.37.0.tgz", + "integrity": "sha512-R3C4db6bgQhlIhPU48fUtdVmKnflq+hRdad7IyKhtFj06VPNVdk2RhiYL3UjQIlso8L+YxAtFkobT0VK+S/ybg==" + }, + "mime-types": { + "version": "2.1.21", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.21.tgz", + "integrity": "sha512-3iL6DbwpyLzjR3xHSFNFeb9Nz/M8WDkX33t1GFQnFOllWk8pOrh/LSrB5OXlnlW5P9LH73X6loW/eogc+F5lJg==", + "requires": { + "mime-db": "~1.37.0" + } + }, + "minimist": { + "version": "0.0.10", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz", + "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8=" + }, + "ms": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=" + }, + "negotiator": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.1.tgz", + "integrity": "sha1-KzJxhOiZIQEXeyhWP7XnECrNDKk=" + }, + "on-finished": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.3.0.tgz", + "integrity": "sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=", + "requires": { + "ee-first": "1.1.1" + } + }, + "optimist": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/optimist/-/optimist-0.6.1.tgz", + "integrity": "sha1-2j6nRob6IaGaERwybpDrFaAZZoY=", + "requires": { + "minimist": "~0.0.1", + "wordwrap": "~0.0.2" + } + }, + "parseurl": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.2.tgz", + "integrity": "sha1-/CidTtiZMRlGDBViUyYs3I3mW/M=" + }, + "path-to-regexp": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", + "integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w=" + }, + "proxy-addr": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.4.tgz", + "integrity": "sha512-5erio2h9jp5CHGwcybmxmVqHmnCBZeewlfJ0pex+UW7Qny7OOZXTtH56TGNyBizkgiOwhJtMKrVzDTeKcySZwA==", + "requires": { + "forwarded": "~0.1.2", + "ipaddr.js": "1.8.0" + } + }, + "qs": { + "version": "6.5.2", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.5.2.tgz", + "integrity": "sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA==" + }, + "range-parser": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.0.tgz", + "integrity": "sha1-9JvmtIeJTdxA3MlKMi9hEJLgDV4=" + }, + "raw-body": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.3.3.tgz", + "integrity": "sha512-9esiElv1BrZoI3rCDuOuKCBRbuApGGaDPQfjSflGxdy4oyzqghxu6klEkkVIvBje+FF0BX9coEv8KqW6X/7njw==", + "requires": { + "bytes": "3.0.0", + "http-errors": "1.6.3", + "iconv-lite": "0.4.23", + "unpipe": "1.0.0" + } + }, + "redis": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/redis/-/redis-2.8.0.tgz", + "integrity": "sha512-M1OkonEQwtRmZv4tEWF2VgpG0JWJ8Fv1PhlgT5+B+uNq2cA3Rt1Yt/ryoR+vQNOQcIEgdCdfH0jr3bDpihAw1A==", + "requires": { + "double-ended-queue": "^2.1.0-0", + "redis-commands": "^1.2.0", + "redis-parser": "^2.6.0" + }, + "dependencies": { + "redis-commands": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.4.0.tgz", + "integrity": "sha512-cu8EF+MtkwI4DLIT0x9P8qNTLFhQD4jLfxLR0cCNkeGzs87FN6879JOJwNQR/1zD7aSYNbU0hgsV9zGY71Itvw==" + }, + "redis-parser": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", + "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" + } + } + }, + "redis-commands": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.5.0.tgz", + "integrity": "sha512-6KxamqpZ468MeQC3bkWmCB1fp56XL64D4Kf0zJSwDZbVLLm7KFkoIcHrgRvQ+sk8dnhySs7+yBg94yIkAK7aJg==" + }, + "redis-parser": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", + "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" + }, + "safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==" + }, + "safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==" + }, + "secure-random-string": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.0.tgz", + "integrity": "sha512-V/h8jqoz58zklNGybVhP++cWrxEPXlLM/6BeJ4e0a8zlb4BsbYRzFs16snrxByPa5LUxCVTD3M6EYIVIHR1fAg==" + }, + "send": { + "version": "0.16.2", + "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz", + "integrity": "sha512-E64YFPUssFHEFBvpbbjr44NCLtI1AohxQ8ZSiJjQLskAdKuriYEP6VyGEsRDH8ScozGpkaX1BGvhanqCwkcEZw==", + "requires": { + "debug": "2.6.9", + "depd": "~1.1.2", + "destroy": "~1.0.4", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "fresh": "0.5.2", + "http-errors": "~1.6.2", + "mime": "1.4.1", + "ms": "2.0.0", + "on-finished": "~2.3.0", + "range-parser": "~1.2.0", + "statuses": "~1.4.0" + } + }, + "serve-static": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.13.2.tgz", + "integrity": "sha512-p/tdJrO4U387R9oMjb1oj7qSMaMfmOyd4j9hOFoxZe2baQszgHcSWjuya/CiT5kgZZKRudHNOA0pYXOl8rQ5nw==", + "requires": { + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "parseurl": "~1.3.2", + "send": "0.16.2" + } + }, + "setprototypeof": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz", + "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ==" + }, + "source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==" + }, + "statuses": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.4.0.tgz", + "integrity": "sha512-zhSCtt8v2NDrRlPQpCNtw/heZLtfUDqxBM1udqikb/Hbk52LK4nQSwr10u77iopCW5LsyHpuXS0GnEc48mLeew==" + }, + "type-is": { + "version": "1.6.16", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.16.tgz", + "integrity": "sha512-HRkVv/5qY2G6I8iab9cI7v1bOIdhm94dVjQCPFElW9W+3GeDOSHmy2EBYe4VTApuzolPcmgFTN3ftVJRKR2J9Q==", + "requires": { + "media-typer": "0.3.0", + "mime-types": "~2.1.18" + } + }, + "uglify-js": { + "version": "3.5.9", + "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.5.9.tgz", + "integrity": "sha512-WpT0RqsDtAWPNJK955DEnb6xjymR8Fn0OlK4TT4pS0ASYsVPqr5ELhgwOwLCP5J5vHeJ4xmMmz3DEgdqC10JeQ==", + "optional": true, + "requires": { + "commander": "~2.20.0", + "source-map": "~0.6.1" + } + }, + "unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw=" + }, + "utils-merge": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", + "integrity": "sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM=" + }, + "vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=" + }, + "walk": { + "version": "2.3.9", + "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.9.tgz", + "integrity": "sha1-MbTbZnjyrgHDnqn7hyWpAx5Vins=", + "requires": { + "foreachasync": "^3.0.0" + } + }, + "wordwrap": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz", + "integrity": "sha1-o9XabNXAvAAI03I0u68b7WMFkQc=" + } + } +} diff --git a/benchmarks/workloads/node_template/package.json b/benchmarks/workloads/node_template/package.json new file mode 100644 index 000000000..7dcadd523 --- /dev/null +++ b/benchmarks/workloads/node_template/package.json @@ -0,0 +1,19 @@ +{ + "name": "nodedum", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "author": "", + "license": "ISC", + "dependencies": { + "express": "^4.16.4", + "hbs": "^4.0.4", + "redis": "^2.8.0", + "redis-commands": "^1.2.0", + "redis-parser": "^2.6.0", + "secure-random-string": "^1.1.0" + } +} diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/redis/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/redis/Dockerfile b/benchmarks/workloads/redis/Dockerfile new file mode 100644 index 000000000..0f17249af --- /dev/null +++ b/benchmarks/workloads/redis/Dockerfile @@ -0,0 +1 @@ +FROM redis:5.0.4 diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD new file mode 100644 index 000000000..d40e75a3a --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "redisbenchmark", + srcs = ["__init__.py"], +) + +py_test( + name = "redisbenchmark_test", + srcs = ["redisbenchmark_test.py"], + python_version = "PY3", + deps = [ + ":redisbenchmark", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/redisbenchmark/Dockerfile b/benchmarks/workloads/redisbenchmark/Dockerfile new file mode 100644 index 000000000..f94f6442e --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/Dockerfile @@ -0,0 +1,4 @@ +FROM redis:5.0.4 +ENV host localhost +ENV port 6379 +CMD ["sh", "-c", "redis-benchmark --csv -h ${host} -p ${port} ${flags}"] diff --git a/benchmarks/workloads/redisbenchmark/__init__.py b/benchmarks/workloads/redisbenchmark/__init__.py new file mode 100644 index 000000000..229cef5fa --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/__init__.py @@ -0,0 +1,85 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Redis-benchmark tool.""" + +import re + +OPERATIONS = [ + "PING_INLINE", + "PING_BULK", + "SET", + "GET", + "INCR", + "LPUSH", + "RPUSH", + "LPOP", + "RPOP", + "SADD", + "HSET", + "SPOP", + "LRANGE_100", + "LRANGE_300", + "LRANGE_500", + "LRANGE_600", + "MSET", +] + +METRICS = dict() + +SAMPLE_DATA = """ +"PING_INLINE","48661.80" +"PING_BULK","50301.81" +"SET","48923.68" +"GET","49382.71" +"INCR","49975.02" +"LPUSH","49875.31" +"RPUSH","50276.52" +"LPOP","50327.12" +"RPOP","50556.12" +"SADD","49504.95" +"HSET","49504.95" +"SPOP","50025.02" +"LPUSH (needed to benchmark LRANGE)","48875.86" +"LRANGE_100 (first 100 elements)","33955.86" +"LRANGE_300 (first 300 elements)","16550.81" +"LRANGE_500 (first 450 elements)","13653.74" +"LRANGE_600 (first 600 elements)","11219.57" +"MSET (10 keys)","44682.75" +""" + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# Bind a metric for each operation noted above. +for op in OPERATIONS: + + def bind(metric): + """Bind op to a new scope.""" + + # pylint: disable=unused-argument + def parse(data: str, **kwargs) -> float: + """Operation throughput in requests/sec.""" + regex = r"\"" + metric + r"( .*)?\",\"(\d*.\d*)" + res = re.compile(regex).search(data) + if res: + return float(res.group(2)) + return 0.0 + + parse.__name__ = metric + return parse + + METRICS[op] = bind(op) diff --git a/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py new file mode 100644 index 000000000..419ced059 --- /dev/null +++ b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py @@ -0,0 +1,51 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser test.""" + +import sys + +import pytest + +from benchmarks.workloads import redisbenchmark + +RESULTS = { + "PING_INLINE": 48661.80, + "PING_BULK": 50301.81, + "SET": 48923.68, + "GET": 49382.71, + "INCR": 49975.02, + "LPUSH": 49875.31, + "RPUSH": 50276.52, + "LPOP": 50327.12, + "RPOP": 50556.12, + "SADD": 49504.95, + "HSET": 49504.95, + "SPOP": 50025.02, + "LRANGE_100": 33955.86, + "LRANGE_300": 16550.81, + "LRANGE_500": 13653.74, + "LRANGE_600": 11219.57, + "MSET": 44682.75 +} + + +def test_metrics(): + """Test all metrics.""" + for (metric, func) in redisbenchmark.METRICS.items(): + res = func(redisbenchmark.sample()) + assert float(res) == RESULTS[metric] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD new file mode 100644 index 000000000..9846c7e70 --- /dev/null +++ b/benchmarks/workloads/ruby/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "Gemfile", + "Gemfile.lock", + "config.ru", + "index.rb", + ], +) diff --git a/benchmarks/workloads/ruby/Dockerfile b/benchmarks/workloads/ruby/Dockerfile new file mode 100644 index 000000000..a9a7a7086 --- /dev/null +++ b/benchmarks/workloads/ruby/Dockerfile @@ -0,0 +1,28 @@ +# example based on https://github.com/errm/fib + +FROM ruby:2.5 + +RUN apt-get update -qq && apt-get install -y build-essential libpq-dev nodejs libsodium-dev + +# Set an environment variable where the Rails app is installed to inside of Docker image +ENV RAILS_ROOT /var/www/app_name +RUN mkdir -p $RAILS_ROOT + +# Set working directory +WORKDIR $RAILS_ROOT + +# Setting env up +ENV RAILS_ENV='production' +ENV RACK_ENV='production' + +# Adding gems +COPY Gemfile Gemfile +COPY Gemfile.lock Gemfile.lock +RUN bundle install --jobs 20 --retry 5 --without development test + +# Adding project files +COPY . . + +EXPOSE $PORT +STOPSIGNAL SIGINT +CMD ["bundle", "exec", "puma", "config.ru"] diff --git a/benchmarks/workloads/ruby/Gemfile b/benchmarks/workloads/ruby/Gemfile new file mode 100644 index 000000000..8f1bdad6e --- /dev/null +++ b/benchmarks/workloads/ruby/Gemfile @@ -0,0 +1,12 @@ +source "https://rubygems.org" +# load a bunch of dependencies to take up memory +gem "sinatra" +gem "puma" +gem "redis" +gem 'rake' +gem 'squid', '~> 1.4' +gem 'cassandra-driver' +gem 'ruby-fann' +gem 'rbnacl' +gem 'bcrypt' +gem "activemerchant"
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock new file mode 100644 index 000000000..b44817bd3 --- /dev/null +++ b/benchmarks/workloads/ruby/Gemfile.lock @@ -0,0 +1,55 @@ +GEM + remote: https://rubygems.org/ + specs: + activesupport (5.2.3) + concurrent-ruby (~> 1.0, >= 1.0.2) + i18n (>= 0.7, < 2) + minitest (~> 5.1) + tzinfo (~> 1.1) + cassandra-driver (3.2.3) + ione (~> 1.2) + concurrent-ruby (1.1.5) + i18n (1.6.0) + concurrent-ruby (~> 1.0) + ione (1.2.4) + minitest (5.11.3) + mustermann (1.0.3) + pdf-core (0.7.0) + prawn (2.2.2) + pdf-core (~> 0.7.0) + ttfunk (~> 1.5) + puma (3.12.1) + rack (2.0.7) + rack-protection (2.0.5) + rack + rake (12.3.2) + redis (4.1.1) + ruby-fann (1.2.6) + sinatra (2.0.5) + mustermann (~> 1.0) + rack (~> 2.0) + rack-protection (= 2.0.5) + tilt (~> 2.0) + squid (1.4.1) + activesupport (>= 4.0) + prawn (~> 2.2) + thread_safe (0.3.6) + tilt (2.0.9) + ttfunk (1.5.1) + tzinfo (1.2.5) + thread_safe (~> 0.1) + +PLATFORMS + ruby + +DEPENDENCIES + cassandra-driver + puma + rake + redis + ruby-fann + sinatra + squid (~> 1.4) + +BUNDLED WITH + 1.17.1 diff --git a/benchmarks/workloads/ruby/config.ru b/benchmarks/workloads/ruby/config.ru new file mode 100755 index 000000000..fbd5acc82 --- /dev/null +++ b/benchmarks/workloads/ruby/config.ru @@ -0,0 +1,2 @@ +require './index' +run Sinatra::Application
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/index.rb b/benchmarks/workloads/ruby/index.rb new file mode 100755 index 000000000..5fa85af93 --- /dev/null +++ b/benchmarks/workloads/ruby/index.rb @@ -0,0 +1,14 @@ +require "sinatra" +require "puma" +require "redis" +require "rake" +require "squid" +require "cassandra" +require "ruby-fann" +require "rbnacl" +require "bcrypt" +require "activemerchant" + +get "/" do + "Hello World!" +end
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD new file mode 100644 index 000000000..2b99892af --- /dev/null +++ b/benchmarks/workloads/ruby_template/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "Gemfile", + "Gemfile.lock", + "config.ru", + "index.erb", + "main.rb", + ], +) diff --git a/benchmarks/workloads/ruby_template/Dockerfile b/benchmarks/workloads/ruby_template/Dockerfile new file mode 100755 index 000000000..a06d68bf4 --- /dev/null +++ b/benchmarks/workloads/ruby_template/Dockerfile @@ -0,0 +1,38 @@ +# example based on https://github.com/errm/fib + +FROM alpine:3.9 as build + +COPY Gemfile Gemfile.lock ./ + +RUN apk add --no-cache ruby ruby-dev ruby-bundler ruby-json build-base bash \ + && bundle install --frozen -j4 -r3 --no-cache --without development \ + && apk del --no-cache ruby-bundler \ + && rm -rf /usr/lib/ruby/gems/*/cache + +FROM alpine:3.9 as prod + +COPY --from=build /usr/lib/ruby/gems /usr/lib/ruby/gems +RUN apk add --no-cache ruby ruby-json ruby-etc redis apache2-utils \ + && ruby -e "Gem::Specification.map.each do |spec| \ + Gem::Installer.for_spec( \ + spec, \ + wrappers: true, \ + force: true, \ + install_dir: spec.base_dir, \ + build_args: spec.build_args, \ + ).generate_bin \ + end" + +WORKDIR /app +COPY . /app/. + +ENV PORT=9292 \ + WEB_CONCURRENCY=20 \ + WEB_MAX_THREADS=20 \ + RACK_ENV=production + +ENV host localhost +EXPOSE $PORT +USER nobody +STOPSIGNAL SIGINT +CMD ["sh", "-c", "/usr/bin/puma", "${host}"] diff --git a/benchmarks/workloads/ruby_template/Gemfile b/benchmarks/workloads/ruby_template/Gemfile new file mode 100755 index 000000000..ac521b32c --- /dev/null +++ b/benchmarks/workloads/ruby_template/Gemfile @@ -0,0 +1,5 @@ +source "https://rubygems.org" + +gem "sinatra" +gem "puma" +gem "redis"
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock new file mode 100644 index 000000000..dd8d56fb7 --- /dev/null +++ b/benchmarks/workloads/ruby_template/Gemfile.lock @@ -0,0 +1,26 @@ +GEM + remote: https://rubygems.org/ + specs: + mustermann (1.0.3) + puma (3.12.0) + rack (2.0.6) + rack-protection (2.0.5) + rack + sinatra (2.0.5) + mustermann (~> 1.0) + rack (~> 2.0) + rack-protection (= 2.0.5) + tilt (~> 2.0) + tilt (2.0.9) + redis (4.1.0) + +PLATFORMS + ruby + +DEPENDENCIES + puma + sinatra + redis + +BUNDLED WITH + 1.17.1
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/config.ru b/benchmarks/workloads/ruby_template/config.ru new file mode 100755 index 000000000..b2d135cc0 --- /dev/null +++ b/benchmarks/workloads/ruby_template/config.ru @@ -0,0 +1,2 @@ +require './main' +run Sinatra::Application
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/index.erb b/benchmarks/workloads/ruby_template/index.erb new file mode 100755 index 000000000..7f7300e80 --- /dev/null +++ b/benchmarks/workloads/ruby_template/index.erb @@ -0,0 +1,8 @@ +<!DOCTYPE html> +<html> +<body> + <% text.each do |t| %> + <p><%= t %></p> + <% end %> +</body> +</html> diff --git a/benchmarks/workloads/ruby_template/main.rb b/benchmarks/workloads/ruby_template/main.rb new file mode 100755 index 000000000..35c239377 --- /dev/null +++ b/benchmarks/workloads/ruby_template/main.rb @@ -0,0 +1,27 @@ +require "sinatra" +require "securerandom" +require "redis" + +redis_host = ENV["host"] +$redis = Redis.new(host: redis_host) + +def generateText + for i in 0..99 + $redis.set(i, randomBody(1024)) + end +end + +def randomBody(length) + return SecureRandom.alphanumeric(length) +end + +generateText +template = ERB.new(File.read('./index.erb')) + +get "/" do + texts = Array.new + for i in 0..4 + texts.push($redis.get(rand(0..99))) + end + template.result_with_hash(text: texts) +end
\ No newline at end of file diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/sleep/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/sleep/Dockerfile b/benchmarks/workloads/sleep/Dockerfile new file mode 100644 index 000000000..24c72e07a --- /dev/null +++ b/benchmarks/workloads/sleep/Dockerfile @@ -0,0 +1,3 @@ +FROM alpine:latest + +CMD ["sleep", "315360000"] diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD new file mode 100644 index 000000000..35f4d460b --- /dev/null +++ b/benchmarks/workloads/sysbench/BUILD @@ -0,0 +1,35 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "sysbench", + srcs = ["__init__.py"], +) + +py_test( + name = "sysbench_test", + srcs = ["sysbench_test.py"], + python_version = "PY3", + deps = [ + ":sysbench", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/sysbench/Dockerfile b/benchmarks/workloads/sysbench/Dockerfile new file mode 100644 index 000000000..8225e0e14 --- /dev/null +++ b/benchmarks/workloads/sysbench/Dockerfile @@ -0,0 +1,16 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + sysbench \ + && rm -rf /var/lib/apt/lists/* + +# Parameterize the tests. +ENV test cpu +ENV threads 1 +ENV options "" + +# run sysbench once as a warm-up and take the second result +CMD ["sh", "-c", "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null && \ +sysbench --threads=${threads} ${options} ${test} run"] diff --git a/benchmarks/workloads/sysbench/__init__.py b/benchmarks/workloads/sysbench/__init__.py new file mode 100644 index 000000000..de357b4db --- /dev/null +++ b/benchmarks/workloads/sysbench/__init__.py @@ -0,0 +1,167 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sysbench.""" + +import re + +STD_REGEX = r"events per second:\s*(\d*.?\d*)\n" +MEM_REGEX = r"Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)" +ALT_REGEX = r"execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)" +AVG_REGEX = r"avg:[^\n^\d]*(\d*\.?\d*)" + +SAMPLE_CPU_DATA = """ +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Prime numbers limit: 10000 + +Initializing worker threads... + +Threads started! + +CPU speed: + events per second: 9093.38 + +General statistics: + total time: 10.0007s + total number of events: 90949 + +Latency (ms): + min: 0.64 + avg: 0.88 + max: 24.65 + 95th percentile: 1.55 + sum: 79936.91 + +Threads fairness: + events (avg/stddev): 11368.6250/831.38 + execution time (avg/stddev): 9.9921/0.01 +""" + +SAMPLE_MEMORY_DATA = """ +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Running memory speed test with the following options: + block size: 1KiB + total size: 102400MiB + operation: write + scope: global + +Initializing worker threads... + +Threads started! + +Total operations: 47999046 (9597428.64 per second) + +46874.07 MiB transferred (9372.49 MiB/sec) + + +General statistics: + total time: 5.0001s + total number of events: 47999046 + +Latency (ms): + min: 0.00 + avg: 0.00 + max: 0.21 + 95th percentile: 0.00 + sum: 33165.91 + +Threads fairness: + events (avg/stddev): 5999880.7500/111242.52 + execution time (avg/stddev): 4.1457/0.09 +""" + +SAMPLE_MUTEX_DATA = """ +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Initializing worker threads... + +Threads started! + + +General statistics: + total time: 3.7869s + total number of events: 8 + +Latency (ms): + min: 3688.56 + avg: 3754.03 + max: 3780.94 + 95th percentile: 3773.42 + sum: 30032.28 + +Threads fairness: + events (avg/stddev): 1.0000/0.00 + execution time (avg/stddev): 3.7540/0.03 +""" + + +# pylint: disable=unused-argument +def sample(test, **kwargs): + switch = { + "cpu": SAMPLE_CPU_DATA, + "memory": SAMPLE_MEMORY_DATA, + "mutex": SAMPLE_MUTEX_DATA, + "randwr": SAMPLE_CPU_DATA + } + return switch[test] + + +# pylint: disable=unused-argument +def cpu_events_per_second(data: str, **kwargs) -> float: + """Returns events per second.""" + return float(re.compile(STD_REGEX).search(data).group(1)) + + +# pylint: disable=unused-argument +def memory_ops_per_second(data: str, **kwargs) -> float: + """Returns memory operations per second.""" + return float(re.compile(MEM_REGEX).search(data).group(1)) + + +# pylint: disable=unused-argument +def mutex_time(data: str, count: int, locks: int, threads: int, + **kwargs) -> float: + """Returns normalized mutex time (lower is better).""" + value = float(re.compile(ALT_REGEX).search(data).group(1)) + contention = float(threads) / float(locks) + scale = contention * float(count) / 100000000.0 + return value / scale + + +# pylint: disable=unused-argument +def mutex_deviation(data: str, **kwargs) -> float: + """Returns deviation for threads.""" + return float(re.compile(ALT_REGEX).search(data).group(2)) + + +# pylint: disable=unused-argument +def mutex_latency(data: str, **kwargs) -> float: + """Returns average mutex latency.""" + return float(re.compile(AVG_REGEX).search(data).group(1)) diff --git a/benchmarks/workloads/sysbench/sysbench_test.py b/benchmarks/workloads/sysbench/sysbench_test.py new file mode 100644 index 000000000..3fb541fd2 --- /dev/null +++ b/benchmarks/workloads/sysbench/sysbench_test.py @@ -0,0 +1,34 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parser test.""" + +import sys + +import pytest + +from benchmarks.workloads import sysbench + + +def test_sysbench_parser(): + """Test the basic parser.""" + assert sysbench.cpu_events_per_second(sysbench.sample("cpu")) == 9093.38 + assert sysbench.memory_ops_per_second(sysbench.sample("memory")) == 9597428.64 + assert sysbench.mutex_time(sysbench.sample("mutex"), 1, 1, + 100000000.0) == 3.754 + assert sysbench.mutex_deviation(sysbench.sample("mutex")) == 0.03 + assert sysbench.mutex_latency(sysbench.sample("mutex")) == 3754.03 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD new file mode 100644 index 000000000..e1ff3059b --- /dev/null +++ b/benchmarks/workloads/syscall/BUILD @@ -0,0 +1,36 @@ +load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") + +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "syscall", + srcs = ["__init__.py"], +) + +py_test( + name = "syscall_test", + srcs = ["syscall_test.py"], + python_version = "PY3", + deps = [ + ":syscall", + requirement("attrs", False), + requirement("atomicwrites", False), + requirement("more-itertools", False), + requirement("pathlib2", False), + requirement("pluggy", False), + requirement("py", False), + requirement("pytest", True), + requirement("six", False), + ], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + "syscall.c", + ], +) diff --git a/benchmarks/workloads/syscall/Dockerfile b/benchmarks/workloads/syscall/Dockerfile new file mode 100644 index 000000000..a2088d953 --- /dev/null +++ b/benchmarks/workloads/syscall/Dockerfile @@ -0,0 +1,6 @@ +FROM gcc:latest +COPY . /usr/src/syscall +WORKDIR /usr/src/syscall +RUN gcc -O2 -o syscall syscall.c +ENV count 1000000 +CMD ["sh", "-c", "./syscall ${count}"] diff --git a/benchmarks/workloads/syscall/__init__.py b/benchmarks/workloads/syscall/__init__.py new file mode 100644 index 000000000..dc9028faa --- /dev/null +++ b/benchmarks/workloads/syscall/__init__.py @@ -0,0 +1,29 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simple syscall test.""" + +import re + +SAMPLE_DATA = "Called getpid syscall 1000000 times: 1117 ms, 500 ns each." + + +# pylint: disable=unused-argument +def sample(**kwargs) -> str: + return SAMPLE_DATA + + +# pylint: disable=unused-argument +def syscall_time_ns(data: str, **kwargs) -> int: + """Returns average system call time.""" + return float(re.compile(r"(\d+)\sns each.").search(data).group(1)) diff --git a/benchmarks/workloads/syscall/syscall.c b/benchmarks/workloads/syscall/syscall.c new file mode 100644 index 000000000..ded030397 --- /dev/null +++ b/benchmarks/workloads/syscall/syscall.c @@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define _GNU_SOURCE +#include <stdio.h> +#include <stdlib.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <time.h> +#include <unistd.h> + +// Short program that calls getpid() a number of times and outputs time +// diference from the MONOTONIC clock. +int main(int argc, char** argv) { + struct timespec start, stop; + long result; + char buf[80]; + + if (argc < 2) { + printf("Usage:./syscall NUM_TIMES_TO_CALL"); + return 1; + } + + if (clock_gettime(CLOCK_MONOTONIC, &start)) return 1; + + long loops = atoi(argv[1]); + for (long i = 0; i < loops; i++) { + syscall(SYS_gettimeofday, 0, 0); + } + + if (clock_gettime(CLOCK_MONOTONIC, &stop)) return 1; + + if ((stop.tv_nsec - start.tv_nsec) < 0) { + result = (stop.tv_sec - start.tv_sec - 1) * 1000; + result += (stop.tv_nsec - start.tv_nsec + 1000000000) / (1000 * 1000); + } else { + result = (stop.tv_sec - start.tv_sec) * 1000; + result += (stop.tv_nsec - start.tv_nsec) / (1000 * 1000); + } + + printf("Called getpid syscall %d times: %lu ms, %lu ns each.\n", loops, + result, result * 1000000 / loops); + + return 0; +} diff --git a/benchmarks/workloads/syscall/syscall_test.py b/benchmarks/workloads/syscall/syscall_test.py new file mode 100644 index 000000000..72f027de1 --- /dev/null +++ b/benchmarks/workloads/syscall/syscall_test.py @@ -0,0 +1,27 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +from benchmarks.workloads import syscall + + +def test_syscall_time_ns(): + assert syscall.syscall_time_ns(syscall.sample()) == 500 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD new file mode 100644 index 000000000..17f1f8ebb --- /dev/null +++ b/benchmarks/workloads/tensorflow/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +py_library( + name = "tensorflow", + srcs = ["__init__.py"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/tensorflow/Dockerfile b/benchmarks/workloads/tensorflow/Dockerfile new file mode 100644 index 000000000..262643b98 --- /dev/null +++ b/benchmarks/workloads/tensorflow/Dockerfile @@ -0,0 +1,14 @@ +FROM tensorflow/tensorflow:1.13.2 + +RUN apt-get update \ + && apt-get install -y git +RUN git clone https://github.com/aymericdamien/TensorFlow-Examples.git +RUN python -m pip install -U pip setuptools +RUN python -m pip install matplotlib + +WORKDIR /TensorFlow-Examples/examples + +ENV PYTHONPATH="$PYTHONPATH:/TensorFlow-Examples/examples" + +ENV workload "3_NeuralNetworks/convolutional_network.py" +CMD python ${workload} diff --git a/benchmarks/workloads/tensorflow/__init__.py b/benchmarks/workloads/tensorflow/__init__.py new file mode 100644 index 000000000..b5ec213f8 --- /dev/null +++ b/benchmarks/workloads/tensorflow/__init__.py @@ -0,0 +1,20 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Tensorflow example.""" + + +# pylint: disable=unused-argument +def run_time(value, **kwargs): + """Returns the startup and runtime of the Tensorflow workload in seconds.""" + return value diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD new file mode 100644 index 000000000..83f3c71a0 --- /dev/null +++ b/benchmarks/workloads/true/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//benchmarks:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "files", + srcs = [ + "Dockerfile", + ], +) diff --git a/benchmarks/workloads/true/Dockerfile b/benchmarks/workloads/true/Dockerfile new file mode 100644 index 000000000..2e97c921e --- /dev/null +++ b/benchmarks/workloads/true/Dockerfile @@ -0,0 +1,3 @@ +FROM alpine:latest + +CMD ["true"] diff --git a/kokoro/kythe/generate_xrefs.cfg b/kokoro/kythe/generate_xrefs.cfg new file mode 100644 index 000000000..03e65c54e --- /dev/null +++ b/kokoro/kythe/generate_xrefs.cfg @@ -0,0 +1,28 @@ +build_file: "gvisor/kokoro/kythe/generate_xrefs.sh" + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73898 + keyname: "kokoro-rbe-service-account" + } + } +} + +bazel_setting { + project_id: "gvisor-rbe" + local_execution: false + auth_credential: { + keystore_config_id: 73898 + keyname: "kokoro-rbe-service-account" + } + bes_backend_address: "buildeventservice.googleapis.com" + foundry_backend_address: "remotebuildexecution.googleapis.com" + upsalite_frontend_address: "https://source.cloud.google.com" +} + +action { + define_artifacts { + regex: "*.kzip" + } +} diff --git a/kokoro/kythe/generate_xrefs.sh b/kokoro/kythe/generate_xrefs.sh new file mode 100644 index 000000000..49186eeeb --- /dev/null +++ b/kokoro/kythe/generate_xrefs.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -ex + +# Install the latest version of Bazel. The default on Kokoro images is out of +# date. +if command -v use_bazel.sh >/dev/null; then + use_bazel.sh latest +fi +bazel version + +python3 -V + +readonly KYTHE_VERSION='v0.0.37' +readonly WORKDIR="$(mktemp -d)" +readonly KYTHE_DIR="${WORKDIR}/kythe-${KYTHE_VERSION}" +if [[ -n "$KOKORO_GIT_COMMIT" ]]; then + readonly KZIP_FILENAME="${KOKORO_ARTIFACTS_DIR}/${KOKORO_GIT_COMMIT}.kzip" +else + readonly KZIP_FILENAME="$(git rev-parse HEAD).kzip" +fi + +wget -q -O "${WORKDIR}/kythe.tar.gz" \ + "https://github.com/kythe/kythe/releases/download/${KYTHE_VERSION}/kythe-${KYTHE_VERSION}.tar.gz" +tar --no-same-owner -xzf "${WORKDIR}/kythe.tar.gz" --directory "$WORKDIR" + +if [[ -n "$KOKORO_ARTIFACTS_DIR" ]]; then + cd "${KOKORO_ARTIFACTS_DIR}/github/gvisor" +fi +bazel \ + --bazelrc="${KYTHE_DIR}/extractors.bazelrc" \ + build \ + --override_repository kythe_release="${KYTHE_DIR}" \ + --define=kythe_corpus=gvisor.dev \ + //... + +"${KYTHE_DIR}/tools/kzip" merge \ + --output "$KZIP_FILENAME" \ + $(find -L bazel-out/*/extra_actions/ -name '*.kzip') diff --git a/kokoro/ubuntu1604/10_core.sh b/kokoro/ubuntu1604/10_core.sh index e87a6eee8..46dda6bb1 100755 --- a/kokoro/ubuntu1604/10_core.sh +++ b/kokoro/ubuntu1604/10_core.sh @@ -21,8 +21,8 @@ apt-get update && apt-get -y install make git-core build-essential linux-headers # Install a recent go toolchain. if ! [[ -d /usr/local/go ]]; then - wget https://dl.google.com/go/go1.12.linux-amd64.tar.gz - tar -xvf go1.12.linux-amd64.tar.gz + wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz + tar -xvf go1.13.5.linux-amd64.tar.gz mv go /usr/local fi diff --git a/kokoro/ubuntu1604/README.md b/kokoro/ubuntu1604/README.md new file mode 100644 index 000000000..64f913b9a --- /dev/null +++ b/kokoro/ubuntu1604/README.md @@ -0,0 +1,34 @@ +## Image Update + +After making changes to files in the directory, you must run the following +commands to update the image Kokoro uses: + +```shell +gcloud config set project gvisor-kokoro-testing +third_party/gvisor/kokoro/ubuntu1604/build.sh +third_party/gvisor/kokoro/ubuntu1804/build.sh +``` + +Note: the command above will change your default project for `gcloud`. Run +`gcloud config set project` again to revert back to your default project. + +Note: Files in `third_party/gvisor/kokoro/ubuntu1804/` as symlinks to +`ubuntu1604`, therefore both images must be updated. + +After the script finishes, the last few lines of the output will container the +image name. If the output was lost, you can run `build.sh` again to print the +image name. + +``` +NAME PROJECT FAMILY DEPRECATED STATUS +image-6777fa4666a968c8 gvisor-kokoro-testing READY ++ cleanup ++ gcloud compute instances delete --quiet build-tlfrdv +Deleted [https://www.googleapis.com/compute/v1/projects/gvisor-kokoro-testing/zones/us-central1-f/instances/build-tlfrdv]. +``` + +To setup Kokoro to use the new image, copy the image names to their +corresponding file below: + +* //devtools/kokoro/config/gcp/gvisor/ubuntu1604.gcl +* //devtools/kokoro/config/gcp/gvisor/ubuntu1804.gcl diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index f78315ebf..6663a199c 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -16,15 +16,17 @@ package linux // Commands from linux/fcntl.h. const ( - F_DUPFD = 0x0 - F_GETFD = 0x1 - F_SETFD = 0x2 - F_GETFL = 0x3 - F_SETFL = 0x4 - F_SETLK = 0x6 - F_SETLKW = 0x7 - F_SETOWN = 0x8 - F_GETOWN = 0x9 + F_DUPFD = 0 + F_GETFD = 1 + F_SETFD = 2 + F_GETFL = 3 + F_SETFL = 4 + F_SETLK = 6 + F_SETLKW = 7 + F_SETOWN = 8 + F_GETOWN = 9 + F_SETOWN_EX = 15 + F_GETOWN_EX = 16 F_DUPFD_CLOEXEC = 1024 + 6 F_SETPIPE_SZ = 1024 + 7 F_GETPIPE_SZ = 1024 + 8 @@ -32,9 +34,9 @@ const ( // Commands for F_SETLK. const ( - F_RDLCK = 0x0 - F_WRLCK = 0x1 - F_UNLCK = 0x2 + F_RDLCK = 0 + F_WRLCK = 1 + F_UNLCK = 2 ) // Flags for fcntl. @@ -42,7 +44,7 @@ const ( FD_CLOEXEC = 00000001 ) -// Lock structure for F_SETLK. +// Flock is the lock structure for F_SETLK. type Flock struct { Type int16 Whence int16 @@ -52,3 +54,16 @@ type Flock struct { Pid int32 _ [4]byte } + +// Flags for F_SETOWN_EX and F_GETOWN_EX. +const ( + F_OWNER_TID = 0 + F_OWNER_PID = 1 + F_OWNER_PGRP = 2 +) + +// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX. +type FOwnerEx struct { + Type int32 + PID int32 +} diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index c9ee098f4..0f014d27f 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -144,9 +144,13 @@ const ( ModeCharacterDevice = S_IFCHR ModeNamedPipe = S_IFIFO - ModeSetUID = 04000 - ModeSetGID = 02000 - ModeSticky = 01000 + S_ISUID = 04000 + S_ISGID = 02000 + S_ISVTX = 01000 + + ModeSetUID = S_ISUID + ModeSetGID = S_ISGID + ModeSticky = S_ISVTX ModeUserAll = 0700 ModeUserRead = 0400 diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 9e7db8b30..67daa6c24 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -305,7 +305,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs()) return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return c.Native(0), nil } @@ -320,6 +320,6 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { _, err := c.PtraceSetRegs(bytes.NewBuffer(buf)) return err } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return nil } diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 1f78d54a2..e1f2fea60 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -22,6 +22,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/urpc" ) @@ -56,6 +57,9 @@ type Profile struct { // traceFile is the current execution trace output file. traceFile *fd.FD + + // Kernel is the kernel under profile. + Kernel *kernel.Kernel } // StartCPUProfile is an RPC stub which starts recording the CPU profile in a @@ -147,6 +151,9 @@ func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { return err } + // Ensure all trace contexts are registered. + p.Kernel.RebuildTraceContexts() + p.traceFile = output return nil } @@ -158,9 +165,15 @@ func (p *Profile) StopTrace(_, _ *struct{}) error { defer p.mu.Unlock() if p.traceFile == nil { - return errors.New("Execution tracing not start") + return errors.New("Execution tracing not started") } + // Similarly to the case above, if tasks have not ended traces, we will + // lose information. Thus we need to rebuild the tasks in order to have + // complete information. This will not lose information if multiple + // traces are overlapping. + p.Kernel.RebuildTraceContexts() + trace.Stop() p.traceFile.Close() p.traceFile = nil diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index c35faeb4c..ced51c66c 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -268,14 +268,17 @@ func (proc *Proc) Ps(args *PsArgs, out *string) error { } // Process contains information about a single process in a Sandbox. -// TODO(b/117881927): Implement TTY field. type Process struct { UID auth.KUID `json:"uid"` PID kernel.ThreadID `json:"pid"` // Parent PID - PPID kernel.ThreadID `json:"ppid"` + PPID kernel.ThreadID `json:"ppid"` + Threads []kernel.ThreadID `json:"threads"` // Processor utilization C int32 `json:"c"` + // TTY name of the process. Will be of the form "pts/N" if there is a + // TTY, or "?" if there is not. + TTY string `json:"tty"` // Start time STime string `json:"stime"` // CPU time @@ -285,18 +288,19 @@ type Process struct { } // ProcessListToTable prints a table with the following format: -// UID PID PPID C STIME TIME CMD -// 0 1 0 0 14:04 505262ns tail +// UID PID PPID C TTY STIME TIME CMD +// 0 1 0 0 pty/4 14:04 505262ns tail func ProcessListToTable(pl []*Process) string { var buf bytes.Buffer tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0) - fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD") + fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD") for _, d := range pl { - fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s", + fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s", d.UID, d.PID, d.PPID, d.C, + d.TTY, d.STime, d.Time, d.Cmd) @@ -307,7 +311,7 @@ func ProcessListToTable(pl []*Process) string { // ProcessListToJSON will return the JSON representation of ps. func ProcessListToJSON(pl []*Process) (string, error) { - b, err := json.Marshal(pl) + b, err := json.MarshalIndent(pl, "", " ") if err != nil { return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err) } @@ -334,7 +338,9 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ts := k.TaskSet() now := k.RealtimeClock().Now() for _, tg := range ts.Root.ThreadGroups() { - pid := tg.PIDNamespace().IDOfThreadGroup(tg) + pidns := tg.PIDNamespace() + pid := pidns.IDOfThreadGroup(tg) + // If tg has already been reaped ignore it. if pid == 0 { continue @@ -345,16 +351,19 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ppid := kernel.ThreadID(0) if p := tg.Leader().Parent(); p != nil { - ppid = p.PIDNamespace().IDOfThreadGroup(p.ThreadGroup()) + ppid = pidns.IDOfThreadGroup(p.ThreadGroup()) } + threads := tg.MemberIDs(pidns) *out = append(*out, &Process{ - UID: tg.Leader().Credentials().EffectiveKUID, - PID: pid, - PPID: ppid, - STime: formatStartTime(now, tg.Leader().StartTime()), - C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), - Time: tg.CPUStats().SysTime.String(), - Cmd: tg.Leader().Name(), + UID: tg.Leader().Credentials().EffectiveKUID, + PID: pid, + PPID: ppid, + Threads: threads, + STime: formatStartTime(now, tg.Leader().StartTime()), + C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), + Time: tg.CPUStats().SysTime.String(), + Cmd: tg.Leader().Name(), + TTY: ttyName(tg.TTY()), }) } sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID }) @@ -395,3 +404,10 @@ func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 { } return int32(percentCPU) } + +func ttyName(tty *kernel.TTY) string { + if tty == nil { + return "?" + } + return fmt.Sprintf("pts/%d", tty.Index) +} diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go index d8ada2694..0a88459b2 100644 --- a/pkg/sentry/control/proc_test.go +++ b/pkg/sentry/control/proc_test.go @@ -34,7 +34,7 @@ func TestProcessListTable(t *testing.T) { }{ { pl: []*Process{}, - expected: "UID PID PPID C STIME TIME CMD", + expected: "UID PID PPID C TTY STIME TIME CMD", }, { pl: []*Process{ @@ -43,6 +43,7 @@ func TestProcessListTable(t *testing.T) { PID: 0, PPID: 0, C: 0, + TTY: "?", STime: "0", Time: "0", Cmd: "zero", @@ -52,14 +53,15 @@ func TestProcessListTable(t *testing.T) { PID: 1, PPID: 1, C: 1, + TTY: "pts/4", STime: "1", Time: "1", Cmd: "one", }, }, - expected: `UID PID PPID C STIME TIME CMD -0 0 0 0 0 0 zero -1 1 1 1 1 1 one`, + expected: `UID PID PPID C TTY STIME TIME CMD +0 0 0 0 ? 0 0 zero +1 1 1 1 pts/4 1 1 one`, }, } diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 0da608548..4e358a46a 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -143,9 +143,9 @@ type session struct { // socket files. This allows unix domain sockets to be used with paths that // belong to a gofer. // - // TODO(b/77154739): there are few possible races with someone stat'ing the - // file and another deleting it concurrently, where the file will not be - // reported as socket file. + // TODO(gvisor.dev/issue/1200): there are few possible races with someone + // stat'ing the file and another deleting it concurrently, where the file + // will not be reported as socket file. endpoints *endpointMaps `state:"wait"` } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 87184ec67..0e46c5fb7 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -67,29 +67,28 @@ type taskDir struct { var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. -func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode { +func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - // FIXME(b/123511468): create the correct io file for threads. - "io": newIO(t, msrc), + "auxv": newAuxvec(t, msrc), + "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), + "comm": newComm(t, msrc), + "environ": newExecArgInode(t, msrc, environExecArg), + "exe": newExe(t, msrc), + "fd": newFdDir(t, msrc), + "fdinfo": newFdInfoDir(t, msrc), + "gid_map": newGIDMap(t, msrc), + "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "ns": newNamespaceDir(t, msrc), "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, showSubtasks, p.pidns), + "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), "statm": newStatm(t, msrc), "status": newStatus(t, msrc, p.pidns), "uid_map": newUIDMap(t, msrc), } - if showSubtasks { + if isThreadGroup { contents["task"] = p.newSubtasks(t, msrc) } if len(p.cgroupControllers) > 0 { @@ -619,8 +618,11 @@ type ioData struct { ioUsage } -func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) +func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { + if isThreadGroup { + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) + } + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -639,7 +641,7 @@ func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se io.Accumulate(i.IOUsage()) var buf bytes.Buffer - fmt.Fprintf(&buf, "char: %d\n", io.CharsRead) + fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead) fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten) fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls) fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls) diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index ff8138820..917f90cc0 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -53,8 +53,8 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal d: d, n: n, ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{}, - slaveKTTY: &kernel.TTY{}, + masterKTTY: &kernel.TTY{Index: n}, + slaveKTTY: &kernel.TTY{Index: n}, } t.EnableLeakCheck("tty.Terminal") return &t diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 94cd74095..177ce2cb9 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -81,7 +81,11 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf ctx := contexttest.Context(b) creds := auth.CredentialsFromContext(ctx) - if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}); err != nil { + if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } return func() { diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 307e4d68c..e9f756732 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -147,55 +147,54 @@ func TestSeek(t *testing.T) { t.Fatalf("vfsfs.OpenAt failed: %v", err) } - if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { + if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { t.Errorf("expected seek position 0, got %d and error %v", n, err) } - stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{}) + stat, err := fd.Stat(ctx, vfs.StatOptions{}) if err != nil { t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err) } // We should be able to seek beyond the end of file. size := int64(stat.Size) - if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { + if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } - if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { + if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err) } // Make sure negative offsets work with SEEK_CUR. - if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } // Make sure SEEK_END works with regular files. - switch fd.Impl().(type) { - case *regularFileFD: + if _, ok := fd.Impl().(*regularFileFD); ok { // Seek back to 0. - if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { + if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", 0, n, err) } // Seek forward beyond EOF. - if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } } @@ -456,7 +455,7 @@ func TestRead(t *testing.T) { want := make([]byte, 1) for { n, err := f.Read(want) - fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) + fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) if diff := cmp.Diff(got, want); diff != "" { t.Errorf("file data mismatch (-want +got):\n%s", diff) @@ -464,7 +463,7 @@ func TestRead(t *testing.T) { // Make sure there is no more file data left after getting EOF. if n == 0 || err == io.EOF { - if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { + if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image) } @@ -574,7 +573,7 @@ func TestIterDirents(t *testing.T) { } cb := &iterDirentsCb{} - if err = fd.Impl().IterDirents(ctx, cb); err != nil { + if err = fd.IterDirents(ctx, cb); err != nil { t.Fatalf("dir fd.IterDirents() failed: %v", err) } diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/memfs/benchmark_test.go index ea6417ce7..4a7a94a52 100644 --- a/pkg/sentry/fsimpl/memfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/memfs/benchmark_test.go @@ -394,7 +394,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { } defer mountPoint.DecRef() // Create and mount the submount. - if err := vfsObj.NewMount(ctx, creds, "", &pop, "memfs", &vfs.GetFilesystemOptions{}); err != nil { + if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } filePathBuilder.WriteString(mountPointName) diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go index a3a870571..5bf527c80 100644 --- a/pkg/sentry/fsimpl/memfs/pipe_test.go +++ b/pkg/sentry/fsimpl/memfs/pipe_test.go @@ -194,7 +194,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { readData := make([]byte, 1) dst := usermem.BytesIOSequence(readData) - bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{}) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) if err != syserror.ErrWouldBlock { t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) } @@ -207,7 +207,7 @@ func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { writeData := []byte(msg) src := usermem.BytesIOSequence(writeData) - bytesWritten, err := fd.Impl().Write(ctx, src, vfs.WriteOptions{}) + bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{}) if err != nil { t.Fatalf("error writing to pipe %q: %v", fileName, err) } @@ -220,7 +220,7 @@ func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { readData := make([]byte, len(msg)) dst := usermem.BytesIOSequence(readData) - bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{}) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) if err != nil { t.Fatalf("error reading from pipe %q: %v", fileName, err) } diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go index c36c4aff5..0e016bca5 100644 --- a/pkg/sentry/fsimpl/proc/filesystems.go +++ b/pkg/sentry/fsimpl/proc/filesystems.go @@ -19,7 +19,7 @@ package proc // +stateify savable type filesystemsData struct{} -// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for +// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for // filesystemsData. We would need to retrive filesystem names from // vfs.VirtualFilesystem. Also needs vfs replacement for // fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev. diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go index e81b1e910..8683cf677 100644 --- a/pkg/sentry/fsimpl/proc/mounts.go +++ b/pkg/sentry/fsimpl/proc/mounts.go @@ -16,7 +16,7 @@ package proc import "gvisor.dev/gvisor/pkg/sentry/kernel" -// TODO(b/138862512): Implement mountInfoFile and mountsFile. +// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile. // mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo. // diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 28ba950bd..bd3fb4c03 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -841,9 +841,11 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, AbstractSocketNamespace: args.AbstractSocketNamespace, ContainerID: args.ContainerID, } - if _, err := k.tasks.NewTask(config); err != nil { + t, err := k.tasks.NewTask(config) + if err != nil { return nil, 0, err } + t.traceExecEvent(tc) // Simulate exec for tracing. // Success. tgid := k.tasks.Root.IDOfThreadGroup(tg) @@ -1118,6 +1120,22 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { return lastErr } +// RebuildTraceContexts rebuilds the trace context for all tasks. +// +// Unfortunately, if these are built while tracing is not enabled, then we will +// not have meaningful trace data. Rebuilding here ensures that we can do so +// after tracing has been enabled. +func (k *Kernel) RebuildTraceContexts() { + k.extMu.Lock() + defer k.extMu.Unlock() + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + + for t, tid := range k.tasks.Root.tids { + t.rebuildTraceContext(tid) + } +} + // FeatureSet returns the FeatureSet. func (k *Kernel) FeatureSet() *cpuid.FeatureSet { return k.featureSet diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 93fe68a3e..de9617e9d 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -302,7 +302,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred return syserror.ERANGE } - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = val sem.pid = pid s.changeTime = ktime.NowFromContext(ctx) @@ -336,7 +336,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti for i, val := range vals { sem := &s.sems[i] - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = int16(val) sem.pid = pid sem.wakeWaiters() @@ -481,7 +481,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch } // All operations succeeded, apply them. - // TODO(b/29354920): handle undo operations. + // TODO(gvisor.dev/issue/137): handle undo operations. for i, v := range tmpVals { s.sems[i].value = v s.sems[i].wakeWaiters() diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 220fa73a2..2fdee0282 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -339,6 +339,14 @@ func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn { return nil } +// LookupName looks up a syscall name. +func (s *SyscallTable) LookupName(sysno uintptr) string { + if sc, ok := s.Table[sysno]; ok { + return sc.Name + } + return fmt.Sprintf("sys_%d", sysno) // Unlikely. +} + // LookupEmulate looks up an emulation syscall number. func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 80c8e5464..ab0c6c4aa 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -15,6 +15,8 @@ package kernel import ( + gocontext "context" + "runtime/trace" "sync" "sync/atomic" @@ -390,7 +392,14 @@ type Task struct { // logPrefix is a string containing the task's thread ID in the root PID // namespace, and is prepended to log messages emitted by Task.Infof etc. - logPrefix atomic.Value `state:".(string)"` + logPrefix atomic.Value `state:"nosave"` + + // traceContext and traceTask are both used for tracing, and are + // updated along with the logPrefix in updateInfoLocked. + // + // These are exclusive to the task goroutine. + traceContext gocontext.Context `state:"nosave"` + traceTask *trace.Task `state:"nosave"` // creds is the task's credentials. // @@ -528,14 +537,6 @@ func (t *Task) loadPtraceTracer(tracer *Task) { t.ptraceTracer.Store(tracer) } -func (t *Task) saveLogPrefix() string { - return t.logPrefix.Load().(string) -} - -func (t *Task) loadLogPrefix(prefix string) { - t.logPrefix.Store(prefix) -} - func (t *Task) saveSyscallFilters() []bpf.Program { if f := t.syscallFilters.Load(); f != nil { return f.([]bpf.Program) @@ -549,6 +550,7 @@ func (t *Task) loadSyscallFilters(filters []bpf.Program) { // afterLoad is invoked by stateify. func (t *Task) afterLoad() { + t.updateInfoLocked() t.interruptChan = make(chan struct{}, 1) t.gosched.State = TaskGoroutineNonexistent if t.stop != nil { diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index dd69939f9..4a4a69ee2 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -16,6 +16,7 @@ package kernel import ( "runtime" + "runtime/trace" "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -133,19 +134,24 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { runtime.Gosched() } + region := trace.StartRegion(t.traceContext, blockRegion) select { case <-C: + region.End() t.SleepFinish(true) + // Woken by event. return nil case <-interrupt: + region.End() t.SleepFinish(false) // Return the indicated error on interrupt. return syserror.ErrInterrupted case <-timerChan: - // We've timed out. + region.End() t.SleepFinish(true) + // We've timed out. return syserror.ETIMEDOUT } } diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 0916fd658..3eadfedb4 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -299,6 +299,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { // nt that it must receive before its task goroutine starts running. tid := nt.k.tasks.Root.IDOfTask(nt) defer nt.Start(tid) + t.traceCloneEvent(tid) // "If fork/clone and execve are allowed by @prog, any child processes will // be constrained to the same filters and system call ABI as the parent." - diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 17a089b90..90a6190f1 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -129,6 +129,7 @@ type runSyscallAfterExecStop struct { } func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { + t.traceExecEvent(r.tc) t.tg.pidns.owner.mu.Lock() t.tg.execing = nil if t.killed() { @@ -253,7 +254,7 @@ func (t *Task) promoteLocked() { t.tg.leader = t t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t]) - t.updateLogPrefixLocked() + t.updateInfoLocked() // Reap the original leader. If it has a tracer, detach it instead of // waiting for it to acknowledge the original leader's death. oldLeader.exitParentNotified = true diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 535f03e50..435761e5a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -236,6 +236,7 @@ func (*runExit) execute(t *Task) taskRunState { type runExitMain struct{} func (*runExitMain) execute(t *Task) taskRunState { + t.traceExitEvent() lastExiter := t.exitThreadGroup() // If the task has a cleartid, and the thread group wasn't killed by a diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index a29e9b9eb..0fb3661de 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -16,6 +16,7 @@ package kernel import ( "fmt" + "runtime/trace" "sort" "gvisor.dev/gvisor/pkg/log" @@ -127,11 +128,88 @@ func (t *Task) debugDumpStack() { } } -// updateLogPrefix updates the task's cached log prefix to reflect its -// current thread ID. +// trace definitions. +// +// Note that all region names are prefixed by ':' in order to ensure that they +// are lexically ordered before all system calls, which use the naked system +// call name (e.g. "read") for maximum clarity. +const ( + traceCategory = "task" + runRegion = ":run" + blockRegion = ":block" + cpuidRegion = ":cpuid" + faultRegion = ":fault" +) + +// updateInfoLocked updates the task's cached log prefix and tracing +// information to reflect its current thread ID. // // Preconditions: The task's owning TaskSet.mu must be locked. -func (t *Task) updateLogPrefixLocked() { +func (t *Task) updateInfoLocked() { // Use the task's TID in the root PID namespace for logging. - t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t])) + tid := t.tg.pidns.owner.Root.tids[t] + t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid)) + t.rebuildTraceContext(tid) +} + +// rebuildTraceContext rebuilds the trace context. +// +// Precondition: the passed tid must be the tid in the root namespace. +func (t *Task) rebuildTraceContext(tid ThreadID) { + // Re-initialize the trace context. + if t.traceTask != nil { + t.traceTask.End() + } + + // Note that we define the "task type" to be the dynamic TID. This does + // not align perfectly with the documentation for "tasks" in the + // tracing package. Tasks may be assumed to be bounded by analysis + // tools. However, if we just use a generic "task" type here, then the + // "user-defined tasks" page on the tracing dashboard becomes nearly + // unusable, as it loads all traces from all tasks. + // + // We can assume that the number of tasks in the system is not + // arbitrarily large (in general it won't be, especially for cases + // where we're collecting a brief profile), so using the TID is a + // reasonable compromise in this case. + t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid)) +} + +// traceCloneEvent is called when a new task is spawned. +// +// ntid must be the new task's ThreadID in the root namespace. +func (t *Task) traceCloneEvent(ntid ThreadID) { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid) +} + +// traceExitEvent is called when a task exits. +func (t *Task) traceExitEvent() { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status()) +} + +// traceExecEvent is called when a task calls exec. +func (t *Task) traceExecEvent(tc *TaskContext) { + if !trace.IsEnabled() { + return + } + d := tc.MemoryManager.Executable() + if d == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") + return + } + defer d.DecRef() + root := t.fsContext.RootDirectory() + if root == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>") + return + } + defer root.DecRef() + n, _ := d.FullName(root) + trace.Logf(t.traceContext, traceCategory, "exec: %s", n) } diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index c92266c59..d97f8c189 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -17,6 +17,7 @@ package kernel import ( "bytes" "runtime" + "runtime/trace" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -205,9 +206,11 @@ func (*runApp) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.RUnlock() } + region := trace.StartRegion(t.traceContext, runRegion) t.accountTaskGoroutineEnter(TaskGoroutineRunningApp) info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU) t.accountTaskGoroutineLeave(TaskGoroutineRunningApp) + region.End() if clearSinglestep { t.Arch().ClearSingleStep() @@ -225,6 +228,7 @@ func (*runApp) execute(t *Task) taskRunState { case platform.ErrContextSignalCPUID: // Is this a CPUID instruction? + region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) @@ -232,10 +236,12 @@ func (*runApp) execute(t *Task) taskRunState { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) + region.End() // Resume execution. return (*runApp)(nil) } + region.End() // Not an actual CPUID, but required copy-in. // The instruction at the given RIP was not a CPUID, and we // fallthrough to the default signal deliver behavior below. @@ -251,8 +257,10 @@ func (*runApp) execute(t *Task) taskRunState { // an application-generated signal and we should continue execution // normally. if at.Any() { + region := trace.StartRegion(t.traceContext, faultRegion) addr := usermem.Addr(info.Addr()) err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack())) + region.End() if err == nil { // The fault was handled appropriately. // We can resume running the application. @@ -260,6 +268,12 @@ func (*runApp) execute(t *Task) taskRunState { } // Is this a vsyscall that we need emulate? + // + // Note that we don't track vsyscalls as part of a + // specific trace region. This is because regions don't + // stack, and the actual system call will count as a + // region. We should be able to easily identify + // vsyscalls by having a <fault><syscall> pair. if at.Execute { if sysno, ok := t.tc.st.LookupEmulate(addr); ok { return t.doVsyscall(addr, sysno) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index ae6fc4025..3522a4ae5 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -154,10 +154,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { // Below this point, newTask is expected not to fail (there is no rollback // of assignTIDsLocked or any of the following). - // Logging on t's behalf will panic if t.logPrefix hasn't been initialized. - // This is the earliest point at which we can do so (since t now has thread - // IDs). - t.updateLogPrefixLocked() + // Logging on t's behalf will panic if t.logPrefix hasn't been + // initialized. This is the earliest point at which we can do so + // (since t now has thread IDs). + t.updateInfoLocked() if cfg.InheritParent != nil { t.parent = cfg.InheritParent.parent diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index b543d536a..3180f5560 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -17,6 +17,7 @@ package kernel import ( "fmt" "os" + "runtime/trace" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -160,6 +161,10 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u ctrl = ctrlStopAndReinvokeSyscall } else { fn := s.Lookup(sysno) + var region *trace.Region // Only non-nil if tracing == true. + if trace.IsEnabled() { + region = trace.StartRegion(t.traceContext, s.LookupName(sysno)) + } if fn != nil { // Call our syscall implementation. rval, ctrl, err = fn(t, args) @@ -167,6 +172,9 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u // Use the missing function if not found. rval, err = t.SyscallTable().Missing(t, sysno, args) } + if region != nil { + region.End() + } } if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) { diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index 34f84487a..048de26dc 100644 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go @@ -21,8 +21,19 @@ import "sync" // // +stateify savable type TTY struct { + // Index is the terminal index. It is immutable. + Index uint32 + mu sync.Mutex `state:"nosave"` // tg is protected by mu. tg *ThreadGroup } + +// TTY returns the thread group's controlling terminal. If nil, there is no +// controlling terminal. +func (tg *ThreadGroup) TTY() *TTY { + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + return tg.tty +} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index c2c3ec06e..6299a3e2f 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -408,6 +408,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el start = vaddr } if vaddr < end { + // NOTE(b/37474556): Linux allows out-of-order + // segments, in violation of the spec. ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end) return loadedELF{}, syserror.ENOEXEC } diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go index 2f65db70b..ba1f9043d 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sentry/sighandling/sighandling.go @@ -16,7 +16,6 @@ package sighandling import ( - "fmt" "os" "os/signal" "reflect" @@ -31,37 +30,25 @@ const numSignals = 32 // handleSignals listens for incoming signals and calls the given handler // function. // -// It starts when the start channel is closed, stops when the stop channel -// is closed, and closes done once it will no longer deliver signals to k. -func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) { +// It stops when the stop channel is closed. The done channel is closed once it +// will no longer deliver signals to k. +func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) { // Build a select case. - sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}} + sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}} for _, sigchan := range sigchans { sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)}) } - started := false for { // Wait for a notification. index, _, ok := reflect.Select(sc) - // Was it the start / stop channel? + // Was it the stop channel? if index == 0 { if !ok { - if !started { - // start channel; start forwarding and - // swap this case for the stop channel - // to select stop requests. - started = true - sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)} - } else { - // stop channel; stop forwarding and - // clear this case so it is never - // selected again. - started = false - close(done) - sc[0].Chan = reflect.Value{} - } + // Stop forwarding and notify that it's done. + close(done) + return } continue } @@ -73,44 +60,17 @@ func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, // Otherwise, it was a signal on channel N. Index 0 represents the stop // channel, so index N represents the channel for signal N. - signal := linux.Signal(index) - - if !started { - // Kernel cannot receive signals, either because it is - // not ready yet or is shutting down. - // - // Kill ourselves if this signal would have killed the - // process before PrepareForwarding was called. i.e., all - // _SigKill signals; see Go - // src/runtime/sigtab_linux_generic.go. - // - // Otherwise ignore the signal. - // - // TODO(b/114489875): Drop in Go 1.12, which uses tgkill - // in runtime.raise. - switch signal { - case linux.SIGHUP, linux.SIGINT, linux.SIGTERM: - dieFromSignal(signal) - panic(fmt.Sprintf("Failed to die from signal %d", signal)) - default: - continue - } - } - - // Pass the signal to the handler. - handler(signal) + handler(linux.Signal(index)) } } -// PrepareHandler ensures that synchronous signals are passed to the given -// handler function and returns a callback that starts signal delivery, which -// itself returns a callback that stops signal handling. +// StartSignalForwarding ensures that synchronous signals are passed to the +// given handler function and returns a callback that stops signal delivery. // // Note that this function permanently takes over signal handling. After the // stop callback, signals revert to the default Go runtime behavior, which // cannot be overridden with external calls to signal.Notify. -func PrepareHandler(handler func(linux.Signal)) func() func() { - start := make(chan struct{}) +func StartSignalForwarding(handler func(linux.Signal)) func() { stop := make(chan struct{}) done := make(chan struct{}) @@ -128,13 +88,10 @@ func PrepareHandler(handler func(linux.Signal)) func() func() { signal.Notify(sigchan, syscall.Signal(sig)) } // Start up our listener. - go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. + go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. - return func() func() { - close(start) - return func() { - close(stop) - <-done - } + return func() { + close(stop) + <-done } } diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index c303435d5..1ebe22d34 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -15,8 +15,6 @@ package sighandling import ( - "fmt" - "runtime" "syscall" "unsafe" @@ -48,27 +46,3 @@ func IgnoreChildStop() error { return nil } - -// dieFromSignal kills the current process with sig. -// -// Preconditions: The default action of sig is termination. -func dieFromSignal(sig linux.Signal) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - sa := sigaction{handler: linux.SIG_DFL} - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigaction failed: %v", e)) - } - - set := linux.MakeSignalSet(sig) - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigprocmask failed: %v", e)) - } - - if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil { - panic(fmt.Sprintf("tgkill failed: %v", err)) - } - - panic("failed to die") -} diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 782a3cb92..af1a4e95f 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -195,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ // the available space, we must align down. // // align must be >= 4 and each data int32 is 4 bytes. The length of the - // header is already aligned, so if we align to the with of the data there + // header is already aligned, so if we align to the width of the data there // are two cases: // 1. The aligned length is less than the length of the header. The // unaligned length was also less than the length of the header, so we // can't write anything. // 2. The aligned length is greater than or equal to the length of the - // header. We can write the header plus zero or more datas. We can't write - // a partial int32, so the length of the message will be - // min(aligned length, header + datas). + // header. We can write the header plus zero or more bytes of data. We can't + // write a partial int32, so the length of the message will be + // min(aligned length, header + data). if space < linux.SizeOfControlMessageHeader { flags |= linux.MSG_CTRUNC return buf, flags @@ -240,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = binary.Marshal(buf, usermem.ByteOrder, data) - // Check if we went over. + // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { return hdrBuf } - // Fix up length. + // Update control message length to include data. putUint64(ob, uint64(len(buf)-len(ob))) return alignSlice(buf, align) @@ -348,43 +348,62 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { ) } -func addSpaceForCmsg(cmsgDataLen int, buf []byte) []byte { - newBuf := make([]byte, 0, len(buf)+linux.SizeOfControlMessageHeader+cmsgDataLen) - return append(newBuf, buf...) -} - -// PackControlMessages converts the given ControlMessages struct into a buffer. +// PackControlMessages packs control messages into the given buffer. +// // We skip control messages specific to Unix domain sockets. -func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages) []byte { - var buf []byte - // The use of t.Arch().Width() is analogous to Linux's use of sizeof(long) in - // CMSG_ALIGN. - width := t.Arch().Width() - +// +// Note that some control messages may be truncated if they do not fit under +// the capacity of buf. +func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { if cmsgs.IP.HasTimestamp { - buf = addSpaceForCmsg(int(width), buf) buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) } if cmsgs.IP.HasInq { // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. - buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageInq, width), buf) buf = PackInq(t, cmsgs.IP.Inq, buf) } if cmsgs.IP.HasTOS { - buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTOS, width), buf) buf = PackTOS(t, cmsgs.IP.TOS, buf) } if cmsgs.IP.HasTClass { - buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTClass, width), buf) buf = PackTClass(t, cmsgs.IP.TClass, buf) } return buf } +// cmsgSpace is equivalent to CMSG_SPACE in Linux. +func cmsgSpace(t *kernel.Task, dataLen int) int { + return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width()) +} + +// CmsgsSpace returns the number of bytes needed to fit the control messages +// represented in cmsgs. +func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { + space := 0 + + if cmsgs.IP.HasTimestamp { + space += cmsgSpace(t, linux.SizeOfTimeval) + } + + if cmsgs.IP.HasInq { + space += cmsgSpace(t, linux.SizeOfControlMessageInq) + } + + if cmsgs.IP.HasTOS { + space += cmsgSpace(t, linux.SizeOfControlMessageTOS) + } + + if cmsgs.IP.HasTClass { + space += cmsgSpace(t, linux.SizeOfControlMessageTClass) + } + + return space +} + // Parse parses a raw socket control message into portable objects. func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 8d9363aac..c957b0f1d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -45,7 +45,7 @@ const ( sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in) // maxControlLen is the maximum size of a control message buffer used in a - // recvmsg syscall. + // recvmsg or sendmsg syscall. maxControlLen = 1024 ) @@ -289,12 +289,12 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt switch level { case linux.SOL_IP: switch name { - case linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } case linux.SOL_SOCKET: @@ -334,12 +334,12 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ switch level { case linux.SOL_IP: switch name { - case linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } case linux.SOL_SOCKET: @@ -412,9 +412,12 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msg.Namelen = uint32(len(senderAddrBuf)) } if controlLen > 0 { - controlBuf = make([]byte, maxControlLen) + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) msg.Control = &controlBuf[0] - msg.Controllen = maxControlLen + msg.Controllen = controlLen } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { @@ -489,7 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.ErrInvalidArgument } - controlBuf := control.PackControlMessages(t, controlMessages) + space := uint64(control.CmsgsSpace(t, controlMessages)) + if space > maxControlLen { + space = maxControlLen + } + controlBuf := make([]byte, 0, space) + // PackControlMessages will append up to space bytes to controlBuf. + controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. if uint64(src.NumBytes()) != srcs.NumBytes() { diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d92399efd..fe5a46aa3 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -151,6 +151,8 @@ var Metrics = tcpip.Stats{ PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto index 9586f5923..b677e9eb3 100644 --- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto +++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto @@ -3,7 +3,6 @@ syntax = "proto3"; // package syscall_rpc is a set of networking related system calls that can be // forwarded to a socket gofer. // -// TODO(b/77963526): Document individual RPCs. package syscall_rpc; message SendmsgRequest { diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 8c250c325..2389a9cdb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -43,6 +43,11 @@ type ControlMessages struct { IP tcpip.ControlMessages } +// Release releases Unix domain socket credentials and rights. +func (c *ControlMessages) Release() { + c.Unix.Release() +} + // Socket is the interface containing socket syscalls used by the syscall layer // to redirect them to the appropriate implementation. type Socket interface { diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 72ebf766d..d46421199 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -14,6 +14,7 @@ go_library( "open.go", "poll.go", "ptrace.go", + "select.go", "signal.go", "socket.go", "strace.go", diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go index 5d57b75af..e603f858f 100644 --- a/pkg/sentry/strace/linux64.go +++ b/pkg/sentry/strace/linux64.go @@ -40,7 +40,7 @@ var linuxAMD64 = SyscallMap{ 20: makeSyscallInfo("writev", FD, WriteIOVec, Hex), 21: makeSyscallInfo("access", Path, Oct), 22: makeSyscallInfo("pipe", PipeFDs), - 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval), + 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval), 24: makeSyscallInfo("sched_yield"), 25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex), 26: makeSyscallInfo("msync", Hex, Hex, Hex), @@ -287,7 +287,7 @@ var linuxAMD64 = SyscallMap{ 267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex), 268: makeSyscallInfo("fchmodat", FD, Path, Mode), 269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex), - 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex), + 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet), 271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex), 272: makeSyscallInfo("unshare", CloneFlags), 273: makeSyscallInfo("set_robust_list", Hex, Hex), @@ -335,5 +335,33 @@ var linuxAMD64 = SyscallMap{ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex), 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex), 317: makeSyscallInfo("seccomp", Hex, Hex, Hex), + 318: makeSyscallInfo("getrandom", Hex, Hex, Hex), + 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close. + 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex), + 321: makeSyscallInfo("bpf", Hex, Hex, Hex), + 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex), + 323: makeSyscallInfo("userfaultfd", Hex), + 324: makeSyscallInfo("membarrier", Hex, Hex), + 325: makeSyscallInfo("mlock2", Hex, Hex, Hex), + 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex), + 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex), + 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex), + 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex), + 330: makeSyscallInfo("pkey_alloc", Hex, Hex), + 331: makeSyscallInfo("pkey_free", Hex), 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex), + 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet), + 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex), + 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex), + 425: makeSyscallInfo("io_uring_setup", Hex, Hex), + 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex), + 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex), + 428: makeSyscallInfo("open_tree", FD, Path, Hex), + 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex), + 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close. + 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex), + 432: makeSyscallInfo("fsmount", FD, Hex, Hex), + 433: makeSyscallInfo("fspick", FD, Path, Hex), + 434: makeSyscallInfo("pidfd_open", Hex, Hex), + 435: makeSyscallInfo("clone3", Hex, Hex), } diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go new file mode 100644 index 000000000..92c18083d --- /dev/null +++ b/pkg/sentry/strace/select.go @@ -0,0 +1,53 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +func fdsFromSet(t *kernel.Task, set []byte) []int { + var fds []int + // Append n if the n-th bit is 1. + for i, v := range set { + for j := 0; j < 8; j++ { + if (v>>uint(j))&1 == 1 { + fds = append(fds, i*8+j) + } + } + } + return fds +} + +func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string { + if addr == 0 { + return "null" + } + + // Calculate the size of the fd set (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := uint(nfds % 8) + + set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte) + if err != nil { + return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err) + } + + return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set)) +} diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 94334f6d2..51f2efb39 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -208,6 +208,15 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) i += linux.SizeOfControlMessageHeader width := t.Arch().Width() length := int(h.Length) - linux.SizeOfControlMessageHeader + if length < 0 { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 311389547..629c1f308 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -439,6 +439,8 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer())) case PollFDs: output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false)) + case SelectFDSet: + output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) case Oct: output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8)) case Hex: diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 3c389d375..e5d486c4e 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -206,6 +206,10 @@ const ( // PollFDs is an array of struct pollfd. The number of entries in the // array is in the next argument. PollFDs + + // SelectFDSet is an fd_set argument in select(2)/pselect(2). The number of + // fds represented must be the first argument. + SelectFDSet ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index 81e4f93a6..797542d28 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -260,7 +260,7 @@ var AMD64 = &kernel.SyscallTable{ 217: syscalls.Supported("getdents64", Getdents64), 218: syscalls.Supported("set_tid_address", SetTidAddress), 219: syscalls.Supported("restart_syscall", RestartSyscall), - 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), 222: syscalls.Supported("timer_create", TimerCreate), 223: syscalls.Supported("timer_settime", TimerSettime), @@ -367,11 +367,31 @@ var AMD64 = &kernel.SyscallTable{ 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - // Syscalls after 325 are "backports" from versions of Linux after 4.4. + // Syscalls implemented after 325 are "backports" from versions + // of Linux after 4.4. 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), 327: syscalls.Supported("preadv2", Preadv2), 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 332: syscalls.Supported("statx", Statx), + 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, Emulate: map[usermem.Addr]uintptr{ diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index f1dd4b0c0..2bc7faff5 100644 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -224,7 +224,7 @@ var ARM64 = &kernel.SyscallTable{ 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), - 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), @@ -302,7 +302,26 @@ var ARM64 = &kernel.SyscallTable{ 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), 286: syscalls.Supported("preadv2", Preadv2), 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 291: syscalls.Supported("statx", Statx), + 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, Emulate: map[usermem.Addr]uintptr{}, diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 3b9181002..9bc2445a5 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -840,25 +840,42 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(newfd), nil, nil } -func fGetOwn(t *kernel.Task, file *fs.File) int32 { +func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx { ma := file.Async(nil) if ma == nil { - return 0 + return linux.FOwnerEx{} } a := ma.(*fasync.FileAsync) ot, otg, opg := a.Owner() switch { case ot != nil: - return int32(t.PIDNamespace().IDOfTask(ot)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_TID, + PID: int32(t.PIDNamespace().IDOfTask(ot)), + } case otg != nil: - return int32(t.PIDNamespace().IDOfThreadGroup(otg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PID, + PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)), + } case opg != nil: - return int32(-t.PIDNamespace().IDOfProcessGroup(opg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PGRP, + PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)), + } default: - return 0 + return linux.FOwnerEx{} } } +func fGetOwn(t *kernel.Task, file *fs.File) int32 { + owner := fGetOwnEx(t, file) + if owner.Type == linux.F_OWNER_PGRP { + return -owner.PID + } + return owner.PID +} + // fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux. // // If who is positive, it represents a PID. If negative, it represents a PGID. @@ -901,11 +918,13 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall t.FDTable().SetFlags(fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) + return 0, nil, nil case linux.F_GETFL: return uintptr(file.Flags().ToLinux()), nil, nil case linux.F_SETFL: flags := uint(args[2].Uint()) file.SetFlags(linuxToFlags(flags).Settable()) + return 0, nil, nil case linux.F_SETLK, linux.F_SETLKW: // In Linux the file system can choose to provide lock operations for an inode. // Normally pipe and socket types lack lock operations. We diverge and use a heavy @@ -1008,6 +1027,44 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_SETOWN: fSetOwn(t, file, args[2].Int()) return 0, nil, nil + case linux.F_GETOWN_EX: + addr := args[2].Pointer() + owner := fGetOwnEx(t, file) + _, err := t.CopyOut(addr, &owner) + return 0, nil, err + case linux.F_SETOWN_EX: + addr := args[2].Pointer() + var owner linux.FOwnerEx + n, err := t.CopyIn(addr, &owner) + if err != nil { + return 0, nil, err + } + a := file.Async(fasync.New).(*fasync.FileAsync) + switch owner.Type { + case linux.F_OWNER_TID: + task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) + if task == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerTask(t, task) + return uintptr(n), nil, nil + case linux.F_OWNER_PID: + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID)) + if tg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerThreadGroup(t, tg) + return uintptr(n), nil, nil + case linux.F_OWNER_PGRP: + pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID)) + if pg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerProcessGroup(t, pg) + return uintptr(n), nil, nil + default: + return 0, nil, syserror.EINVAL + } case linux.F_GET_SEALS: val, err := tmpfs.GetSeals(file.Dirent.Inode) return uintptr(val), nil, err @@ -1035,7 +1092,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Everything else is not yet supported. return 0, nil, syserror.EINVAL } - return 0, nil, nil } const ( diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 7a13beac2..631dffec6 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -197,53 +197,51 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) return remainingTimeout, n, err } +// CopyInFDSet copies an fd set from select(2)/pselect(2). +func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes int, nBitsInLastPartialByte uint) ([]byte, error) { + set := make([]byte, nBytes) + + if addr != 0 { + if _, err := t.CopyIn(addr, &set); err != nil { + return nil, err + } + // If we only use part of the last byte, mask out the extraneous bits. + // + // N.B. This only works on little-endian architectures. + if nBitsInLastPartialByte != 0 { + set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte + } + } + return set, nil +} + func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { return 0, syserror.EINVAL } - // Capture all the provided input vectors. - // - // N.B. This only works on little-endian architectures. - byteCount := (nfds + 7) / 8 - - bitsInLastPartialByte := uint(nfds % 8) - r := make([]byte, byteCount) - w := make([]byte, byteCount) - e := make([]byte, byteCount) + // Calculate the size of the fd sets (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := uint(nfds % 8) - if readFDs != 0 { - if _, err := t.CopyIn(readFDs, &r); err != nil { - return 0, err - } - // Mask out bits above nfds. - if bitsInLastPartialByte != 0 { - r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + // Capture all the provided input vectors. + r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if writeFDs != 0 { - if _, err := t.CopyIn(writeFDs, &w); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if exceptFDs != 0 { - if _, err := t.CopyIn(exceptFDs, &e); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } // Count how many FDs are actually being requested so that we can build // a PollFD array. fdCount := 0 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { v := r[i] | w[i] | e[i] for v != 0 { v &= (v - 1) @@ -254,7 +252,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Build the PollFD array. pfd := make([]linux.PollFD, 0, fdCount) var fd int32 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { rV, wV, eV := r[i], w[i], e[i] v := rV | wV | eV m := byte(1) @@ -295,8 +293,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add } // Do the syscall, then count the number of bits set. - _, _, err := pollBlock(t, pfd, timeout) - if err != nil { + if _, _, err = pollBlock(t, pfd, timeout); err != nil { return 0, syserror.ConvertIntr(err, syserror.EINTR) } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index d8acae063..4b5aafcc0 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -764,7 +764,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Unix.Release() + cms.Release() } if int(msg.Flags) != mflags { @@ -784,32 +784,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } - defer cms.Unix.Release() + defer cms.Release() controlData := make([]byte, 0, msg.ControlLen) + controlData = control.PackControlMessages(t, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) controlData, mflags = control.PackCredentials(t, creds, controlData, mflags) } - if cms.IP.HasTimestamp { - controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData) - } - - if cms.IP.HasInq { - // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. - controlData = control.PackInq(t, cms.IP.Inq, controlData) - } - - if cms.IP.HasTOS { - controlData = control.PackTOS(t, cms.IP.TOS, controlData) - } - - if cms.IP.HasTClass { - controlData = control.PackTClass(t, cms.IP.TClass, controlData) - } - if cms.Unix.Rights != nil { controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags) } @@ -885,7 +869,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Unix.Release() + cm.Release() if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } @@ -1071,7 +1055,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) if err != nil { - controlMessages.Unix.Release() + controlMessages.Release() } return uintptr(n), err } diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 74a325309..59237c3b9 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -19,7 +19,6 @@ go_library( "options.go", "permissions.go", "resolving_path.go", - "syscalls.go", "testutil.go", "vfs.go", ], diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 34007eb57..4473dfce8 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -241,3 +241,96 @@ type IterDirentsCallback interface { // called. Handle(dirent Dirent) bool } + +// OnClose is called when a file descriptor representing the FileDescription is +// closed. Returning a non-nil error should not prevent the file descriptor +// from being closed. +func (fd *FileDescription) OnClose(ctx context.Context) error { + return fd.impl.OnClose(ctx) +} + +// StatusFlags returns file description status flags, as for fcntl(F_GETFL). +func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) { + flags, err := fd.impl.StatusFlags(ctx) + flags |= linux.O_LARGEFILE + return flags, err +} + +// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL). +func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { + return fd.impl.SetStatusFlags(ctx, flags) +} + +// Stat returns metadata for the file represented by fd. +func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + return fd.impl.Stat(ctx, opts) +} + +// SetStat updates metadata for the file represented by fd. +func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error { + return fd.impl.SetStat(ctx, opts) +} + +// StatFS returns metadata for the filesystem containing the file represented +// by fd. +func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + return fd.impl.StatFS(ctx) +} + +// PRead reads from the file represented by fd into dst, starting at the given +// offset, and returns the number of bytes read. PRead is permitted to return +// partial reads with a nil error. +func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return fd.impl.PRead(ctx, dst, offset, opts) +} + +// Read is similar to PRead, but does not specify an offset. +func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + return fd.impl.Read(ctx, dst, opts) +} + +// PWrite writes src to the file represented by fd, starting at the given +// offset, and returns the number of bytes written. PWrite is permitted to +// return partial writes with a nil error. +func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return fd.impl.PWrite(ctx, src, offset, opts) +} + +// Write is similar to PWrite, but does not specify an offset. +func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return fd.impl.Write(ctx, src, opts) +} + +// IterDirents invokes cb on each entry in the directory represented by fd. If +// IterDirents has been called since the last call to Seek, it continues +// iteration from the end of the last call. +func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error { + return fd.impl.IterDirents(ctx, cb) +} + +// Seek changes fd's offset (assuming one exists) and returns its new value. +func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return fd.impl.Seek(ctx, offset, whence) +} + +// Sync has the semantics of fsync(2). +func (fd *FileDescription) Sync(ctx context.Context) error { + return fd.impl.Sync(ctx) +} + +// ConfigureMMap mutates opts to implement mmap(2) for the file represented by +// fd. +func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.impl.ConfigureMMap(ctx, opts) +} + +// Ioctl implements the ioctl(2) syscall. +func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.impl.Ioctl(ctx, uio, args) +} + +// SyncFS instructs the filesystem containing fd to execute the semantics of +// syncfs(2). +func (fd *FileDescription) SyncFS(ctx context.Context) error { + return fd.vd.mount.fs.impl.Sync(ctx) +} diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index a5561dcbe..ac7799296 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -103,7 +103,7 @@ func TestGenCountFD(t *testing.T) { // The first read causes Generate to be called to fill the FD's buffer. buf := make([]byte, 2) ioseq := usermem.BytesIOSequence(buf) - n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err := fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -112,17 +112,17 @@ func TestGenCountFD(t *testing.T) { } // A second read without seeking is still at EOF. - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 0 || err != io.EOF { t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err) } // Seeking to the beginning of the file causes it to be regenerated. - n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET) + n, err = fd.Seek(ctx, 0, linux.SEEK_SET) if n != 0 || err != nil { t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err) } - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -131,7 +131,7 @@ func TestGenCountFD(t *testing.T) { } // PRead at the beginning of the file also causes it to be regenerated. - n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{}) + n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err) } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 76ff8cf51..dfbd2372a 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -47,6 +47,9 @@ func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) { fs.refs = 1 fs.vfs = vfsObj fs.impl = impl + vfsObj.filesystemsMu.Lock() + vfsObj.filesystems[fs] = struct{}{} + vfsObj.filesystemsMu.Unlock() } // VirtualFilesystem returns the containing VirtualFilesystem. @@ -66,9 +69,28 @@ func (fs *Filesystem) IncRef() { } } +// TryIncRef increments fs' reference count and returns true. If fs' reference +// count is zero, TryIncRef does nothing and returns false. +// +// TryIncRef does not require that a reference is held on fs. +func (fs *Filesystem) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&fs.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) { + return true + } + } +} + // DecRef decrements fs' reference count. func (fs *Filesystem) DecRef() { if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 { + fs.vfs.filesystemsMu.Lock() + delete(fs.vfs.filesystems, fs) + fs.vfs.filesystemsMu.Unlock() fs.impl.Release() } else if refs < 0 { panic("Filesystem.decRef() called without holding a reference") diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 1c3b2e987..ec23ab0dd 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -18,6 +18,7 @@ import ( "math" "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -133,13 +134,13 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return mntns, nil } -// NewMount creates and mounts a Filesystem configured by the given arguments. -func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *GetFilesystemOptions) error { +// MountAt creates and mounts a Filesystem configured by the given arguments. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { fsType := vfs.getFilesystemType(fsTypeName) if fsType == nil { return syserror.ENODEV } - fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts) + fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return err } @@ -207,6 +208,68 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti return nil } +// UmountAt removes the Mount at the given path. +func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error { + if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 { + return syserror.EINVAL + } + + // MNT_FORCE is currently unimplemented except for the permission check. + if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) { + return syserror.EPERM + } + + vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{}) + if err != nil { + return err + } + defer vd.DecRef() + if vd.dentry != vd.mount.root { + return syserror.EINVAL + } + vfs.mountMu.Lock() + if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns { + vfs.mountMu.Unlock() + return syserror.EINVAL + } + + // TODO(jamieliu): Linux special-cases umount of the caller's root, which + // we don't implement yet (we'll just fail it since the caller holds a + // reference on it). + + vfs.mounts.seq.BeginWrite() + if opts.Flags&linux.MNT_DETACH == 0 { + if len(vd.mount.children) != 0 { + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + // We are holding a reference on vd.mount. + expectedRefs := int64(1) + if !vd.mount.umounted { + expectedRefs = 2 + } + if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + } + vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{ + eager: opts.Flags&linux.MNT_DETACH == 0, + disconnectHierarchy: true, + }, nil, nil) + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } + return nil +} + type umountRecursiveOptions struct { // If eager is true, ensure that future calls to Mount.tryIncMountedRef() // on umounted mounts fail. diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 3aa73d911..3ecbc8fc1 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -46,6 +46,12 @@ type MknodOptions struct { DevMinor uint32 } +// MountOptions contains options to VirtualFilesystem.MountAt(). +type MountOptions struct { + // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). + GetFilesystemOptions GetFilesystemOptions +} + // OpenOptions contains options to VirtualFilesystem.OpenAt() and // FilesystemImpl.OpenAt(). type OpenOptions struct { @@ -114,6 +120,12 @@ type StatOptions struct { Sync uint32 } +// UmountOptions contains options to VirtualFilesystem.UmountAt(). +type UmountOptions struct { + // Flags contains flags as specified for umount2(2). + Flags uint32 +} + // WriteOptions contains options to FileDescription.PWrite(), // FileDescriptionImpl.PWrite(), FileDescription.Write(), and // FileDescriptionImpl.Write(). diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index f8e74355c..f1edb0680 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -119,3 +119,65 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { return false } } + +// CheckSetStat checks that creds has permission to change the metadata of a +// file with the given permissions, UID, and GID as specified by stat, subject +// to the rules of Linux's fs/attr.c:setattr_prepare(). +func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error { + if stat.Mask&linux.STATX_MODE != 0 { + if !CanActAsOwner(creds, kuid) { + return syserror.EPERM + } + // TODO(b/30815691): "If the calling process is not privileged (Linux: + // does not have the CAP_FSETID capability), and the group of the file + // does not match the effective group ID of the process or one of its + // supplementary group IDs, the S_ISGID bit will be turned off, but + // this will not cause an error to be returned." - chmod(2) + } + if stat.Mask&linux.STATX_UID != 0 { + if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&linux.STATX_GID != 0 { + if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { + if !CanActAsOwner(creds, kuid) { + if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) { + return syserror.EPERM + } + // isDir is irrelevant in the following call to + // GenericCheckPermissions since ats == MayWrite means that + // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE + // applies, regardless of isDir. + if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil { + return err + } + } + } + return nil +} + +// CanActAsOwner returns true if creds can act as the owner of a file with the +// given owning UID, consistent with Linux's +// fs/inode.c:inode_owner_or_capable(). +func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool { + if creds.EffectiveKUID == kuid { + return true + } + return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok() +} + +// HasCapabilityOnFile returns true if creds has the given capability with +// respect to a file with the given owning UID and GID, consistent with Linux's +// kernel/capability.c:capable_wrt_inode_uidgid(). +func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool { + return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok() +} diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go deleted file mode 100644 index 436151afa..000000000 --- a/pkg/sentry/vfs/syscalls.go +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" -) - -// PathOperation specifies the path operated on by a VFS method. -// -// PathOperation is passed to VFS methods by pointer to reduce memory copying: -// it's somewhat large and should never escape. (Options structs are passed by -// pointer to VFS and FileDescription methods for the same reason.) -type PathOperation struct { - // Root is the VFS root. References on Root are borrowed from the provider - // of the PathOperation. - // - // Invariants: Root.Ok(). - Root VirtualDentry - - // Start is the starting point for the path traversal. References on Start - // are borrowed from the provider of the PathOperation (i.e. the caller of - // the VFS method to which the PathOperation was passed). - // - // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. - Start VirtualDentry - - // Path is the pathname traversed by this operation. - Pathname string - - // If FollowFinalSymlink is true, and the Dentry traversed by the final - // path component represents a symbolic link, the symbolic link should be - // followed. - FollowFinalSymlink bool -} - -// GetDentryAt returns a VirtualDentry representing the given path, at which a -// file must exist. A reference is taken on the returned VirtualDentry. -func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return VirtualDentry{}, err - } - for { - d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) - if err == nil { - vd := VirtualDentry{ - mount: rp.mount, - dentry: d, - } - rp.mount.IncRef() - vfs.putResolvingPath(rp) - return vd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return VirtualDentry{}, err - } - } -} - -// MkdirAt creates a directory at the given path. -func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { - // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is - // also honored." - mkdir(2) - opts.Mode &= 01777 - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err - } - for { - err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return err - } - } -} - -// MknodAt creates a file of the given mode at the given path. It returns an -// error from the syserror package. -func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return nil - } - for { - if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil { - vfs.putResolvingPath(rp) - return nil - } - // Handle mount traversals. - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return err - } - } -} - -// OpenAt returns a FileDescription providing access to the file at the given -// path. A reference is taken on the returned FileDescription. -func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { - // Remove: - // - // - O_LARGEFILE, which we always report in FileDescription status flags - // since only 64-bit architectures are supported at this time. - // - // - O_CLOEXEC, which affects file descriptors and therefore must be - // handled outside of VFS. - // - // - Unknown flags. - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE - // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. - if opts.Flags&linux.O_SYNC != 0 { - opts.Flags |= linux.O_DSYNC - } - // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified - // with O_DIRECTORY and a writable access mode (to ensure that it fails on - // filesystem implementations that do not support it). - if opts.Flags&linux.O_TMPFILE != 0 { - if opts.Flags&linux.O_DIRECTORY == 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { - return nil, syserror.EINVAL - } - } - // O_PATH causes most other flags to be ignored. - if opts.Flags&linux.O_PATH != 0 { - opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH - } - // "On Linux, the following bits are also honored in mode: [S_ISUID, - // S_ISGID, S_ISVTX]" - open(2) - opts.Mode &= 07777 - - if opts.Flags&linux.O_NOFOLLOW != 0 { - pop.FollowFinalSymlink = false - } - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return nil, err - } - if opts.Flags&linux.O_DIRECTORY != 0 { - rp.mustBeDir = true - rp.mustBeDirOrig = true - } - for { - fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return fd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return nil, err - } - } -} - -// StatAt returns metadata for the file at the given path. -func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return linux.Statx{}, err - } - for { - stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return stat, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return linux.Statx{}, err - } - } -} - -// StatusFlags returns file description status flags. -func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) { - flags, err := fd.impl.StatusFlags(ctx) - flags |= linux.O_LARGEFILE - return flags, err -} - -// SetStatusFlags sets file description status flags. -func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - return fd.impl.SetStatusFlags(ctx, flags) -} - -// TODO: -// -// - VFS.SyncAllFilesystems() for sync(2) -// -// - Something for syncfs(2) -// -// - VFS.LinkAt() -// -// - VFS.ReadlinkAt() -// -// - VFS.RenameAt() -// -// - VFS.RmdirAt() -// -// - VFS.SetStatAt() -// -// - VFS.StatFSAt() -// -// - VFS.SymlinkAt() -// -// - VFS.UmountAt() -// -// - VFS.UnlinkAt() -// -// - FileDescription.(almost everything) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index f0cd3ffe5..7262b0d0a 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -20,6 +20,7 @@ // VirtualFilesystem.mountMu // Dentry.mu // Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry +// VirtualFilesystem.filesystemsMu // VirtualFilesystem.fsTypesMu // // Locking Dentry.mu in multiple Dentries requires holding @@ -28,6 +29,11 @@ package vfs import ( "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" ) // A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts. @@ -67,6 +73,11 @@ type VirtualFilesystem struct { // mountpoints is analogous to Linux's mountpoint_hashtable. mountpoints map[*Dentry]map[*Mount]struct{} + // filesystems contains all Filesystems. filesystems is protected by + // filesystemsMu. + filesystemsMu sync.Mutex + filesystems map[*Filesystem]struct{} + // fsTypes contains all FilesystemTypes that are usable in the // VirtualFilesystem. fsTypes is protected by fsTypesMu. fsTypesMu sync.RWMutex @@ -77,12 +88,379 @@ type VirtualFilesystem struct { func New() *VirtualFilesystem { vfs := &VirtualFilesystem{ mountpoints: make(map[*Dentry]map[*Mount]struct{}), + filesystems: make(map[*Filesystem]struct{}), fsTypes: make(map[string]FilesystemType), } vfs.mounts.Init() return vfs } +// PathOperation specifies the path operated on by a VFS method. +// +// PathOperation is passed to VFS methods by pointer to reduce memory copying: +// it's somewhat large and should never escape. (Options structs are passed by +// pointer to VFS and FileDescription methods for the same reason.) +type PathOperation struct { + // Root is the VFS root. References on Root are borrowed from the provider + // of the PathOperation. + // + // Invariants: Root.Ok(). + Root VirtualDentry + + // Start is the starting point for the path traversal. References on Start + // are borrowed from the provider of the PathOperation (i.e. the caller of + // the VFS method to which the PathOperation was passed). + // + // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. + Start VirtualDentry + + // Path is the pathname traversed by this operation. + Pathname string + + // If FollowFinalSymlink is true, and the Dentry traversed by the final + // path component represents a symbolic link, the symbolic link should be + // followed. + FollowFinalSymlink bool +} + +// GetDentryAt returns a VirtualDentry representing the given path, at which a +// file must exist. A reference is taken on the returned VirtualDentry. +func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return VirtualDentry{}, err + } + for { + d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) + if err == nil { + vd := VirtualDentry{ + mount: rp.mount, + dentry: d, + } + rp.mount.IncRef() + vfs.putResolvingPath(rp) + return vd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return VirtualDentry{}, err + } + } +} + +// LinkAt creates a hard link at newpop representing the existing file at +// oldpop. +func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// MkdirAt creates a directory at the given path. +func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { + // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is + // also honored." - mkdir(2) + opts.Mode &= 0777 | linux.S_ISVTX + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// MknodAt creates a file of the given mode at the given path. It returns an +// error from the syserror package. +func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil + } + for { + if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil { + vfs.putResolvingPath(rp) + return nil + } + // Handle mount traversals. + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// OpenAt returns a FileDescription providing access to the file at the given +// path. A reference is taken on the returned FileDescription. +func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { + // Remove: + // + // - O_LARGEFILE, which we always report in FileDescription status flags + // since only 64-bit architectures are supported at this time. + // + // - O_CLOEXEC, which affects file descriptors and therefore must be + // handled outside of VFS. + // + // - Unknown flags. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE + // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. + if opts.Flags&linux.O_SYNC != 0 { + opts.Flags |= linux.O_DSYNC + } + // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified + // with O_DIRECTORY and a writable access mode (to ensure that it fails on + // filesystem implementations that do not support it). + if opts.Flags&linux.O_TMPFILE != 0 { + if opts.Flags&linux.O_DIRECTORY == 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_CREAT != 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { + return nil, syserror.EINVAL + } + } + // O_PATH causes most other flags to be ignored. + if opts.Flags&linux.O_PATH != 0 { + opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH + } + // "On Linux, the following bits are also honored in mode: [S_ISUID, + // S_ISGID, S_ISVTX]" - open(2) + opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX + + if opts.Flags&linux.O_NOFOLLOW != 0 { + pop.FollowFinalSymlink = false + } + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil, err + } + if opts.Flags&linux.O_DIRECTORY != 0 { + rp.mustBeDir = true + rp.mustBeDirOrig = true + } + for { + fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return fd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// ReadlinkAt returns the target of the symbolic link at the given path. +func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return "", err + } + for { + target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return target, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// RenameAt renames the file at oldpop to newpop. +func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// RmdirAt removes the directory at the given path. +func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.RmdirAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SetStatAt changes metadata for the file at the given path. +func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// StatAt returns metadata for the file at the given path. +func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statx{}, err + } + for { + stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return stat, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statx{}, err + } + } +} + +// StatFSAt returns metadata for the filesystem containing the file at the +// given path. +func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statfs{}, err + } + for { + statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return statfs, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statfs{}, err + } + } +} + +// SymlinkAt creates a symbolic link at the given path with the given target. +func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// UnlinkAt deletes the non-directory file at the given path. +func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.UnlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SyncAllFilesystems has the semantics of Linux's sync(2). +func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + fss := make(map[*Filesystem]struct{}) + vfs.filesystemsMu.Lock() + for fs := range vfs.filesystems { + if !fs.TryIncRef() { + continue + } + fss[fs] = struct{}{} + } + vfs.filesystemsMu.Unlock() + var retErr error + for fs := range fss { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef() + } + return retErr +} + // A VirtualDentry represents a node in a VFS tree, by combining a Dentry // (which represents a node in a Filesystem's tree) and a Mount (which // represents the Filesystem's position in a VFS mount tree). diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index ecce6c69f..5e4611333 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -287,7 +287,9 @@ func (w *Watchdog) runTurn() { if !ok { // New stuck task detected. // - // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel. + // Note that tasks blocked doing IO may be considered stuck in kernel, + // unless they are surrounded b + // Task.UninterruptibleSleepStart/Finish. tc = &offender{lastUpdateTime: lastUpdateTime} stuckTasks.Increment() newTaskFound = true diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD index b06a90bef..cb1f41628 100644 --- a/pkg/syncutil/BUILD +++ b/pkg/syncutil/BUILD @@ -31,8 +31,6 @@ go_template( go_library( name = "syncutil", srcs = [ - "downgradable_rwmutex_1_12_unsafe.go", - "downgradable_rwmutex_1_13_unsafe.go", "downgradable_rwmutex_unsafe.go", "memmove_unsafe.go", "norace_unsafe.go", diff --git a/pkg/syncutil/downgradable_rwmutex_1_12_unsafe.go b/pkg/syncutil/downgradable_rwmutex_1_12_unsafe.go deleted file mode 100644 index 7c6336e62..000000000 --- a/pkg/syncutil/downgradable_rwmutex_1_12_unsafe.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.12 -// +build !go1.13 - -// TODO(b/133868570): Delete once Go 1.12 is no longer supported. - -package syncutil - -import _ "unsafe" - -//go:linkname runtimeSemrelease112 sync.runtime_Semrelease -func runtimeSemrelease112(s *uint32, handoff bool) - -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) { - // 'skipframes' is only available starting from 1.13. - runtimeSemrelease112(s, handoff) -} diff --git a/pkg/syncutil/downgradable_rwmutex_1_13_unsafe.go b/pkg/syncutil/downgradable_rwmutex_1_13_unsafe.go deleted file mode 100644 index 3c3673119..000000000 --- a/pkg/syncutil/downgradable_rwmutex_1_13_unsafe.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.13 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -package syncutil - -import _ "unsafe" - -//go:linkname runtimeSemrelease sync.runtime_Semrelease -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go index 07feca402..51e11555d 100644 --- a/pkg/syncutil/downgradable_rwmutex_unsafe.go +++ b/pkg/syncutil/downgradable_rwmutex_unsafe.go @@ -3,7 +3,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.12 +// +build go1.13 // +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -27,6 +27,9 @@ import ( //go:linkname runtimeSemacquire sync.runtime_Semacquire func runtimeSemacquire(s *uint32) +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + // DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock // method. type DowngradableRWMutex struct { diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index a3485b35c..8392cb9e5 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -55,5 +55,8 @@ go_test( "ndp_test.go", ], embed = [":header"], - deps = ["//pkg/tcpip"], + deps = [ + "//pkg/tcpip", + "@com_github_google_go-cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 0caa51c1e..5275b34d4 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -90,6 +90,18 @@ const ( // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -266,6 +278,28 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// EthernetAdddressToEUI64IntoBuf populates buf with a EUI-64 from a 48-bit +// Ethernet/MAC address. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFE + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToEUI64 computes an EUI-64 from a 48-bit Ethernet/MAC address. +func EthernetAddressToEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { @@ -275,18 +309,11 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff // Note the capital A. The conversion aa->Aa involves a bit flip. - lladdrb := [16]byte{ - 0: 0xFE, - 1: 0x80, - 8: linkAddr[0] ^ 2, - 9: linkAddr[1], - 10: linkAddr[2], - 11: 0xFF, - 12: 0xFE, - 13: linkAddr[3], - 14: linkAddr[4], - 15: linkAddr[5], + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, } + EthernetAdddressToEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 1ca6199ef..06e0bace2 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "errors" + "math" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -85,6 +86,23 @@ const ( // within an NDPPrefixInformation. ndpPrefixInformationPrefixOffset = 14 + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType = 25 + + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS + // Server option's length field value when it contains at least one + // IPv6 address. + minNDPRecursiveDNSServerLength = 3 + // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. @@ -92,13 +110,13 @@ const ( ) var ( - // NDPPrefixInformationInfiniteLifetime is a value that represents - // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix - // Information option. Its value is (2^32 - 1)s = 4294967295s + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. // // This is a variable instead of a constant so that tests can change // this value to a smaller value. It should only be modified by tests. - NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 + NDPInfiniteLifetime = time.Second * math.MaxUint32 ) // NDPOptionIterator is an iterator of NDPOption. @@ -118,6 +136,7 @@ var ( ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") ) // Next returns the next element in the backing NDPOptions, or true if we are @@ -182,6 +201,22 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { } return NDPPrefixInformation(body), false, nil + + case NDPRecursiveDNSServerOptionType: + // RFC 8106 section 5.3.1 outlines that the RDNSS option + // must have a minimum length of 3 so it contains at + // least one IPv6 address. + if l < minNDPRecursiveDNSServerLength { + return nil, true, ErrNDPInvalidLength + } + + opt := NDPRecursiveDNSServer(body) + if len(opt.Addresses()) == 0 { + return nil, true, ErrNDPOptMalformedBody + } + + return opt, false, nil + default: // We do not yet recognize the option, just skip for // now. This is okay because RFC 4861 allows us to @@ -434,7 +469,7 @@ func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { // // Note, a value of 0 implies the prefix should not be considered as on-link, // and a value of infinity/forever is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. func (o NDPPrefixInformation) ValidLifetime() time.Duration { // The field is the time in seconds, as per RFC 4861 section 4.6.2. return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) @@ -447,7 +482,7 @@ func (o NDPPrefixInformation) ValidLifetime() time.Duration { // // Note, a value of 0 implies that addresses generated from the prefix should // no longer remain preferred, and a value of infinity is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. // // Also note that the value of this field MUST NOT exceed the Valid Lifetime // field to avoid preferring addresses that are no longer valid, for the @@ -476,3 +511,79 @@ func (o NDPPrefixInformation) Subnet() tcpip.Subnet { } return addrWithPrefix.Subnet() } + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// Type implements NDPOption.Type. +func (NDPRecursiveDNSServer) Type() uint8 { + return NDPRecursiveDNSServerOptionType +} + +// Length implements NDPOption.Length. +func (o NDPRecursiveDNSServer) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, some of the addresses returned MAY be link-local addresses. +// +// Addresses may panic if o does not hold valid IPv6 addresses. +func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { + l := len(o) + if l < ndpRecursiveDNSServerAddressesOffset { + return nil + } + + l -= ndpRecursiveDNSServerAddressesOffset + if l%IPv6AddressSize != 0 { + return nil + } + + buf := o[ndpRecursiveDNSServerAddressesOffset:] + var addrs []tcpip.Address + for len(buf) > 0 { + addr := tcpip.Address(buf[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return nil + } + addrs = append(addrs, addr) + buf = buf[IPv6AddressSize:] + } + return addrs +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index ad6daafcd..2c439d70c 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -369,6 +370,175 @@ func TestNDPPrefixInformationOption(t *testing.T) { } } +func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 25, 3, 0, 0, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPRecursiveDNSServer(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Type(); got != 25 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16909320*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + want := []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + } + if got := opt.Addresses(); !cmp.Equal(got, want) { + t.Errorf("got Addresses = %v, want = %v", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + addrs []tcpip.Address + }{ + { + "Valid1Addr", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + }, + }, + { + "Valid2Addr", + []byte{ + 25, 5, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + }, + }, + { + "Valid3Addr", + []byte{ + 25, 7, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Iterator should get our option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { + t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { @@ -473,6 +643,51 @@ func TestNDPOptionsIterCheck(t *testing.T) { }, nil, }, + { + "InvalidRecursiveDNSServerCutsOffAddress", + []byte{ + 25, 4, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7, + }, + ErrNDPOptMalformedBody, + }, + { + "InvalidRecursiveDNSServerInvalidLengthField", + []byte{ + 25, 2, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, + }, + ErrNDPInvalidLength, + }, + { + "RecursiveDNSServerTooSmall", + []byte{ + 25, 1, 0, 0, + 0, 0, 0, + }, + ErrNDPOptBufExhausted, + }, + { + "RecursiveDNSServerMulticast", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + ErrNDPOptMalformedBody, + }, + { + "RecursiveDNSServerUnspecified", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + ErrNDPOptMalformedBody, + }, } for _, test := range tests { diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 4839f0a65..e156b01f6 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 30cea8996..6c5e19e8f 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -41,6 +41,30 @@ type portDescriptor struct { port uint16 } +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool +} + +func (f Flags) bits() reuseFlag { + var rf reuseFlag + if f.MostRecent { + rf |= mostRecentFlag + } + if f.LoadBalanced { + rf |= loadBalancedFlag + } + return rf +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -54,9 +78,59 @@ type PortManager struct { hint uint32 } +type reuseFlag int + +const ( + mostRecentFlag reuseFlag = 1 << iota + loadBalancedFlag + nextFlag + + flagMask = nextFlag - 1 +) + type portNode struct { - reuse bool - refs int + // refs stores the count for each possible flag combination. + refs [nextFlag]int +} + +func (p portNode) totalRefs() int { + var total int + for _, r := range p.refs { + total += r + } + return total +} + +// flagRefs returns the number of references with all specified flags. +func (p portNode) flagRefs(flags reuseFlag) int { + var total int + for i, r := range p.refs { + if reuseFlag(i)&flags == flags { + total += r + } + } + return total +} + +// allRefsHave returns if all references have all specified flags. +func (p portNode) allRefsHave(flags reuseFlag) bool { + for i, r := range p.refs { + if reuseFlag(i)&flags == flags && r > 0 { + return false + } + } + return true +} + +// intersectionRefs returns the set of flags shared by all references. +func (p portNode) intersectionRefs() reuseFlag { + intersection := flagMask + for i, r := range p.refs { + if r > 0 { + intersection &= reuseFlag(i) + } + } + return intersection } // deviceNode is never empty. When it has no elements, it is removed from the @@ -66,30 +140,44 @@ type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all portNodes. If binding to a specific device, check // against the unspecified device and the provided device. -func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { +// +// If either of the port reuse flags is enabled on any of the nodes, all nodes +// sharing a port must share at least one reuse flag. This matches Linux's +// behavior. +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { + flagBits := flags.bits() if bindToDevice == 0 { // Trying to binding all devices. - if !reuse { + if flagBits == 0 { // Can't bind because the (addr,port) is already bound. return false } + intersection := flagMask for _, p := range d { - if !p.reuse { - // Can't bind because the (addr,port) was previously bound without reuse. + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { + // Can't bind because the (addr,port) was + // previously bound without reuse. return false } } return true } + intersection := flagMask + if p, ok := d[0]; ok { - if !reuse || !p.reuse { + intersection = p.intersectionRefs() + if intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - if !reuse || !p.reuse { + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { return false } } @@ -103,12 +191,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -117,14 +205,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -190,17 +278,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice) { return false } } @@ -212,14 +300,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -227,15 +315,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } + flagBits := flags.bits() // Reserve port on all network protocols. for _, network := range networks { @@ -250,12 +339,9 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - if n, ok := d[bindToDevice]; ok { - n.refs++ - d[bindToDevice] = n - } else { - d[bindToDevice] = portNode{reuse: reuse, refs: 1} - } + n := d[bindToDevice] + n.refs[flagBits]++ + d[bindToDevice] = n } return true @@ -263,10 +349,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() + flagBits := flags.bits() + for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -278,9 +366,9 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n.refs-- + n.refs[flagBits]-- d[bindToDevice] = n - if n.refs == 0 { + if n.refs == [nextFlag]int{} { delete(d, bindToDevice) } if len(d) == 0 { diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 19f4833fc..d6969d050 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -33,7 +33,7 @@ type portReserveTestAction struct { port uint16 ip tcpip.Address want *tcpip.Error - reuse bool + flags Flags release bool device tcpip.NICID } @@ -50,7 +50,7 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, }, }, { @@ -61,7 +61,7 @@ func TestPortReservation(t *testing.T) { /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -71,36 +71,36 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 00, ip: fakeIPAddress, want: nil}, {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to ip with reuseport", actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: true, want: nil}, + {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to inaddr any with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, release: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, }, }, { tname: "bind twice with device fails", @@ -125,88 +125,152 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { - tname: "bind with reuse", + tname: "bind with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { - tname: "binding with reuse and device", + tname: "binding with reuseport and device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, }, }, { - tname: "mixing reuse and not reuse by binding to device", + tname: "mixing reuseport and not reuseport by binding to device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: nil}, }, }, { - tname: "can't bind to 0 after mixing reuse and not reuse", + tname: "can't bind to 0 after mixing reuseport and not reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, }, }, { - tname: "bind twice with reuse once", + tname: "bind twice with reuseport once", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "release an unreserved device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, // Release all. - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, + }, + }, { + tname: "bind with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind twice with reuseaddr once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseport, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, }, } { @@ -216,12 +280,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index cfdd0496e..060a2e7c6 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -58,6 +58,14 @@ const ( // Default = true. defaultDiscoverOnLinkPrefixes = true + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + // minimumRetransmitTimer is the minimum amount of time to wait between // sending NDP Neighbor solicitation messages. Note, RFC 4861 does // not impose a minimum Retransmit Timer, but we do here to make sure @@ -87,6 +95,24 @@ const ( // // Max = 10. MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 +) + +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour ) // NDPDispatcher is the interface integrators of netstack must implement to @@ -139,6 +165,33 @@ type NDPDispatcher interface { // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route + + // OnAutoGenAddress will be called when a new prefix with its + // autonomous address-configuration flag set has been received and SLAAC + // has been performed. Implementations may prevent the stack from + // assigning the address to the NIC by returning false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressInvalidated will be called when an auto-generated + // address (as part of SLAAC) has been invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption will be called when an NDP option with + // recursive DNS servers has been received. Note, addrs may contain + // link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) } // NDPConfigurations is the NDP configurations for the netstack. @@ -168,6 +221,17 @@ type NDPConfigurations struct { // will be discovered from Router Advertisements' Prefix Information // option. This configuration is ignored if HandleRAs is false. DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not global IPv6 + // addresses will be generated for a NIC in response to receiving a new + // Prefix Information option with its Autonomous Address + // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -179,6 +243,7 @@ func DefaultNDPConfigurations() NDPConfigurations { HandleRAs: defaultHandleRAs, DiscoverDefaultRouters: defaultDiscoverDefaultRouters, DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, } } @@ -210,6 +275,9 @@ type ndpState struct { // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The addresses generated by SLAAC. + autoGenAddresses map[tcpip.Address]autoGenAddressState } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -270,6 +338,32 @@ type onLinkPrefixState struct { doNotInvalidate *bool } +// autoGenAddressState holds data associated with an address generated via +// SLAAC. +type autoGenAddressState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the SLAAC address (A) in + // a race condition (T1 is a goroutine that handles a PI for A and T2 + // is the goroutine that handles A's invalidation timer firing): + // T1: Receive a new PI for A + // T1: Obtain the NIC's lock before processing the PI + // T2: A's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends A's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates A immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate A, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool + + // Nonzero only when the address is not valid forever (invalidationTimer + // is not nil). + validUntil time.Time +} + // startDuplicateAddressDetection performs Duplicate Address Detection. // // This function must only be called by IPv6 addresses that are currently @@ -534,19 +628,21 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // we do not check the iterator for errors on calls to Next. it, _ := ra.Options().Iter(false) for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { - switch opt.Type() { - case header.NDPPrefixInformationType: - if !ndp.configs.DiscoverOnLinkPrefixes { + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.nic.stack.ndpDisp == nil { continue } - pi := opt.(header.NDPPrefixInformation) + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) - prefix := pi.Subnet() + case header.NDPPrefixInformation: + prefix := opt.Subnet() // Is the prefix a link-local? if header.IsV6LinkLocalAddress(prefix.ID()) { - // ...Yes, skip as per RFC 4861 section 6.3.4. + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). continue } @@ -557,82 +653,13 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { continue } - if !pi.OnLinkFlag() { - // Not on-link so don't "discover" it as an - // on-link prefix. - continue - } - - prefixState, ok := ndp.onLinkPrefixes[prefix] - vl := pi.ValidLifetime() - switch { - case !ok && vl == 0: - // Don't know about this prefix but has a zero - // valid lifetime, so just ignore. - continue - - case !ok && vl != 0: - // This is a new on-link prefix we are - // discovering. - // - // Only remember it if we currently know about - // less than MaxDiscoveredOnLinkPrefixes on-link - // prefixes. - if len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { - ndp.rememberOnLinkPrefix(prefix, vl) - } - continue - - case ok && vl == 0: - // We know about the on-link prefix, but it is - // no longer to be considered on-link, so - // invalidate it. - ndp.invalidateOnLinkPrefix(prefix) - continue - } - - // This is an already discovered on-link prefix with a - // new non-zero valid lifetime. - // Update the invalidation timer. - timer := prefixState.invalidationTimer - - if timer == nil && vl >= header.NDPPrefixInformationInfiniteLifetime { - // Had infinite valid lifetime before and - // continues to have an invalid lifetime. Do - // nothing further. - continue - } - - if timer != nil && !timer.Stop() { - // If we reach this point, then we know the - // timer already fired after we took the NIC - // lock. Inform the timer to not invalidate - // the prefix once it obtains the lock as we - // just got a new PI that refeshes its lifetime - // to a non-zero value. See - // onLinkPrefixState.doNotInvalidate for more - // details. - *prefixState.doNotInvalidate = true + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) } - if vl >= header.NDPPrefixInformationInfiniteLifetime { - // Prefix is now valid forever so we don't need - // an invalidation timer. - prefixState.invalidationTimer = nil - ndp.onLinkPrefixes[prefix] = prefixState - continue + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) } - - if timer != nil { - // We already have a timer so just reset it to - // expire after the new valid lifetime. - timer.Reset(vl) - continue - } - - // We do not have a timer so just create a new one. - prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) - ndp.onLinkPrefixes[prefix] = prefixState } // TODO(b/141556115): Do (MTU) Parameter Discovery. @@ -734,7 +761,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) var timer *time.Timer // Only create a timer if the lifetime is not infinite. - if l < header.NDPPrefixInformationInfiniteLifetime { + if l < header.NDPInfiniteLifetime { timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) } @@ -795,3 +822,345 @@ func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Dur ndp.invalidateOnLinkPrefix(prefix) }) } + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // Update the invalidation timer. + timer := prefixState.invalidationTimer + + if timer == nil && vl >= header.NDPInfiniteLifetime { + // Had infinite valid lifetime before and + // continues to have an invalid lifetime. Do + // nothing further. + return + } + + if timer != nil && !timer.Stop() { + // If we reach this point, then we know the timer alread fired + // after we took the NIC lock. Inform the timer to not + // invalidate the prefix once it obtains the lock as we just + // got a new PI that refreshes its lifetime to a non-zero value. + // See onLinkPrefixState.doNotInvalidate for more details. + *prefixState.doNotInvalidate = true + } + + if vl >= header.NDPInfiniteLifetime { + // Prefix is now valid forever so we don't need + // an invalidation timer. + prefixState.invalidationTimer = nil + ndp.onLinkPrefixes[prefix] = prefixState + return + } + + if timer != nil { + // We already have a timer so just reset it to + // expire after the new valid lifetime. + timer.Reset(vl) + return + } + + // We do not have a timer so just create a new one. + prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already have an auto-generated address for prefix. + for _, ref := range ndp.nic.endpoints { + if ref.protocol != header.IPv6ProtocolNumber { + continue + } + + if ref.configType != slaac { + continue + } + + addr := ref.ep.ID().LocalAddress + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + if refAddrWithPrefix.Subnet() != prefix { + continue + } + + // + // At this point, we know we are refreshing a SLAAC generated + // IPv6 address with the prefix, prefix. Do the work as outlined + // by RFC 4862 section 5.5.3.e. + // + + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) + } + + // TODO(b/143713887): Handle deprecating auto-generated address + // after the preferred lifetime. + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the + // address generated by SLAAC is as follows: + // + // 1) If the received Valid Lifetime is greater than 2 hours or + // greater than RemainingLifetime, set the valid lifetime of + // the address to the advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, + // ignore the advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 + // hours. + + // Handle the infinite valid lifetime separately as we do not + // keep a timer in this case. + if vl >= header.NDPInfiniteLifetime { + if addrState.invalidationTimer != nil { + // Valid lifetime was finite before, but now it + // is valid forever. + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer = nil + addrState.validUntil = time.Time{} + ndp.autoGenAddresses[addr] = addrState + } + + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, + // assume the remaining time to be the maximum possible value. + if addrState.invalidationTimer == nil { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + ndp.autoGenAddresses[addr] = addrState + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if addrState.invalidationTimer == nil { + addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) + } else { + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer.Reset(effectiveVl) + } + + addrState.validUntil = time.Now().Add(effectiveVl) + ndp.autoGenAddresses[addr] = addrState + return + } + + // We do not already have an address within the prefix, prefix. Do the + // work as outlined by RFC 4862 section 5.5.3.d if n is configured + // to auto-generated global addresses by SLAAC. + + // Are we configured to auto-generate new global addresses? + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } + + // Generate an address within prefix from the EUI-64 of ndp's NIC's + // Ethernet MAC address. + addrBytes := make([]byte, header.IPv6AddressSize) + copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) + header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.Address(addrBytes) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + } + + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr) { + return + } + + // Inform the integrator that we have a new SLAAC address. + if ndp.nic.stack.ndpDisp == nil { + return + } + if !ndp.nic.stack.ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + // Informed by the integrator not to add the address. + return + } + + if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + }, FirstPrimaryEndpoint, permanent, slaac); err != nil { + panic(err) + } + + // Setup the timers to deprecate and invalidate this newly generated + // address. + + // TODO(b/143713887): Handle deprecating auto-generated addresses + // after the preferred lifetime. + + var doNotInvalidate bool + var vTimer *time.Timer + if vl < header.NDPInfiniteLifetime { + vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + } + + ndp.autoGenAddresses[addr] = autoGenAddressState{ + invalidationTimer: vTimer, + doNotInvalidate: &doNotInvalidate, + validUntil: time.Now().Add(vl), + } +} + +// invalidateAutoGenAddress invalidates an auto-generated address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { + if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { + return + } + + ndp.nic.removePermanentAddressLocked(addr) +} + +// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated +// address's resources from ndp. If the stack has an NDP dispatcher, it will +// be notified that addr has been invalidated. +// +// Returns true if ndp had resources for addr to cleanup. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { + state, ok := ndp.autoGenAddresses[addr] + + if !ok { + return false + } + + if state.invalidationTimer != nil { + state.invalidationTimer.Stop() + state.invalidationTimer = nil + *state.doNotInvalidate = true + } + + state.doNotInvalidate = nil + + delete(ndp.autoGenAddresses, addr) + + if ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } + + return true +} + +// autoGenAddrInvalidationTimer returns a new invalidation timer for an +// auto-generated address that fires after vl. +// +// doNotInvalidate is used to inform the timer when it fires at the same time +// that an auto-generated address's valid lifetime gets refreshed. See +// autoGenAddrState.doNotInvalidate for more details. +func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateAutoGenAddress(addr) + }) +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5b901f947..8d811eb8e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -38,7 +38,7 @@ const ( linkAddr1 = "\x02\x02\x03\x04\x05\x06" linkAddr2 = "\x02\x02\x03\x04\x05\x07" linkAddr3 = "\x02\x02\x03\x04\x05\x08" - defaultTimeout = 250 * time.Millisecond + defaultTimeout = 100 * time.Millisecond ) var ( @@ -47,6 +47,31 @@ var ( llAddr3 = header.LinkLocalAddr(linkAddr3) ) +// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent +// tcpip.Subnet, and an address where the lower half of the address is composed +// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. +func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { + prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixBytes), + PrefixLen: 64, + } + + subnet := prefix.Subnet() + + var addr tcpip.AddressWithPrefix + if header.IsValidUnicastEthernetAddress(linkAddr) { + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + } + + return prefix, subnet, addr +} + // TestDADDisabled tests that an address successfully resolves immediately // when DAD is not enabled (the default for an empty stack.Options). func TestDADDisabled(t *testing.T) { @@ -103,6 +128,29 @@ type ndpPrefixEvent struct { discovered bool } +type ndpAutoGenAddrEventType int + +const ( + newAddr ndpAutoGenAddrEventType = iota + invalidatedAddr +) + +type ndpAutoGenAddrEvent struct { + nicID tcpip.NICID + addr tcpip.AddressWithPrefix + eventType ndpAutoGenAddrEventType +} + +type ndpRDNSS struct { + addrs []tcpip.Address + lifetime time.Duration +} + +type ndpRDNSSEvent struct { + nicID tcpip.NICID + rdnss ndpRDNSS +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP @@ -113,6 +161,8 @@ type ndpDispatcher struct { rememberRouter bool prefixC chan ndpPrefixEvent rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent routeTable []tcpip.Route } @@ -211,7 +261,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } - rt := make([]tcpip.Route, 0) + var rt []tcpip.Route exclude := tcpip.Route{ Destination: prefix, NIC: nicID, @@ -226,6 +276,40 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi return rt } +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { + if n.autoGenAddrC != nil { + n.autoGenAddrC <- ndpAutoGenAddrEvent{ + nicID, + addr, + newAddr, + } + } + return true +} + +func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if n.autoGenAddrC != nil { + n.autoGenAddrC <- ndpAutoGenAddrEvent{ + nicID, + addr, + invalidatedAddr, + } + } +} + +// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { + if n.rdnssC != nil { + n.rdnssC <- ndpRDNSSEvent{ + nicID, + ndpRDNSS{ + addrs, + lifetime, + }, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -247,6 +331,8 @@ func TestDADResolve(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -781,16 +867,33 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink bool, vl uint32) tcpip.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { flags := uint8(0) if onLink { - flags |= 128 + // The OnLink flag is the 7th bit in the flags byte. + flags |= 1 << 7 + } + if auto { + // The Address Auto-Configuration flag is the 6th bit in the + // flags byte. + flags |= 1 << 6 } + // A valid header.NDPPrefixInformation must be 30 bytes. buf := [30]byte{} + // The first byte in a header.NDPPrefixInformation is the Prefix Length + // field. buf[0] = uint8(prefix.PrefixLen) + // The 2nd byte within a header.NDPPrefixInformation is the Flags field. buf[1] = flags + // The Valid Lifetime field starts after the 2nd byte within a + // header.NDPPrefixInformation. binary.BigEndian.PutUint32(buf[2:], vl) + // The Preferred Lifetime field starts after the 6th byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[6:], pl) + // The Prefix Address field starts after the 14th byte within a + // header.NDPPrefixInformation. copy(buf[14:], prefix.Address) return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ header.NDPPrefixInformation(buf[:]), @@ -800,6 +903,8 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { + t.Parallel() + // Being configured to discover routers means handle and // discover are set to true and forwarding is set to false. // This tests all possible combinations of the configurations, @@ -812,6 +917,8 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), } @@ -844,6 +951,8 @@ func TestNoRouterDiscovery(t *testing.T) { // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), } @@ -909,6 +1018,8 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), rememberRouter: true, @@ -1040,6 +1151,8 @@ func TestRouterDiscovery(t *testing.T) { // TestRouterDiscoveryMaxRouters tests that only // stack.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), rememberRouter: true, @@ -1104,6 +1217,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { // TestNoPrefixDiscovery tests that prefix discovery will not be performed if // configured not to. func TestNoPrefixDiscovery(t *testing.T) { + t.Parallel() + prefix := tcpip.AddressWithPrefix{ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), PrefixLen: 64, @@ -1121,6 +1236,8 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 10), } @@ -1140,7 +1257,7 @@ func TestNoPrefixDiscovery(t *testing.T) { } // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 10)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) select { case <-ndpDisp.prefixC: @@ -1154,11 +1271,9 @@ func TestNoPrefixDiscovery(t *testing.T) { // TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered on-link prefix when the dispatcher asks it not to. func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - subnet := prefix.Subnet() + t.Parallel() + + prefix, subnet, _ := prefixSubnetAddr(0, "") ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 10), @@ -1189,7 +1304,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { // Rx an RA with prefix with a short lifetime. const lifetime = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, lifetime)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetime, 0)) select { case r := <-ndpDisp.prefixC: if r.nicID != 1 { @@ -1226,21 +1341,11 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } func TestPrefixDiscovery(t *testing.T) { - prefix1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - prefix2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - prefix3 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x0a\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 72, - } - subnet1 := prefix1.Subnet() - subnet2 := prefix2.Subnet() - subnet3 := prefix3.Subnet() + t.Parallel() + + prefix1, subnet1, _ := prefixSubnetAddr(0, "") + prefix2, subnet2, _ := prefixSubnetAddr(1, "") + prefix3, subnet3, _ := prefixSubnetAddr(2, "") ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 10), @@ -1281,7 +1386,7 @@ func TestPrefixDiscovery(t *testing.T) { // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix with 0 lifetime") @@ -1290,7 +1395,7 @@ func TestPrefixDiscovery(t *testing.T) { // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) waitForEvent(subnet1, true, defaultTimeout) // Should have added a device route for subnet1 through the nic. @@ -1299,7 +1404,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) waitForEvent(subnet2, true, defaultTimeout) // Should have added a device route for subnet2 through the nic. @@ -1308,7 +1413,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) waitForEvent(subnet3, true, defaultTimeout) // Should have added a device route for subnet3 through the nic. @@ -1317,7 +1422,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) waitForEvent(subnet1, false, defaultTimeout) // Should have removed the device route for subnet1 through the nic. @@ -1327,7 +1432,7 @@ func TestPrefixDiscovery(t *testing.T) { // Receive an RA with prefix2 in a PI with lesser lifetime. lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, lifetime)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly received prefix event when updating lifetime") @@ -1349,7 +1454,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) waitForEvent(subnet3, false, defaultTimeout) // Should not have any routes. @@ -1364,10 +1469,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // invalidate the prefix. const testInfiniteLifetimeSeconds = 2 const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPPrefixInformationInfiniteLifetime - header.NDPPrefixInformationInfiniteLifetime = testInfiniteLifetime + saved := header.NDPInfiniteLifetime + header.NDPInfiniteLifetime = testInfiniteLifetime defer func() { - header.NDPPrefixInformationInfiniteLifetime = saved + header.NDPInfiniteLifetime = saved }() prefix := tcpip.AddressWithPrefix{ @@ -1415,7 +1520,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) waitForEvent(true, defaultTimeout) select { case <-ndpDisp.prefixC: @@ -1425,16 +1530,16 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with finite lifetime. // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) waitForEvent(false, testInfiniteLifetime) // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) waitForEvent(true, defaultTimeout) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1443,7 +1548,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with a prefix with a lifetime value greater than the // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds+1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1452,13 +1557,15 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with 0 lifetime. // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) waitForEvent(false, defaultTimeout) } // TestPrefixDiscoveryMaxRouters tests that only // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, @@ -1537,3 +1644,606 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) } } + +// Checks to see if list contains an IPv6 address, item. +func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { + protocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: item, + } + + for _, i := range list { + if i == protocolAddress { + return true + } + } + + return false +} + +// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. +func TestNoAutoGenAddr(t *testing.T) { + t.Parallel() + + prefix, _, _ := prefixSubnetAddr(0, "") + + // Being configured to auto-generate addresses means handle and + // autogen are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, autogen = + // true and forwarding = false (the required configuration to do + // SLAAC) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + autogen := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) + + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when configured not to") + case <-time.After(defaultTimeout): + } + }) + } +} + +// TestAutoGenAddr tests that an address is properly generated and invalidated +// when configured to do so. +func TestAutoGenAddr(t *testing.T) { + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + waitForEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != eventType { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, eventType) + } + case <-time.After(timeout): + t.Fatal("timeout waiting for addr auto gen event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + case <-time.After(defaultTimeout): + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + waitForEvent(addr1, newAddr, defaultTimeout) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + case <-time.After(defaultTimeout): + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + waitForEvent(addr2, newAddr, defaultTimeout) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + case <-time.After(defaultTimeout): + } + + // Wait for addr of prefix1 to be invalidated. + waitForEvent(addr1, invalidatedAddr, newMinVLDuration+defaultTimeout) + if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } +} + +// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an +// auto-generated address only gets updated when required to, as specified in +// RFC 4862 section 5.5.3.e. +func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { + const infiniteVL = 4294967295 + const newMinVL = 5 + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + ovl uint32 + nvl uint32 + evl uint32 + }{ + // Should update the VL to the minimum VL for updating if the + // new VL is less than newMinVL but was originally greater than + // it. + { + "LargeVLToVLLessThanMinVLForUpdate", + 9999, + 1, + newMinVL, + }, + { + "LargeVLTo0", + 9999, + 0, + newMinVL, + }, + { + "InfiniteVLToVLLessThanMinVLForUpdate", + infiniteVL, + 1, + newMinVL, + }, + { + "InfiniteVLTo0", + infiniteVL, + 0, + newMinVL, + }, + + // Should not update VL if original VL was less than newMinVL + // and the new VL is also less than newMinVL. + { + "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", + newMinVL - 1, + newMinVL - 3, + newMinVL - 1, + }, + + // Should take the new VL if the new VL is greater than the + // remaining time or is greater than newMinVL. + { + "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", + newMinVL + 5, + newMinVL + 3, + newMinVL + 3, + }, + { + "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", + newMinVL - 3, + newMinVL - 1, + newMinVL - 1, + }, + { + "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", + newMinVL - 3, + newMinVL + 1, + newMinVL + 1, + }, + } + + const delta = 500 * time.Millisecond + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with initial VL, test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != newAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + + // Receive an new RA with prefix with new VL, test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + + // + // Validate that the VL for the address got set to + // test.evl. + // + + // Make sure we do not get any invalidation events + // until atleast 500ms (delta) before test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): + } + + // Wait for another second (2x delta), but now we expect + // the invalidation event. + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != invalidatedAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } +} + +// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed +// by the user, its resources will be cleaned up and an invalidation event will +// be sent to the integrator. +func TestAutoGenAddrRemoval(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with its valid lifetime = lifetime. + const lifetime = 5 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != newAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + + // Remove the address. + if err := s.RemoveAddress(1, addr.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) + } + + // Should get the invalidation event immediately. + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != invalidatedAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + + // Wait for the original valid lifetime to make sure the original timer + // got stopped/cleaned up. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(lifetime*time.Second + defaultTimeout): + } +} + +// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that +// is already assigned to the NIC, the static address remains. +func TestAutoGenAddrStaticConflict(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Add the address as a static address before SLAAC tries to add it. + if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive a PI where the generated address will be the same as the one + // that we already have assigned statically. + const lifetime = 5 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") + case <-time.After(defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Should not get an invalidation event after the PI's invalidation + // time. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetime*time.Second + defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } +} + +// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event +// to the integrator when an RA is received with the NDP Recursive DNS Server +// option with at least one valid address. +func TestNDPRecursiveDNSServerDispatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opt header.NDPRecursiveDNSServer + expected *ndpRDNSS + }{ + { + "Unspecified", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }), + nil, + }, + { + "Multicast", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }), + nil, + }, + { + "OptionTooSmall", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, + }), + nil, + }, + { + "0Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + }), + nil, + }, + { + "Valid1Address", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + }, + 2 * time.Second, + }, + }, + { + "Valid2Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + }, + time.Second, + }, + }, + { + "Valid3Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", + }, + 0, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + // We do not expect more than a single RDNSS + // event at any time for this test. + rdnssC: make(chan ndpRDNSSEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) + + if test.expected != nil { + select { + case e := <-ndpDisp.rdnssC: + if e.nicID != 1 { + t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) + } + if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { + t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) + } + if e.rdnss.lifetime != test.expected.lifetime { + t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) + } + default: + t.Fatal("expected an RDNSS option event") + } + } + + // Should have no more RDNSS options. + select { + case e := <-ndpDisp.rdnssC: + t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) + default: + } + }) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3f8d7312c..e8401c673 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -115,10 +115,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback }, }, ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), }, } nic.ndp.nic = nic @@ -244,6 +245,20 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +// hasPermanentAddrLocked returns true if n has a permanent (including currently +// tentative) address, addr. +func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + + if !ok { + return false + } + + kind := ref.getKind() + + return kind == permanent || kind == permanentTentative +} + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) } @@ -335,7 +350,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary) + }, peb, temporary, static) n.mu.Unlock() return ref @@ -384,10 +399,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent) + return n.addAddressLocked(protocolAddress, peb, permanent, static) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -417,11 +432,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, + configType: configType, } // Set up cache if link address resolution exists for this protocol. @@ -624,9 +640,18 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) - // If we are removing a tentative IPv6 unicast address, stop DAD. - if isIPv6Unicast && kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + if isIPv6Unicast { + // If we are removing a tentative IPv6 unicast address, stop + // DAD. + if kind == permanentTentative { + n.ndp.stopDuplicateAddressDetection(addr) + } + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + if r.configType == slaac { + n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + } } r.setKind(permanentExpired) @@ -989,7 +1014,7 @@ const ( // removing the permanent address from the NIC. permanent - // An expired permanent endoint is a permanent endoint that had its address + // An expired permanent endpoint is a permanent endpoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes // hold a reference to it. This is achieved by decreasing its reference count // by 1. If its address is re-added before the endpoint is removed, its type @@ -1035,6 +1060,19 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } +type networkEndpointConfigType int32 + +const ( + // A statically configured endpoint is an address that was added by + // some user-specified action (adding an explicit address, joining a + // multicast group). + static networkEndpointConfigType = iota + + // A slaac configured endpoint is an IPv6 endpoint that was + // added by SLAAC as per RFC 4862 section 5.5.3. + slaac +) + type referencedNetworkEndpoint struct { ep NetworkEndpoint nic *NIC @@ -1050,6 +1088,10 @@ type referencedNetworkEndpoint struct { // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind + + // configType is the method that was used to configure this endpoint. + // This must never change after the endpoint is added to a NIC. + configType networkEndpointConfigType } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 5746043cc..d5bb5b6ed 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -924,6 +924,14 @@ type TCPStats struct { // ESTABLISHED state or the CLOSE-WAIT state. EstablishedResets *StatCounter + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index dd1728f9c..455a1c098 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -52,6 +52,7 @@ go_library( "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f543a6105..74df3edfb 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -298,8 +298,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.stack.Stats().TCP.CurrentEstablished.Increment() - ep.state = StateEstablished ep.isConnectNotified = true ep.mu.Unlock() @@ -546,6 +544,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + // We do not use transitionToStateEstablishedLocked here as there is + // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4206db8b6..3d059c302 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -252,6 +252,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -352,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + return nil } @@ -880,6 +889,30 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver if required. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + if e.snd == nil { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + } + if e.rcv == nil { + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + } + h.ep.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished +} + // transitionToStateCloseLocked ensures that the endpoint is // cleaned up from the transport demuxer, "before" moving to // StateClose. This will ensure that no packet will be @@ -891,6 +924,7 @@ func (e *endpoint) transitionToStateCloseLocked() { } e.cleanupLocked() e.state = StateClose + e.stack.Stats().TCP.EstablishedClosed.Increment() } // tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed @@ -953,20 +987,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - e.mu.RLock() - state := e.state - e.mu.RUnlock() - if state == StateClose { - // When we get into StateClose while processing from the queue, - // return immediately and let the protocolMainloop handle it. - // - // We can reach StateClose only while processing a previous segment - // or a notification from the protocolMainLoop (caller goroutine). - // This means that with this return, the segment dequeue below can - // never occur on a closed endpoint. - return nil - } - s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false @@ -1024,6 +1044,24 @@ func (e *endpoint) handleSegments() *tcpip.Error { s.decRef() continue } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return nil + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -1057,6 +1095,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -1142,8 +1181,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError e.HardError = err @@ -1152,25 +1189,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.mu.Lock() - e.state = StateEstablished - e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -1371,7 +1389,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - e.stack.Stats().TCP.EstablishedResets.Increment() e.stack.Stats().TCP.CurrentEstablished.Decrement() e.transitionToStateCloseLocked() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9d4a87e30..4861ab513 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -343,6 +344,7 @@ type endpoint struct { // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -737,9 +739,10 @@ func (e *endpoint) Close() { e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } // Mark endpoint as closed. @@ -800,10 +803,11 @@ func (e *endpoint) cleanupLocked() { } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} e.route.Release() e.stack.CompleteTransportEndpointCleanup(e) @@ -1775,7 +1779,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // reusePort is false below because connect cannot reuse a port even if // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { return false, nil } @@ -1802,7 +1806,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } @@ -2034,28 +2038,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) if err != nil { return err } e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = flags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func(bindToDevice tcpip.NICID) { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }(e.boundBindToDevice) + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 857dc445f..5ee499c36 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -205,7 +205,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() switch r.ep.state { case StateFinWait1: diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index d3f7c9125..8332a0179 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -674,7 +674,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } - s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 50829ae27..bc5cfcf0e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) { if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestConnectIncrementActiveConnection(t *testing.T) { @@ -541,13 +555,21 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { ep.(interface{ ResumeWork() }).ResumeWork() // Wait for the protocolMainLoop to resume and update state. - time.Sleep(1 * time.Millisecond) + time.Sleep(10 * time.Millisecond) // Expect the endpoint to be closed. if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + // Check if the endpoint was moved to CLOSED and netstack a reset in // response to the ACK packet that we sent after last-ACK. checker.IPv4(t, c.GetPacket(), @@ -2694,6 +2716,13 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -4363,9 +4392,17 @@ func TestKeepalive(t *testing.T) { ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -5632,6 +5669,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5750,6 +5788,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5856,6 +5895,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5929,6 +5969,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) @@ -5941,6 +5982,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5987,6 +6029,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) } + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -6007,6 +6051,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -6114,4 +6159,184 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.SeqNum(uint32(ackHeaders.AckNum)), checker.AckNum(uint32(ackHeaders.SeqNum)), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } +} + +func TestTCPCloseWithData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + + c.EP.Close() + // Check the FIN. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.AckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Check the RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(uint32(ackHeaders.SeqNum)), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6cb66c1af..b0a376eba 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -231,6 +231,7 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -259,6 +260,7 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -483,6 +485,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv6.ProtocolNumber { diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 8d4c3808f..97e4d5825 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 24cb88c13..1ac4705af 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -107,6 +108,7 @@ type endpoint struct { // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags // sendTOS represents IPv4 TOS or IPv6 TrafficClass, // applied while sending packets. Defaults to 0 as on Linux. @@ -180,8 +182,9 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } for _, mem := range e.multicastMemberships { @@ -895,7 +898,8 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundPortFlags = ports.Flags{} } e.state = StateInitial } @@ -1042,16 +1046,23 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + // FIXME(b/129164367): Support SO_REUSEADDR. + MostRecent: false, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) if err != nil { return id, e.bindToDevice, err } + e.boundPortFlags = flags id.LocalPort = port } err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.boundPortFlags = ports.Flags{} } return id, e.bindToDevice, err } @@ -1134,9 +1145,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() + addr := e.ID.LocalAddress + if e.state == StateConnected { + addr = e.route.LocalAddress + } + return tcpip.FullAddress{ NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, + Addr: addr, Port: e.ID.LocalPort, }, nil } diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 847d2f91c..6226b63f8 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "compat.go", "compat_amd64.go", + "compat_arm64.go", "config.go", "controller.go", "debug.go", @@ -110,7 +111,6 @@ go_test( "//pkg/control/server", "//pkg/log", "//pkg/p9", - "//pkg/sentry/arch:registers_go_proto", "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 07e35ab10..352e710d2 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -21,10 +21,8 @@ import ( "syscall" "github.com/golang/protobuf/proto" - "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/arch" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/strace" @@ -53,9 +51,9 @@ type compatEmitter struct { } func newCompatEmitter(logFD int) (*compatEmitter, error) { - nameMap, ok := strace.Lookup(abi.Linux, arch.AMD64) + nameMap, ok := getSyscallNameMap() if !ok { - return nil, fmt.Errorf("amd64 Linux syscall table not found") + return nil, fmt.Errorf("Linux syscall table not found") } c := &compatEmitter{ @@ -86,16 +84,16 @@ func (c *compatEmitter) Emit(msg proto.Message) (bool, error) { } func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) { - regs := us.Registers.GetArch().(*rpb.Registers_Amd64).Amd64 + regs := us.Registers c.mu.Lock() defer c.mu.Unlock() - sysnr := regs.OrigRax + sysnr := syscallNum(regs) tr := c.trackers[sysnr] if tr == nil { switch sysnr { - case syscall.SYS_PRCTL, syscall.SYS_ARCH_PRCTL: + case syscall.SYS_PRCTL: // args: cmd, ... tr = newArgsTracker(0) @@ -112,10 +110,14 @@ func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) { tr = newArgsTracker(2) default: - tr = &onceTracker{} + tr = newArchArgsTracker(sysnr) + if tr == nil { + tr = &onceTracker{} + } } c.trackers[sysnr] = tr } + if tr.shouldReport(regs) { c.sink.Infof("Unsupported syscall: %s, regs: %+v", c.nameMap.Name(uintptr(sysnr)), regs) tr.onReported(regs) @@ -139,10 +141,10 @@ func (c *compatEmitter) Close() error { // the syscall and arguments. type syscallTracker interface { // shouldReport returns true is the syscall should be reported. - shouldReport(regs *rpb.AMD64Registers) bool + shouldReport(regs *rpb.Registers) bool // onReported marks the syscall as reported. - onReported(regs *rpb.AMD64Registers) + onReported(regs *rpb.Registers) } // onceTracker reports only a single time, used for most syscalls. @@ -150,10 +152,45 @@ type onceTracker struct { reported bool } -func (o *onceTracker) shouldReport(_ *rpb.AMD64Registers) bool { +func (o *onceTracker) shouldReport(_ *rpb.Registers) bool { return !o.reported } -func (o *onceTracker) onReported(_ *rpb.AMD64Registers) { +func (o *onceTracker) onReported(_ *rpb.Registers) { o.reported = true } + +// argsTracker reports only once for each different combination of arguments. +// It's used for generic syscalls like ioctl to report once per 'cmd'. +type argsTracker struct { + // argsIdx is the syscall arguments to use as unique ID. + argsIdx []int + reported map[string]struct{} + count int +} + +func newArgsTracker(argIdx ...int) *argsTracker { + return &argsTracker{argsIdx: argIdx, reported: make(map[string]struct{})} +} + +// key returns the command based on the syscall argument index. +func (a *argsTracker) key(regs *rpb.Registers) string { + var rv string + for _, idx := range a.argsIdx { + rv += fmt.Sprintf("%d|", argVal(idx, regs)) + } + return rv +} + +func (a *argsTracker) shouldReport(regs *rpb.Registers) bool { + if a.count >= reportLimit { + return false + } + _, ok := a.reported[a.key(regs)] + return !ok +} + +func (a *argsTracker) onReported(regs *rpb.Registers) { + a.count++ + a.reported[a.key(regs)] = struct{}{} +} diff --git a/runsc/boot/compat_amd64.go b/runsc/boot/compat_amd64.go index 43cd0db94..42b0ca8b0 100644 --- a/runsc/boot/compat_amd64.go +++ b/runsc/boot/compat_amd64.go @@ -16,62 +16,81 @@ package boot import ( "fmt" + "syscall" + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/sentry/arch" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + "gvisor.dev/gvisor/pkg/sentry/strace" ) // reportLimit is the max number of events that should be reported per tracker. const reportLimit = 100 -// argsTracker reports only once for each different combination of arguments. -// It's used for generic syscalls like ioctl to report once per 'cmd'. -type argsTracker struct { - // argsIdx is the syscall arguments to use as unique ID. - argsIdx []int - reported map[string]struct{} - count int +// newRegs create a empty Registers instance. +func newRegs() *rpb.Registers { + return &rpb.Registers{ + Arch: &rpb.Registers_Amd64{ + Amd64: &rpb.AMD64Registers{}, + }, + } } -func newArgsTracker(argIdx ...int) *argsTracker { - return &argsTracker{argsIdx: argIdx, reported: make(map[string]struct{})} -} +func argVal(argIdx int, regs *rpb.Registers) uint32 { + amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64 -// cmd returns the command based on the syscall argument index. -func (a *argsTracker) key(regs *rpb.AMD64Registers) string { - var rv string - for _, idx := range a.argsIdx { - rv += fmt.Sprintf("%d|", argVal(idx, regs)) + switch argIdx { + case 0: + return uint32(amd64Regs.Rdi) + case 1: + return uint32(amd64Regs.Rsi) + case 2: + return uint32(amd64Regs.Rdx) + case 3: + return uint32(amd64Regs.R10) + case 4: + return uint32(amd64Regs.R8) + case 5: + return uint32(amd64Regs.R9) } - return rv + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) } -func argVal(argIdx int, regs *rpb.AMD64Registers) uint32 { +func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) { + amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64 + switch argIdx { case 0: - return uint32(regs.Rdi) + amd64Regs.Rdi = argVal case 1: - return uint32(regs.Rsi) + amd64Regs.Rsi = argVal case 2: - return uint32(regs.Rdx) + amd64Regs.Rdx = argVal case 3: - return uint32(regs.R10) + amd64Regs.R10 = argVal case 4: - return uint32(regs.R8) + amd64Regs.R8 = argVal case 5: - return uint32(regs.R9) + amd64Regs.R9 = argVal + default: + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) } - panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) } -func (a *argsTracker) shouldReport(regs *rpb.AMD64Registers) bool { - if a.count >= reportLimit { - return false - } - _, ok := a.reported[a.key(regs)] - return !ok +func getSyscallNameMap() (strace.SyscallMap, bool) { + return strace.Lookup(abi.Linux, arch.AMD64) +} + +func syscallNum(regs *rpb.Registers) uint64 { + amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64 + return amd64Regs.OrigRax } -func (a *argsTracker) onReported(regs *rpb.AMD64Registers) { - a.count++ - a.reported[a.key(regs)] = struct{}{} +func newArchArgsTracker(sysnr uint64) syscallTracker { + switch sysnr { + case syscall.SYS_ARCH_PRCTL: + // args: cmd, ... + return newArgsTracker(0) + } + return nil } diff --git a/runsc/boot/compat_arm64.go b/runsc/boot/compat_arm64.go new file mode 100644 index 000000000..f784cd237 --- /dev/null +++ b/runsc/boot/compat_arm64.go @@ -0,0 +1,91 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package boot + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/sentry/arch" + rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + "gvisor.dev/gvisor/pkg/sentry/strace" +) + +// reportLimit is the max number of events that should be reported per tracker. +const reportLimit = 100 + +// newRegs create a empty Registers instance. +func newRegs() *rpb.Registers { + return &rpb.Registers{ + Arch: &rpb.Registers_Arm64{ + Arm64: &rpb.ARM64Registers{}, + }, + } +} + +func argVal(argIdx int, regs *rpb.Registers) uint32 { + arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64 + + switch argIdx { + case 0: + return uint32(arm64Regs.R0) + case 1: + return uint32(arm64Regs.R1) + case 2: + return uint32(arm64Regs.R2) + case 3: + return uint32(arm64Regs.R3) + case 4: + return uint32(arm64Regs.R4) + case 5: + return uint32(arm64Regs.R5) + } + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) +} + +func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) { + arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64 + + switch argIdx { + case 0: + arm64Regs.R0 = argVal + case 1: + arm64Regs.R1 = argVal + case 2: + arm64Regs.R2 = argVal + case 3: + arm64Regs.R3 = argVal + case 4: + arm64Regs.R4 = argVal + case 5: + arm64Regs.R5 = argVal + default: + panic(fmt.Sprintf("invalid syscall argument index %d", argIdx)) + } +} + +func getSyscallNameMap() (strace.SyscallMap, bool) { + return strace.Lookup(abi.Linux, arch.ARM64) +} + +func syscallNum(regs *rpb.Registers) uint64 { + arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64 + return arm64Regs.R8 +} + +func newArchArgsTracker(sysnr uint64) syscallTracker { + // currently, no arch specific syscalls need to be handled here. + return nil +} diff --git a/runsc/boot/compat_test.go b/runsc/boot/compat_test.go index 388298d8d..839c5303b 100644 --- a/runsc/boot/compat_test.go +++ b/runsc/boot/compat_test.go @@ -16,8 +16,6 @@ package boot import ( "testing" - - rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" ) func TestOnceTracker(t *testing.T) { @@ -35,31 +33,34 @@ func TestOnceTracker(t *testing.T) { func TestArgsTracker(t *testing.T) { for _, tc := range []struct { - name string - idx []int - rdi1 uint64 - rdi2 uint64 - rsi1 uint64 - rsi2 uint64 - want bool + name string + idx []int + arg1_1 uint64 + arg1_2 uint64 + arg2_1 uint64 + arg2_2 uint64 + want bool }{ - {name: "same rdi", idx: []int{0}, rdi1: 123, rdi2: 123, want: false}, - {name: "same rsi", idx: []int{1}, rsi1: 123, rsi2: 123, want: false}, - {name: "diff rdi", idx: []int{0}, rdi1: 123, rdi2: 321, want: true}, - {name: "diff rsi", idx: []int{1}, rsi1: 123, rsi2: 321, want: true}, - {name: "cmd is uint32", idx: []int{0}, rsi1: 0xdead00000123, rsi2: 0xbeef00000123, want: false}, - {name: "same 2 args", idx: []int{0, 1}, rsi1: 123, rdi1: 321, rsi2: 123, rdi2: 321, want: false}, - {name: "diff 2 args", idx: []int{0, 1}, rsi1: 123, rdi1: 321, rsi2: 789, rdi2: 987, want: true}, + {name: "same arg1", idx: []int{0}, arg1_1: 123, arg1_2: 123, want: false}, + {name: "same arg2", idx: []int{1}, arg2_1: 123, arg2_2: 123, want: false}, + {name: "diff arg1", idx: []int{0}, arg1_1: 123, arg1_2: 321, want: true}, + {name: "diff arg2", idx: []int{1}, arg2_1: 123, arg2_2: 321, want: true}, + {name: "cmd is uint32", idx: []int{0}, arg2_1: 0xdead00000123, arg2_2: 0xbeef00000123, want: false}, + {name: "same 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 123, arg1_2: 321, want: false}, + {name: "diff 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 789, arg1_2: 987, want: true}, } { t.Run(tc.name, func(t *testing.T) { c := newArgsTracker(tc.idx...) - regs := &rpb.AMD64Registers{Rdi: tc.rdi1, Rsi: tc.rsi1} + regs := newRegs() + setArgVal(0, tc.arg1_1, regs) + setArgVal(1, tc.arg2_1, regs) if !c.shouldReport(regs) { t.Error("first call to shouldReport, got: false, want: true") } c.onReported(regs) - regs.Rdi, regs.Rsi = tc.rdi2, tc.rsi2 + setArgVal(0, tc.arg1_2, regs) + setArgVal(1, tc.arg2_2, regs) if got := c.shouldReport(regs); tc.want != got { t.Errorf("second call to shouldReport, got: %t, want: %t", got, tc.want) } @@ -70,7 +71,9 @@ func TestArgsTracker(t *testing.T) { func TestArgsTrackerLimit(t *testing.T) { c := newArgsTracker(0, 1) for i := 0; i < reportLimit; i++ { - regs := &rpb.AMD64Registers{Rdi: 123, Rsi: uint64(i)} + regs := newRegs() + setArgVal(0, 123, regs) + setArgVal(1, uint64(i), regs) if !c.shouldReport(regs) { t.Error("shouldReport before limit was reached, got: false, want: true") } @@ -78,7 +81,9 @@ func TestArgsTrackerLimit(t *testing.T) { } // Should hit the count limit now. - regs := &rpb.AMD64Registers{Rdi: 123, Rsi: 123456} + regs := newRegs() + setArgVal(0, 123, regs) + setArgVal(1, 123456, regs) if c.shouldReport(regs) { t.Error("shouldReport after limit was reached, got: true, want: false") } diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index f62be4c59..9c9e94864 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -152,7 +152,9 @@ func newController(fd int, l *Loader) (*controller, error) { srv.Register(&debug{}) srv.Register(&control.Logging{}) if l.conf.ProfileEnable { - srv.Register(&control.Profile{}) + srv.Register(&control.Profile{ + Kernel: l.k, + }) } return &controller{ diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index bf690160c..4fb9adca6 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -313,11 +313,21 @@ func hostInetFilters() seccomp.SyscallRules { { seccomp.AllowAny{}, seccomp.AllowValue(syscall.SOL_IP), + seccomp.AllowValue(syscall.IP_TOS), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IP), seccomp.AllowValue(syscall.IP_RECVTOS), }, { seccomp.AllowAny{}, seccomp.AllowValue(syscall.SOL_IPV6), + seccomp.AllowValue(syscall.IPV6_TCLASS), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IPV6), seccomp.AllowValue(syscall.IPV6_RECVTCLASS), }, { @@ -426,6 +436,13 @@ func hostInetFilters() seccomp.SyscallRules { { seccomp.AllowAny{}, seccomp.AllowValue(syscall.SOL_IP), + seccomp.AllowValue(syscall.IP_TOS), + seccomp.AllowAny{}, + seccomp.AllowValue(4), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IP), seccomp.AllowValue(syscall.IP_RECVTOS), seccomp.AllowAny{}, seccomp.AllowValue(4), @@ -433,6 +450,13 @@ func hostInetFilters() seccomp.SyscallRules { { seccomp.AllowAny{}, seccomp.AllowValue(syscall.SOL_IPV6), + seccomp.AllowValue(syscall.IPV6_TCLASS), + seccomp.AllowAny{}, + seccomp.AllowValue(4), + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.SOL_IPV6), seccomp.AllowValue(syscall.IPV6_RECVTCLASS), seccomp.AllowAny{}, seccomp.AllowValue(4), diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index bc9ffaf81..421ccd255 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -16,7 +16,6 @@ package boot import ( "fmt" - "path" "path/filepath" "sort" "strconv" @@ -52,7 +51,7 @@ const ( rootDevice = "9pfs-/" // MountPrefix is the annotation prefix for mount hints. - MountPrefix = "gvisor.dev/spec/mount" + MountPrefix = "dev.gvisor.spec.mount." // Filesystems that runsc supports. bind = "bind" @@ -490,14 +489,15 @@ type podMountHints struct { func newPodMountHints(spec *specs.Spec) (*podMountHints, error) { mnts := make(map[string]*mountHint) for k, v := range spec.Annotations { - // Look for 'gvisor.dev/spec/mount' annotations and parse them. + // Look for 'dev.gvisor.spec.mount' annotations and parse them. if strings.HasPrefix(k, MountPrefix) { - parts := strings.Split(k, "/") - if len(parts) != 5 { + // Remove the prefix and split the rest. + parts := strings.Split(k[len(MountPrefix):], ".") + if len(parts) != 2 { return nil, fmt.Errorf("invalid mount annotation: %s=%s", k, v) } - name := parts[3] - if len(name) == 0 || path.Clean(name) != name { + name := parts[0] + if len(name) == 0 { return nil, fmt.Errorf("invalid mount name: %s", name) } mnt := mnts[name] @@ -505,7 +505,7 @@ func newPodMountHints(spec *specs.Spec) (*podMountHints, error) { mnt = &mountHint{name: name} mnts[name] = mnt } - if err := mnt.setField(parts[4], v); err != nil { + if err := mnt.setField(parts[1], v); err != nil { return nil, err } } @@ -575,6 +575,11 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin func (c *containerMounter) processHints(conf *Config) error { ctx := c.k.SupervisorContext() for _, hint := range c.hints.mounts { + // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a + // common gofer to mount all shared volumes. + if hint.mount.Type != tmpfs { + continue + } log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type) inode, err := c.mountSharedMaster(ctx, conf, hint) if err != nil { @@ -851,7 +856,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns return fmt.Errorf("mount %q error: %v", m.Destination, err) } - log.Infof("Mounted %q to %q type %s", m.Source, m.Destination, m.Type) + log.Infof("Mounted %q to %q type: %s, internal-options: %q", m.Source, m.Destination, m.Type, opts) return nil } diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go index 0396a4cfb..912037075 100644 --- a/runsc/boot/fs_test.go +++ b/runsc/boot/fs_test.go @@ -15,7 +15,6 @@ package boot import ( - "path" "reflect" "strings" "testing" @@ -26,19 +25,19 @@ import ( func TestPodMountHintsHappy(t *testing.T) { spec := &specs.Spec{ Annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", - path.Join(MountPrefix, "mount2", "source"): "bar", - path.Join(MountPrefix, "mount2", "type"): "bind", - path.Join(MountPrefix, "mount2", "share"): "container", - path.Join(MountPrefix, "mount2", "options"): "rw,private", + MountPrefix + "mount2.source": "bar", + MountPrefix + "mount2.type": "bind", + MountPrefix + "mount2.share": "container", + MountPrefix + "mount2.options": "rw,private", }, } podHints, err := newPodMountHints(spec) if err != nil { - t.Errorf("newPodMountHints failed: %v", err) + t.Fatalf("newPodMountHints failed: %v", err) } // Check that fields were set correctly. @@ -86,95 +85,95 @@ func TestPodMountHintsErrors(t *testing.T) { { name: "too short", annotations: map[string]string{ - path.Join(MountPrefix, "mount1"): "foo", + MountPrefix + "mount1": "foo", }, error: "invalid mount annotation", }, { name: "no name", annotations: map[string]string{ - MountPrefix + "//source": "foo", + MountPrefix + ".source": "foo", }, error: "invalid mount name", }, { name: "missing source", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", }, error: "source field", }, { name: "missing type", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.share": "pod", }, error: "type field", }, { name: "missing share", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", }, error: "share field", }, { name: "invalid field name", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "invalid"): "foo", + MountPrefix + "mount1.invalid": "foo", }, error: "invalid mount annotation", }, { name: "invalid source", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", }, error: "source cannot be empty", }, { name: "invalid type", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "invalid-type", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "invalid-type", + MountPrefix + "mount1.share": "pod", }, error: "invalid type", }, { name: "invalid share", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "invalid-share", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "invalid-share", }, error: "invalid share", }, { name: "invalid options", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", - path.Join(MountPrefix, "mount1", "options"): "invalid-option", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", + MountPrefix + "mount1.options": "invalid-option", }, error: "unknown mount option", }, { name: "duplicate source", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): "foo", - path.Join(MountPrefix, "mount1", "type"): "tmpfs", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": "foo", + MountPrefix + "mount1.type": "tmpfs", + MountPrefix + "mount1.share": "pod", - path.Join(MountPrefix, "mount2", "source"): "foo", - path.Join(MountPrefix, "mount2", "type"): "bind", - path.Join(MountPrefix, "mount2", "share"): "container", + MountPrefix + "mount2.source": "foo", + MountPrefix + "mount2.type": "bind", + MountPrefix + "mount2.share": "container", }, error: "have the same mount source", }, @@ -202,36 +201,36 @@ func TestGetMountAccessType(t *testing.T) { { name: "container=exclusive", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): source, - path.Join(MountPrefix, "mount1", "type"): "bind", - path.Join(MountPrefix, "mount1", "share"): "container", + MountPrefix + "mount1.source": source, + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "container", }, want: FileAccessExclusive, }, { name: "pod=shared", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): source, - path.Join(MountPrefix, "mount1", "type"): "bind", - path.Join(MountPrefix, "mount1", "share"): "pod", + MountPrefix + "mount1.source": source, + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "pod", }, want: FileAccessShared, }, { name: "shared=shared", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): source, - path.Join(MountPrefix, "mount1", "type"): "bind", - path.Join(MountPrefix, "mount1", "share"): "shared", + MountPrefix + "mount1.source": source, + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "shared", }, want: FileAccessShared, }, { name: "default=shared", annotations: map[string]string{ - path.Join(MountPrefix, "mount1", "source"): source + "mismatch", - path.Join(MountPrefix, "mount1", "type"): "bind", - path.Join(MountPrefix, "mount1", "share"): "container", + MountPrefix + "mount1.source": source + "mismatch", + MountPrefix + "mount1.type": "bind", + MountPrefix + "mount1.share": "container", }, want: FileAccessShared, }, diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index df6052c88..bc1d0c1bb 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -93,10 +93,6 @@ type Loader struct { // spec is the base configuration for the root container. spec *specs.Spec - // startSignalForwarding enables forwarding of signals to the sandboxed - // container. It should be called after the init process is loaded. - startSignalForwarding func() func() - // stopSignalForwarding disables forwarding of signals to the sandboxed // container. It should be called when a sandbox is destroyed. stopSignalForwarding func() @@ -336,29 +332,6 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("ignore child stop signals failed: %v", err) } - // Handle signals by forwarding them to the root container process - // (except for panic signal, which should cause a panic). - l.startSignalForwarding = sighandling.PrepareHandler(func(sig linux.Signal) { - // Panic signal should cause a panic. - if args.Conf.PanicSignal != -1 && sig == linux.Signal(args.Conf.PanicSignal) { - panic("Signal-induced panic") - } - - // Otherwise forward to root container. - deliveryMode := DeliverToProcess - if args.Console { - // Since we are running with a console, we should - // forward the signal to the foreground process group - // so that job control signals like ^C can be handled - // properly. - deliveryMode = DeliverToForegroundProcessGroup - } - log.Infof("Received external signal %d, mode: %v", sig, deliveryMode) - if err := l.signal(args.ID, 0, int32(sig), deliveryMode); err != nil { - log.Warningf("error sending signal %v to container %q: %v", sig, args.ID, err) - } - }) - // Create the control server using the provided FD. // // This must be done *after* we have initialized the kernel since the @@ -566,8 +539,27 @@ func (l *Loader) run() error { ep.tty.InitForegroundProcessGroup(ep.tg.ProcessGroup()) } - // Start signal forwarding only after an init process is created. - l.stopSignalForwarding = l.startSignalForwarding() + // Handle signals by forwarding them to the root container process + // (except for panic signal, which should cause a panic). + l.stopSignalForwarding = sighandling.StartSignalForwarding(func(sig linux.Signal) { + // Panic signal should cause a panic. + if l.conf.PanicSignal != -1 && sig == linux.Signal(l.conf.PanicSignal) { + panic("Signal-induced panic") + } + + // Otherwise forward to root container. + deliveryMode := DeliverToProcess + if l.console { + // Since we are running with a console, we should forward the signal to + // the foreground process group so that job control signals like ^C can + // be handled properly. + deliveryMode = DeliverToForegroundProcessGroup + } + log.Infof("Received external signal %d, mode: %v", sig, deliveryMode) + if err := l.signal(l.sandboxID, 0, int32(sig), deliveryMode); err != nil { + log.Warningf("error sending signal %v to container %q: %v", sig, l.sandboxID, err) + } + }) log.Infof("Process should have started...") l.watchdog.Start() diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index 7313e473f..f37415810 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -32,16 +32,17 @@ import ( // Debug implements subcommands.Command for the "debug" command. type Debug struct { - pid int - stacks bool - signal int - profileHeap string - profileCPU string - profileDelay int - trace string - strace string - logLevel string - logPackets string + pid int + stacks bool + signal int + profileHeap string + profileCPU string + trace string + strace string + logLevel string + logPackets string + duration time.Duration + ps bool } // Name implements subcommands.Command. @@ -65,12 +66,13 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log") f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.") f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.") - f.IntVar(&d.profileDelay, "profile-delay", 5, "amount of time to wait before stoping CPU profile") + f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles") f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.") f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox") f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all`) f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).") f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.") + f.BoolVar(&d.ps, "ps", false, "lists processes") } // Execute implements subcommands.Command.Execute. @@ -163,7 +165,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) if err := c.Sandbox.StartCPUProfile(f); err != nil { return Errorf(err.Error()) } - log.Infof("CPU profile started for %d sec, writing to %q", d.profileDelay, d.profileCPU) + log.Infof("CPU profile started for %v, writing to %q", d.duration, d.profileCPU) } if d.trace != "" { delay = true @@ -181,8 +183,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) if err := c.Sandbox.StartTrace(f); err != nil { return Errorf(err.Error()) } - log.Infof("Tracing started for %d sec, writing to %q", d.profileDelay, d.trace) - + log.Infof("Tracing started for %v, writing to %q", d.duration, d.trace) } if d.strace != "" || len(d.logLevel) != 0 || len(d.logPackets) != 0 { @@ -241,9 +242,20 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } log.Infof("Logging options changed") } + if d.ps { + pList, err := c.Processes() + if err != nil { + Fatalf("getting processes for container: %v", err) + } + o, err := control.ProcessListToJSON(pList) + if err != nil { + Fatalf("generating JSON: %v", err) + } + log.Infof(o) + } if delay { - time.Sleep(time.Duration(d.profileDelay) * time.Second) + time.Sleep(d.duration) } return subcommands.ExitSuccess diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 7d67c3a75..5ed131a7f 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -28,6 +28,7 @@ import ( "github.com/kr/pty" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/control" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/testutil" @@ -219,9 +220,9 @@ func TestJobControlSignalExec(t *testing.T) { // Make sure all the processes are running. expectedPL := []*control.Process{ // Root container process. - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, // Bash from exec process. - {PID: 2, Cmd: "bash"}, + {PID: 2, Cmd: "bash", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(c, expectedPL); err != nil { t.Error(err) @@ -231,7 +232,7 @@ func TestJobControlSignalExec(t *testing.T) { ptyMaster.Write([]byte("sleep 100\n")) // Wait for it to start. Sleep's PPID is bash's PID. - expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep"}) + expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}}) if err := waitForProcessList(c, expectedPL); err != nil { t.Error(err) } @@ -361,7 +362,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { // Wait for bash to start. expectedPL := []*control.Process{ - {PID: 1, Cmd: "bash"}, + {PID: 1, Cmd: "bash", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(c, expectedPL); err != nil { t.Fatal(err) @@ -371,7 +372,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { ptyMaster.Write([]byte("sleep 100\n")) // Wait for sleep to start. - expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep"}) + expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{2}}) if err := waitForProcessList(c, expectedPL); err != nil { t.Fatal(err) } diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 07eacaac0..2ced028f6 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -37,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" @@ -52,13 +53,14 @@ func waitForProcessList(cont *Container, want []*control.Process) error { err = fmt.Errorf("error getting process data from container: %v", err) return &backoff.PermanentError{Err: err} } - if !procListsEqual(got, want) { - return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want)) + if r, err := procListsEqual(got, want); !r { + return fmt.Errorf("container got process list: %s, want: %s: error: %v", + procListToString(got), procListToString(want), err) } return nil } // Gives plenty of time as tests can run slow under --race. - return testutil.Poll(cb, 30*time.Second) + return testutil.Poll(cb, 10*time.Second) } func waitForProcessCount(cont *Container, want int) error { @@ -91,22 +93,34 @@ func blockUntilWaitable(pid int) error { // procListsEqual is used to check whether 2 Process lists are equal for all // implemented fields. -func procListsEqual(got, want []*control.Process) bool { +func procListsEqual(got, want []*control.Process) (bool, error) { if len(got) != len(want) { - return false + return false, nil } for i := range got { pd1 := got[i] pd2 := want[i] - // Zero out unimplemented and timing dependant fields. + // Zero out timing dependant fields. pd1.Time = "" pd1.STime = "" pd1.C = 0 - if *pd1 != *pd2 { - return false + // Ignore TTY field too, since it's not relevant in the cases + // where we use this method. Tests that care about the TTY + // field should check for it themselves. + pd1.TTY = "" + pd1Json, err := control.ProcessListToJSON([]*control.Process{pd1}) + if err != nil { + return false, err + } + pd2Json, err := control.ProcessListToJSON([]*control.Process{pd2}) + if err != nil { + return false, err + } + if pd1Json != pd2Json { + return false, nil } } - return true + return true, nil } // getAndCheckProcLists is similar to waitForProcessList, but does not wait and retry the @@ -116,7 +130,11 @@ func getAndCheckProcLists(cont *Container, want []*control.Process) error { if err != nil { return fmt.Errorf("error getting process data from container: %v", err) } - if procListsEqual(got, want) { + equal, err := procListsEqual(got, want) + if err != nil { + return err + } + if equal { return nil } return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want)) @@ -288,11 +306,12 @@ func TestLifecycle(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, } // Create the container. @@ -590,18 +609,20 @@ func TestExec(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: uid, + PID: 2, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{2}, }, } @@ -1062,18 +1083,20 @@ func TestPauseResume(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "bash", + UID: uid, + PID: 2, + PPID: 0, + C: 0, + Cmd: "bash", + Threads: []kernel.ThreadID{2}, }, } @@ -1126,11 +1149,12 @@ func TestPauseResume(t *testing.T) { expectedPL2 := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, } @@ -1241,18 +1265,20 @@ func TestCapabilities(t *testing.T) { // expectedPL lists the expected process state of the container. expectedPL := []*control.Process{ { - UID: 0, - PID: 1, - PPID: 0, - C: 0, - Cmd: "sleep", + UID: 0, + PID: 1, + PPID: 0, + C: 0, + Cmd: "sleep", + Threads: []kernel.ThreadID{1}, }, { - UID: uid, - PID: 2, - PPID: 0, - C: 0, - Cmd: "exe", + UID: uid, + PID: 2, + PPID: 0, + C: 0, + Cmd: "exe", + Threads: []kernel.ThreadID{2}, }, } if err := waitForProcessList(cont, expectedPL[:1]); err != nil { @@ -2112,6 +2138,95 @@ func TestOverlayfsStaleRead(t *testing.T) { } } +// TestTTYField checks TTY field returned by container.Processes(). +func TestTTYField(t *testing.T) { + stop := testutil.StartReaper() + defer stop() + + testApp, err := testutil.FindFile("runsc/container/test_app/test_app") + if err != nil { + t.Fatal("error finding test_app:", err) + } + + testCases := []struct { + name string + useTTY bool + wantTTYField string + }{ + { + name: "no tty", + useTTY: false, + wantTTYField: "?", + }, + { + name: "tty used", + useTTY: true, + wantTTYField: "pts/0", + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + conf := testutil.TestConfig() + + // We will run /bin/sleep, possibly with an open TTY. + cmd := []string{"/bin/sleep", "10000"} + if test.useTTY { + // Run inside the "pty-runner". + cmd = append([]string{testApp, "pty-runner"}, cmd...) + } + + spec := testutil.NewSpecWithArgs(cmd...) + rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) + if err != nil { + t.Fatalf("error setting up container: %v", err) + } + defer os.RemoveAll(rootDir) + defer os.RemoveAll(bundleDir) + + // Create and start the container. + args := Args{ + ID: testutil.UniqueContainerID(), + Spec: spec, + BundleDir: bundleDir, + } + c, err := New(conf, args) + if err != nil { + t.Fatalf("error creating container: %v", err) + } + defer c.Destroy() + if err := c.Start(conf); err != nil { + t.Fatalf("error starting container: %v", err) + } + + // Wait for sleep to be running, and check the TTY + // field. + var gotTTYField string + cb := func() error { + ps, err := c.Processes() + if err != nil { + err = fmt.Errorf("error getting process data from container: %v", err) + return &backoff.PermanentError{Err: err} + } + for _, p := range ps { + if strings.Contains(p.Cmd, "sleep") { + gotTTYField = p.TTY + return nil + } + } + return fmt.Errorf("sleep not running") + } + if err := testutil.Poll(cb, 30*time.Second); err != nil { + t.Fatalf("error waiting for sleep process: %v", err) + } + + if gotTTYField != test.wantTTYField { + t.Errorf("tty field got %q, want %q", gotTTYField, test.wantTTYField) + } + }) + } +} + // executeSync synchronously executes a new process. func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { pid, err := cont.Execute(args) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index a5a62378c..4ad09ceab 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -123,11 +123,11 @@ func execMany(execs []execDesc) error { func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { for _, spec := range pod { - spec.Annotations[path.Join(boot.MountPrefix, name, "source")] = mount.Source - spec.Annotations[path.Join(boot.MountPrefix, name, "type")] = mount.Type - spec.Annotations[path.Join(boot.MountPrefix, name, "share")] = "pod" + spec.Annotations[boot.MountPrefix+name+".source"] = mount.Source + spec.Annotations[boot.MountPrefix+name+".type"] = mount.Type + spec.Annotations[boot.MountPrefix+name+".share"] = "pod" if len(mount.Options) > 0 { - spec.Annotations[path.Join(boot.MountPrefix, name, "options")] = strings.Join(mount.Options, ",") + spec.Annotations[boot.MountPrefix+name+".options"] = strings.Join(mount.Options, ",") } } } @@ -156,13 +156,13 @@ func TestMultiContainerSanity(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) } expectedPL = []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -202,13 +202,13 @@ func TestMultiPIDNS(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) } expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -264,7 +264,7 @@ func TestMultiPIDNSPath(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -274,7 +274,7 @@ func TestMultiPIDNSPath(t *testing.T) { } expectedPL = []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -306,7 +306,7 @@ func TestMultiContainerWait(t *testing.T) { // Check via ps that multiple processes are running. expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -351,7 +351,7 @@ func TestMultiContainerWait(t *testing.T) { // After Wait returns, ensure that the root container is running and // the child has finished. expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for %q to start: %v", strings.Join(containers[0].Spec.Process.Args, " "), err) @@ -383,7 +383,7 @@ func TestExecWait(t *testing.T) { // Check via ps that process is running. expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { t.Fatalf("failed to wait for sleep to start: %v", err) @@ -418,7 +418,7 @@ func TestExecWait(t *testing.T) { // Wait for the exec'd process to exit. expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Fatalf("failed to wait for second container to stop: %v", err) @@ -505,7 +505,7 @@ func TestMultiContainerSignal(t *testing.T) { // Check via ps that container 1 process is running. expectedPL := []*control.Process{ - {PID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}}, } if err := waitForProcessList(containers[1], expectedPL); err != nil { @@ -519,7 +519,7 @@ func TestMultiContainerSignal(t *testing.T) { // Make sure process 1 is still running. expectedPL = []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -633,9 +633,10 @@ func TestMultiContainerDestroy(t *testing.T) { if err != nil { t.Fatalf("error getting process data from sandbox: %v", err) } - expectedPL := []*control.Process{{PID: 1, Cmd: "sleep"}} - if !procListsEqual(pss, expectedPL) { - t.Errorf("container got process list: %s, want: %s", procListToString(pss), procListToString(expectedPL)) + expectedPL := []*control.Process{{PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}} + if r, err := procListsEqual(pss, expectedPL); !r { + t.Errorf("container got process list: %s, want: %s: error: %v", + procListToString(pss), procListToString(expectedPL), err) } // Check that cont.Destroy is safe to call multiple times. @@ -669,7 +670,7 @@ func TestMultiContainerProcesses(t *testing.T) { // Check root's container process list doesn't include other containers. expectedPL0 := []*control.Process{ - {PID: 1, Cmd: "sleep"}, + {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}, } if err := waitForProcessList(containers[0], expectedPL0); err != nil { t.Errorf("failed to wait for process to start: %v", err) @@ -677,8 +678,8 @@ func TestMultiContainerProcesses(t *testing.T) { // Same for the other container. expectedPL1 := []*control.Process{ - {PID: 2, Cmd: "sh"}, - {PID: 3, PPID: 2, Cmd: "sleep"}, + {PID: 2, Cmd: "sh", Threads: []kernel.ThreadID{2}}, + {PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}}, } if err := waitForProcessList(containers[1], expectedPL1); err != nil { t.Errorf("failed to wait for process to start: %v", err) @@ -692,7 +693,7 @@ func TestMultiContainerProcesses(t *testing.T) { if _, err := containers[1].Execute(args); err != nil { t.Fatalf("error exec'ing: %v", err) } - expectedPL1 = append(expectedPL1, &control.Process{PID: 4, Cmd: "sleep"}) + expectedPL1 = append(expectedPL1, &control.Process{PID: 4, Cmd: "sleep", Threads: []kernel.ThreadID{4}}) if err := waitForProcessList(containers[1], expectedPL1); err != nil { t.Errorf("failed to wait for process to start: %v", err) } @@ -1513,7 +1514,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { // Ensure container is running c := containers[2] expectedPL := []*control.Process{ - {PID: 3, Cmd: "sleep"}, + {PID: 3, Cmd: "sleep", Threads: []kernel.ThreadID{3}}, } if err := waitForProcessList(c, expectedPL); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) @@ -1541,7 +1542,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { continue // container[2] has been killed. } pl := []*control.Process{ - {PID: kernel.ThreadID(i + 1), Cmd: "sleep"}, + {PID: kernel.ThreadID(i + 1), Cmd: "sleep", Threads: []kernel.ThreadID{kernel.ThreadID(i + 1)}}, } if err := waitForProcessList(c, pl); err != nil { t.Errorf("Container %q was affected by another container: %v", c.ID, err) @@ -1561,7 +1562,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { // Wait until sandbox stops. waitForProcessList will loop until sandbox exits // and RPC errors out. impossiblePL := []*control.Process{ - {PID: 100, Cmd: "non-existent-process"}, + {PID: 100, Cmd: "non-existent-process", Threads: []kernel.ThreadID{100}}, } if err := waitForProcessList(c, impossiblePL); err == nil { t.Fatalf("Sandbox was not killed after gofer death") diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD index 9bf9e6e9d..bfd338bb6 100644 --- a/runsc/container/test_app/BUILD +++ b/runsc/container/test_app/BUILD @@ -15,5 +15,6 @@ go_binary( "//pkg/unet", "//runsc/testutil", "@com_github_google_subcommands//:go_default_library", + "@com_github_kr_pty//:go_default_library", ], ) diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go index 913d781c6..a1c8a741a 100644 --- a/runsc/container/test_app/test_app.go +++ b/runsc/container/test_app/test_app.go @@ -19,6 +19,7 @@ package main import ( "context" "fmt" + "io" "io/ioutil" "log" "net" @@ -31,6 +32,7 @@ import ( "flag" "github.com/google/subcommands" + "github.com/kr/pty" "gvisor.dev/gvisor/runsc/testutil" ) @@ -41,6 +43,7 @@ func main() { subcommands.Register(new(fdReceiver), "") subcommands.Register(new(fdSender), "") subcommands.Register(new(forkBomb), "") + subcommands.Register(new(ptyRunner), "") subcommands.Register(new(reaper), "") subcommands.Register(new(syscall), "") subcommands.Register(new(taskTree), "") @@ -352,3 +355,40 @@ func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...inter return subcommands.ExitSuccess } + +type ptyRunner struct{} + +// Name implements subcommands.Command. +func (*ptyRunner) Name() string { + return "pty-runner" +} + +// Synopsis implements subcommands.Command. +func (*ptyRunner) Synopsis() string { + return "runs the given command with an open pty terminal" +} + +// Usage implements subcommands.Command. +func (*ptyRunner) Usage() string { + return "pty-runner [command]" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (*ptyRunner) SetFlags(f *flag.FlagSet) {} + +// Execute implements subcommands.Command. +func (*ptyRunner) Execute(_ context.Context, fs *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + c := exec.Command(fs.Args()[0], fs.Args()[1:]...) + f, err := pty.Start(c) + if err != nil { + fmt.Printf("pty.Start failed: %v", err) + return subcommands.ExitFailure + } + defer f.Close() + + // Copy stdout from the command to keep this process alive until the + // subprocess exits. + io.Copy(os.Stdout, f) + + return subcommands.ExitSuccess +} diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index c9add64ec..b59e1a70e 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -199,6 +199,7 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // The reason that the file is not opened initially as read-write is for better // performance with 'overlay2' storage driver. overlay2 eagerly copies the // entire file up when it's opened in write mode, and would perform badly when +// multiple files are only being opened for read (esp. startup). type localFile struct { p9.DefaultWalkGetAttr diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index ee9327fc8..805233184 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -1004,16 +1004,22 @@ func (s *Sandbox) ChangeLogging(args control.LoggingArgs) error { // DestroyContainer destroys the given container. If it is the root container, // then the entire sandbox is destroyed. func (s *Sandbox) DestroyContainer(cid string) error { + if err := s.destroyContainer(cid); err != nil { + // If the sandbox isn't running, the container has already been destroyed, + // ignore the error in this case. + if s.IsRunning() { + return err + } + } + return nil +} + +func (s *Sandbox) destroyContainer(cid string) error { if s.IsRootContainer(cid) { log.Debugf("Destroying root container %q by destroying sandbox", cid) return s.destroy() } - if !s.IsRunning() { - // Sandbox isn't running anymore, container is already destroyed. - return nil - } - log.Debugf("Destroying container %q in sandbox %q", cid, s.ID) conn, err := s.sandboxConnect() if err != nil { diff --git a/scripts/benchmarks.sh b/scripts/benchmarks.sh new file mode 100755 index 000000000..6b9065b07 --- /dev/null +++ b/scripts/benchmarks.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env bash + +if [ "$#" -lt "1" ]; then + echo "usage: $0 <--mock |--env=<filename>> ..." + echo "example: $0 --mock --runs=8" + exit 1 +fi + +source $(dirname $0)/common.sh + +readonly TIMESTAMP=`date "+%Y%m%d-%H%M%S"` +readonly OUTDIR="$(mktemp --tmpdir -d run-${TIMESTAMP}-XXX)" +readonly DEFAULT_RUNTIMES="--runtime=runc --runtime=runsc --runtime=runsc-kvm" +readonly ALL_RUNTIMES="--runtime=runc --runtime=runsc --runtime=runsc-kvm" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'fio.(read|write)' --metric=bandwidth --size=5g --ioengine=sync --blocksize=1m > "${OUTDIR}/fio.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} fio.rand --metric=bandwidth --size=5g --ioengine=sync --blocksize=4k --time=30 > "${OUTDIR}/tmp_fio.csv" +cat "${OUTDIR}/tmp_fio.csv" | grep "\(runc\|runsc\)" >> "${OUTDIR}/fio.csv" && rm "${OUTDIR}/tmp_fio.csv" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'fio.(read|write)' --metric=bandwidth --tmpfs=True --size=5g --ioengine=sync --blocksize=1m > "${OUTDIR}/fio-tmpfs.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} fio.rand --metric=bandwidth --tmpfs=True --size=5g --ioengine=sync --blocksize=4k --time=30 > "${OUTDIR}/tmp_fio.csv" +cat "${OUTDIR}/tmp_fio.csv" | grep "\(runc\|runsc\)" >> "${OUTDIR}/fio-tmpfs.csv" && rm "${OUTDIR}/tmp_fio.csv" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} startup --count=50 > "${OUTDIR}/startup.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} density > "${OUTDIR}/density.csv" + +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} sysbench.cpu --threads=1 --max_prime=50000 --options='--max-time=5' > "${OUTDIR}/sysbench-cpu.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} sysbench.memory --threads=1 --options='--memory-block-size=1M --memory-total-size=500G' > "${OUTDIR}/sysbench-memory.csv" +run //benchmarks:perf -- run "$@" ${ALL_RUNTIMES} syscall > "${OUTDIR}/syscall.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'network.(upload|download)' --runs=20 > "${OUTDIR}/iperf.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} ml.tensorflow > "${OUTDIR}/tensorflow.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} media.ffmpeg > "${OUTDIR}/ffmpeg.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} http.httpd --path=latin100k.txt --connections=1 --connections=5 --connections=10 --connections=25 > "${OUTDIR}/httpd100k.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} http.httpd --path=latin10240k.txt --connections=1 --connections=5 --connections=10 --connections=25 > "${OUTDIR}/httpd10240k.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} redis > "${OUTDIR}/redis.csv" +run //benchmarks:perf -- run "$@" ${DEFAULT_RUNTIMES} 'http.(ruby|node)' > "${OUTDIR}/applications.csv" + +echo "${OUTPUT}" && exit 0 diff --git a/scripts/dev.sh b/scripts/dev.sh index c67003018..6238b4d0b 100755 --- a/scripts/dev.sh +++ b/scripts/dev.sh @@ -54,9 +54,10 @@ declare OUTPUT="$(build //runsc)" if [[ ${REFRESH} -eq 0 ]]; then install_runsc "${RUNTIME}" --net-raw install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets + install_runsc "${RUNTIME}-p" --net-raw --profile echo - echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup." + echo "Runtimes ${RUNTIME}, ${RUNTIME}-d (debug enabled), and ${RUNTIME}-p installed." echo "Use --runtime="${RUNTIME}" with your Docker command." echo " docker run --rm --runtime="${RUNTIME}" hello-world" echo diff --git a/scripts/go.sh b/scripts/go.sh index 0dbfb7747..626ed8fa4 100755 --- a/scripts/go.sh +++ b/scripts/go.sh @@ -25,6 +25,8 @@ tools/go_branch.sh # Checkout the new branch. git checkout go && git clean -f +go version + # Build everything. go build ./... diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh index 585216aae..3a15050c2 100755 --- a/scripts/simple_tests.sh +++ b/scripts/simple_tests.sh @@ -17,4 +17,4 @@ source $(dirname $0)/common.sh # Run all simple tests (locally). -test //pkg/... //runsc/... //tools/... +test //pkg/... //runsc/... //tools/... //benchmarks/... //benchmarks/runner:runner_test diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 722d14b53..829693e8e 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -376,6 +376,8 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:rlimits_test") +syscall_test(test = "//test/syscalls/linux:rseq_test") + syscall_test(test = "//test/syscalls/linux:rtsignal_test") syscall_test(test = "//test/syscalls/linux:sched_test") @@ -669,6 +671,7 @@ syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( size = "medium", + add_hostinet = True, shard_count = 10, test = "//test/syscalls/linux:udp_socket_test", ) diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index dcf5b73ed..aaf77c65b 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -9,6 +9,7 @@ def syscall_test( use_tmpfs = False, add_overlay = False, add_uds_tree = False, + add_hostinet = False, tags = None): _syscall_test( test = test, @@ -65,6 +66,18 @@ def syscall_test( file_access = "shared", ) + if add_hostinet: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "ptrace", + use_tmpfs = use_tmpfs, + network = "host", + add_uds_tree = add_uds_tree, + tags = tags, + ) + def _syscall_test( test, shard_count, @@ -72,6 +85,7 @@ def _syscall_test( platform, use_tmpfs, tags, + network = "none", file_access = "exclusive", overlay = False, add_uds_tree = False): @@ -85,6 +99,8 @@ def _syscall_test( name += "_shared" if overlay: name += "_overlay" + if network != "none": + name += "_" + network + "net" if tags == None: tags = [] @@ -107,6 +123,7 @@ def _syscall_test( # Arguments are passed directly to syscall_test_runner binary. "--test-name=" + test_name, "--platform=" + platform, + "--network=" + network, "--use-tmpfs=" + str(use_tmpfs), "--file-access=" + file_access, "--overlay=" + str(overlay), diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 2dd115409..6ea922fb4 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -6,6 +6,16 @@ package( licenses = ["notice"], ) +exports_files( + [ + "socket.cc", + "socket_ipv4_udp_unbound_loopback.cc", + "tcp_socket.cc", + "udp_socket.cc", + ], + visibility = ["//:sandbox"], +) + cc_binary( name = "sigaltstack_check", testonly = 1, @@ -743,6 +753,7 @@ cc_binary( "//test/util:eventfd_util", "//test/util:multiprocess_util", "//test/util:posix_error", + "//test/util:save_util", "//test/util:temp_path", "//test/util:test_util", "//test/util:timer_util", @@ -1795,7 +1806,6 @@ cc_binary( name = "readv_socket_test", testonly = 1, srcs = [ - "file_base.h", "readv_common.cc", "readv_common.h", "readv_socket.cc", @@ -1843,6 +1853,22 @@ cc_binary( ) cc_binary( + name = "rseq_test", + testonly = 1, + srcs = ["rseq.cc"], + data = ["//test/syscalls/linux/rseq"], + linkstatic = 1, + deps = [ + "//test/syscalls/linux/rseq:lib", + "//test/util:logging", + "//test/util:multiprocess_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "rtsignal_test", testonly = 1, srcs = ["rtsignal.cc"], @@ -2142,6 +2168,7 @@ cc_library( ":socket_test_util", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], alwayslink = 1, @@ -3245,8 +3272,6 @@ cc_binary( testonly = 1, srcs = ["tcp_socket.cc"], linkstatic = 1, - # FIXME(b/135470853) - tags = ["flaky"], deps = [ ":socket_test_util", "//test/util:file_descriptor", diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc index b27d4e10a..a33daff17 100644 --- a/test/syscalls/linux/aio.cc +++ b/test/syscalls/linux/aio.cc @@ -129,7 +129,7 @@ TEST_F(AIOTest, BasicWrite) { // aio implementation uses aio_ring. gVisor doesn't and returns all zeroes. // Linux implements aio_ring, so skip the zeroes check. // - // TODO(b/65486370): Remove when gVisor implements aio_ring. + // TODO(gvisor.dev/issue/204): Remove when gVisor implements aio_ring. auto ring = reinterpret_cast<struct aio_ring*>(ctx_); auto magic = IsRunningOnGvisor() ? 0 : AIO_RING_MAGIC; EXPECT_EQ(ring->magic, magic); diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index 581f03533..b5e0a512b 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -47,23 +47,14 @@ namespace testing { namespace { -constexpr char kBasicWorkload[] = "exec_basic_workload"; -constexpr char kExitScript[] = "exit_script"; -constexpr char kStateWorkload[] = "exec_state_workload"; -constexpr char kProcExeWorkload[] = "exec_proc_exe_workload"; -constexpr char kAssertClosedWorkload[] = "exec_assert_closed_workload"; -constexpr char kPriorityWorkload[] = "priority_execve"; - -std::string WorkloadPath(absl::string_view binary) { - std::string full_path; - char* test_src = getenv("TEST_SRCDIR"); - if (test_src) { - full_path = JoinPath(test_src, "__main__/test/syscalls/linux", binary); - } - - TEST_CHECK(full_path.empty() == false); - return full_path; -} +constexpr char kBasicWorkload[] = "test/syscalls/linux/exec_basic_workload"; +constexpr char kExitScript[] = "test/syscalls/linux/exit_script"; +constexpr char kStateWorkload[] = "test/syscalls/linux/exec_state_workload"; +constexpr char kProcExeWorkload[] = + "test/syscalls/linux/exec_proc_exe_workload"; +constexpr char kAssertClosedWorkload[] = + "test/syscalls/linux/exec_assert_closed_workload"; +constexpr char kPriorityWorkload[] = "test/syscalls/linux/priority_execve"; constexpr char kExit42[] = "--exec_exit_42"; constexpr char kExecWithThread[] = "--exec_exec_with_thread"; @@ -171,44 +162,44 @@ TEST(ExecTest, EmptyPath) { } TEST(ExecTest, Basic) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n")); } TEST(ExecTest, OneArg) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"}, - {}, ArgEnvExitStatus(1, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "1"}, {}, + ArgEnvExitStatus(1, 0), + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); } TEST(ExecTest, FiveArg) { - CheckExec(WorkloadPath(kBasicWorkload), - {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, + CheckExec(RunfilePath(kBasicWorkload), + {RunfilePath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, ArgEnvExitStatus(5, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } TEST(ExecTest, OneEnv) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {"1"}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1"}, ArgEnvExitStatus(0, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); } TEST(ExecTest, FiveEnv) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } TEST(ExecTest, OneArgOneEnv) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "arg"}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "arg"}, {"env"}, ArgEnvExitStatus(1, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\narg\nenv\n")); } TEST(ExecTest, InterpreterScript) { - CheckExec(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {}, + CheckExec(RunfilePath(kExitScript), {RunfilePath(kExitScript), "25"}, {}, ArgEnvExitStatus(25, 0), ""); } @@ -216,7 +207,7 @@ TEST(ExecTest, InterpreterScript) { TEST(ExecTest, InterpreterScriptArgSplit) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"), @@ -230,7 +221,7 @@ TEST(ExecTest, InterpreterScriptArgSplit) { TEST(ExecTest, InterpreterScriptArgvZero) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); @@ -244,7 +235,7 @@ TEST(ExecTest, InterpreterScriptArgvZero) { TEST(ExecTest, InterpreterScriptArgvZeroRelative) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); @@ -261,7 +252,7 @@ TEST(ExecTest, InterpreterScriptArgvZeroRelative) { TEST(ExecTest, InterpreterScriptArgvZeroAdded) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); @@ -274,7 +265,7 @@ TEST(ExecTest, InterpreterScriptArgvZeroAdded) { TEST(ExecTest, InterpreterScriptArgNUL) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), @@ -289,7 +280,7 @@ TEST(ExecTest, InterpreterScriptArgNUL) { TEST(ExecTest, InterpreterScriptTrailingWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755)); @@ -302,7 +293,7 @@ TEST(ExecTest, InterpreterScriptTrailingWhitespace) { TEST(ExecTest, InterpreterScriptArgWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755)); @@ -325,7 +316,7 @@ TEST(ExecTest, InterpreterScriptNoPath) { TEST(ExecTest, ExecFn) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " PrintExecFn"), @@ -342,7 +333,7 @@ TEST(ExecTest, ExecFn) { } TEST(ExecTest, ExecName) { - std::string path = WorkloadPath(kStateWorkload); + std::string path = RunfilePath(kStateWorkload); CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0), absl::StrCat(Basename(path).substr(0, 15), "\n")); @@ -351,7 +342,7 @@ TEST(ExecTest, ExecName) { TEST(ExecTest, ExecNameScript) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), @@ -405,13 +396,13 @@ TEST(ExecStateTest, HandlerReset) { ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigHandler", absl::StrCat(SIGUSR1), absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))), }; - CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Ignored signal dispositions are not reset. @@ -421,13 +412,13 @@ TEST(ExecStateTest, IgnorePreserved) { ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigHandler", absl::StrCat(SIGUSR1), absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))), }; - CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Signal masks are not reset on exec @@ -438,12 +429,12 @@ TEST(ExecStateTest, SignalMask) { ASSERT_THAT(sigprocmask(SIG_BLOCK, &s, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigBlocked", absl::StrCat(SIGUSR1), }; - CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // itimers persist across execve. @@ -471,7 +462,7 @@ TEST(ExecStateTest, ItimerPreserved) { } }; - std::string filename = WorkloadPath(kStateWorkload); + std::string filename = RunfilePath(kStateWorkload); ExecveArray argv = { filename, "CheckItimerEnabled", @@ -495,8 +486,8 @@ TEST(ExecStateTest, ItimerPreserved) { TEST(ProcSelfExe, ChangesAcrossExecve) { // See exec_proc_exe_workload for more details. We simply // assert that the /proc/self/exe link changes across execve. - CheckExec(WorkloadPath(kProcExeWorkload), - {WorkloadPath(kProcExeWorkload), + CheckExec(RunfilePath(kProcExeWorkload), + {RunfilePath(kProcExeWorkload), ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))}, {}, W_EXITCODE(0, 0), ""); } @@ -507,8 +498,8 @@ TEST(ExecTest, CloexecNormalFile) { const FileDescriptor fd_closed_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); - CheckExec(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), + CheckExec(RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_closed_on_exec.get())}, {}, W_EXITCODE(0, 0), ""); @@ -517,10 +508,10 @@ TEST(ExecTest, CloexecNormalFile) { const FileDescriptor fd_open_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY)); - CheckExec(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), - absl::StrCat(fd_open_on_exec.get())}, - {}, W_EXITCODE(2, 0), ""); + CheckExec( + RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_open_on_exec.get())}, + {}, W_EXITCODE(2, 0), ""); } TEST(ExecTest, CloexecEventfd) { @@ -528,15 +519,15 @@ TEST(ExecTest, CloexecEventfd) { ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds()); FileDescriptor fd(efd); - CheckExec(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, + CheckExec(RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, W_EXITCODE(0, 0), ""); } constexpr int kLinuxMaxSymlinks = 40; TEST(ExecTest, SymlinkLimitExceeded) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); // Hold onto TempPath objects so they are not destructed prematurely. std::vector<TempPath> symlinks; @@ -575,13 +566,13 @@ TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { } TEST(ExecveatTest, BasicWithFDCWD) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); } TEST(ExecveatTest, Basic) { - std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string absolute_path = RunfilePath(kBasicWorkload); std::string parent_dir = std::string(Dirname(absolute_path)); std::string base = std::string(Basename(absolute_path)); const FileDescriptor dirfd = @@ -592,7 +583,7 @@ TEST(ExecveatTest, Basic) { } TEST(ExecveatTest, FDNotADirectory) { - std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string absolute_path = RunfilePath(kBasicWorkload); std::string base = std::string(Basename(absolute_path)); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0)); @@ -604,13 +595,13 @@ TEST(ExecveatTest, FDNotADirectory) { } TEST(ExecveatTest, AbsolutePathWithFDCWD) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, absl::StrCat(path, "\n")); } TEST(ExecveatTest, AbsolutePath) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); // File descriptor should be ignored when an absolute path is given. const int32_t badFD = -1; CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, @@ -618,7 +609,7 @@ TEST(ExecveatTest, AbsolutePath) { } TEST(ExecveatTest, EmptyPathBasic) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0), @@ -626,7 +617,7 @@ TEST(ExecveatTest, EmptyPathBasic) { } TEST(ExecveatTest, EmptyPathWithDirFD) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); std::string parent_dir = std::string(Dirname(path)); const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); @@ -639,7 +630,7 @@ TEST(ExecveatTest, EmptyPathWithDirFD) { } TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); int execve_errno; @@ -649,7 +640,7 @@ TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { } TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH, @@ -657,7 +648,7 @@ TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { } TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { - std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string absolute_path = RunfilePath(kBasicWorkload); std::string parent_dir = std::string(Dirname(absolute_path)); std::string base = std::string(Basename(absolute_path)); const FileDescriptor dirfd = @@ -670,7 +661,7 @@ TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { std::string parent_dir = "/tmp"; TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); std::string base = std::string(Basename(link.path())); @@ -685,7 +676,7 @@ TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { std::string parent_dir = "/tmp"; TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); std::string path = link.path(); int execve_errno; @@ -697,7 +688,7 @@ TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) { TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); std::string path = link.path(); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0)); @@ -723,7 +714,7 @@ TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) { } TEST(ExecveatTest, BasicWithCloexecFD) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH, @@ -731,7 +722,7 @@ TEST(ExecveatTest, BasicWithCloexecFD) { } TEST(ExecveatTest, InterpreterScriptWithCloexecFD) { - std::string path = WorkloadPath(kExitScript); + std::string path = RunfilePath(kExitScript); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); int execve_errno; @@ -742,7 +733,7 @@ TEST(ExecveatTest, InterpreterScriptWithCloexecFD) { } TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) { - std::string absolute_path = WorkloadPath(kExitScript); + std::string absolute_path = RunfilePath(kExitScript); std::string parent_dir = std::string(Dirname(absolute_path)); std::string base = std::string(Basename(absolute_path)); const FileDescriptor dirfd = @@ -775,7 +766,7 @@ TEST(GetpriorityTest, ExecveMaintainsPriority) { // Program run (priority_execve) will exit(X) where // X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio. - CheckExec(WorkloadPath(kPriorityWorkload), {WorkloadPath(kPriorityWorkload)}, + CheckExec(RunfilePath(kPriorityWorkload), {RunfilePath(kPriorityWorkload)}, {}, W_EXITCODE(expected_exit_code, 0), ""); } diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 8a45be12a..4f3aa81d6 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -14,6 +14,7 @@ #include <fcntl.h> #include <signal.h> +#include <sys/types.h> #include <syscall.h> #include <unistd.h> @@ -32,6 +33,7 @@ #include "test/util/eventfd_util.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/timer_util.h" @@ -910,8 +912,166 @@ TEST(FcntlTest, GetOwn) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), 0); + MaybeSave(); +} + +TEST(FcntlTest, GetOwnEx) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); +} + +TEST(FcntlTest, SetOwnExInvalidType) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = __pid_type(-1); + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(FcntlTest, SetOwnExInvalidTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_TID; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExInvalidPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExInvalidPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PGRP; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_TID; + EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceeds()); + + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceeds()); + + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PGRP; + EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceeds()); + + // NOTE(igudger): I don't understand why, but this is flaky on Linux. + // GetOwnExPgrp (below) does not have this issue. + SKIP_IF(!IsRunningOnGvisor()); + + EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), -owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, GetOwnExTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_TID; + EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceeds()); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); +} + +TEST(FcntlTest, GetOwnExPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PID; + EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceeds()); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); +} + +TEST(FcntlTest, GetOwnExPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PGRP; + EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceeds()); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); } } // namespace diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 4e048320e..6f80bc97c 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -111,95 +111,6 @@ class FileTest : public ::testing::Test { int test_pipe_[2]; }; -class SocketTest : public ::testing::Test { - public: - void SetUp() override { - test_unix_stream_socket_[0] = -1; - test_unix_stream_socket_[1] = -1; - test_unix_dgram_socket_[0] = -1; - test_unix_dgram_socket_[1] = -1; - test_unix_seqpacket_socket_[0] = -1; - test_unix_seqpacket_socket_[1] = -1; - test_tcp_socket_[0] = -1; - test_tcp_socket_[1] = -1; - - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT( - socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - } - - void TearDown() override { - close(test_unix_stream_socket_[0]); - close(test_unix_stream_socket_[1]); - - close(test_unix_dgram_socket_[0]); - close(test_unix_dgram_socket_[1]); - - close(test_unix_seqpacket_socket_[0]); - close(test_unix_seqpacket_socket_[1]); - - close(test_tcp_socket_[0]); - close(test_tcp_socket_[1]); - } - - int test_unix_stream_socket_[2]; - int test_unix_dgram_socket_[2]; - int test_unix_seqpacket_socket_[2]; - int test_tcp_socket_[2]; -}; - -// MatchesStringLength checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string length strlen. -MATCHER_P(MatchesStringLength, strlen, "") { - struct iovec* iovs = arg.first; - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - offset += iovs[i].iov_len; - } - if (offset != static_cast<int>(strlen)) { - *result_listener << offset; - return false; - } - return true; -} - -// MatchesStringValue checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string value str. -MATCHER_P(MatchesStringValue, str, "") { - struct iovec* iovs = arg.first; - int len = strlen(str); - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - struct iovec iov = iovs[i]; - if (len < offset) { - *result_listener << "strlen " << len << " < offset " << offset; - return false; - } - if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) { - absl::string_view iovec_string(static_cast<char*>(iov.iov_base), - iov.iov_len); - *result_listener << iovec_string << " @offset " << offset; - return false; - } - offset += iov.iov_len; - } - return true; -} - } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc index c4f8bff08..b0a07a064 100644 --- a/test/syscalls/linux/ioctl.cc +++ b/test/syscalls/linux/ioctl.cc @@ -215,7 +215,8 @@ TEST_F(IoctlTest, FIOASYNCSelfTarget2) { auto mask_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - pid_t pid = getpid(); + pid_t pid = -1; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); int set = 1; diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 072230d85..9cb4566db 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -26,25 +26,6 @@ namespace gvisor { namespace testing { -// Possible values of the "st" field in a /proc/net/{tcp,udp} entry. Source: -// Linux kernel, include/net/tcp_states.h. -enum { - TCP_ESTABLISHED = 1, - TCP_SYN_SENT, - TCP_SYN_RECV, - TCP_FIN_WAIT1, - TCP_FIN_WAIT2, - TCP_TIME_WAIT, - TCP_CLOSE, - TCP_CLOSE_WAIT, - TCP_LAST_ACK, - TCP_LISTEN, - TCP_CLOSING, - TCP_NEW_SYN_RECV, - - TCP_MAX_STATES -}; - // Extracts the IP address from an inet sockaddr in network byte order. uint32_t IPFromInetSockaddr(const struct sockaddr* addr); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 512de5ee0..8cf08991b 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -37,6 +37,7 @@ #include <map> #include <memory> #include <ostream> +#include <regex> #include <string> #include <unordered_set> #include <utility> @@ -51,6 +52,7 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/capability_util.h" @@ -1988,6 +1990,44 @@ TEST(Proc, GetdentsEnoent) { SyscallFailsWithErrno(ENOENT)); } +void CheckSyscwFromIOFile(const std::string& path, const std::string& regex) { + std::string output; + ASSERT_NO_ERRNO(GetContents(path, &output)); + ASSERT_THAT(output, ContainsRegex(absl::StrCat("syscw:\\s+", regex, "\n"))); +} + +// Checks that there is variable accounting of IO between threads/tasks. +TEST(Proc, PidTidIOAccounting) { + absl::Notification notification; + + // Run a thread with a bunch of writes. Check that io account records exactly + // the number of write calls. File open/close is there to prevent buffering. + ScopedThread writer([¬ification] { + const int num_writes = 100; + for (int i = 0; i < num_writes; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + ASSERT_NO_ERRNO(SetContents(path.path(), "a")); + } + notification.Notify(); + const std::string& writer_dir = + absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); + + CheckSyscwFromIOFile(writer_dir, std::to_string(num_writes)); + }); + + // Run a thread and do no writes. Check that no writes are recorded. + ScopedThread noop([¬ification] { + notification.WaitForNotification(); + const std::string& noop_dir = + absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); + + CheckSyscwFromIOFile(noop_dir, "0"); + }); + + writer.Join(); + noop.Join(); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index 2659f6a98..5b6e3e3cd 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <netinet/tcp.h> #include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc index f06f1a24b..786b4b4af 100644 --- a/test/syscalls/linux/proc_net_udp.cc +++ b/test/syscalls/linux/proc_net_udp.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <netinet/tcp.h> #include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc index 9658f7d42..491d5f40f 100644 --- a/test/syscalls/linux/readv_common.cc +++ b/test/syscalls/linux/readv_common.cc @@ -19,12 +19,53 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +// MatchesStringLength checks that a tuple argument of (struct iovec *, int) +// corresponding to an iovec array and its length, contains data that matches +// the string length strlen. +MATCHER_P(MatchesStringLength, strlen, "") { + struct iovec* iovs = arg.first; + int niov = arg.second; + int offset = 0; + for (int i = 0; i < niov; i++) { + offset += iovs[i].iov_len; + } + if (offset != static_cast<int>(strlen)) { + *result_listener << offset; + return false; + } + return true; +} + +// MatchesStringValue checks that a tuple argument of (struct iovec *, int) +// corresponding to an iovec array and its length, contains data that matches +// the string value str. +MATCHER_P(MatchesStringValue, str, "") { + struct iovec* iovs = arg.first; + int len = strlen(str); + int niov = arg.second; + int offset = 0; + for (int i = 0; i < niov; i++) { + struct iovec iov = iovs[i]; + if (len < offset) { + *result_listener << "strlen " << len << " < offset " << offset; + return false; + } + if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) { + absl::string_view iovec_string(static_cast<char*>(iov.iov_base), + iov.iov_len); + *result_listener << iovec_string << " @offset " << offset; + return false; + } + offset += iov.iov_len; + } + return true; +} + extern const char kReadvTestData[] = "127.0.0.1 localhost" "" diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc index 9b6972201..dd6fb7008 100644 --- a/test/syscalls/linux/readv_socket.cc +++ b/test/syscalls/linux/readv_socket.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/test_util.h" @@ -28,9 +27,30 @@ namespace testing { namespace { -class ReadvSocketTest : public SocketTest { +class ReadvSocketTest : public ::testing::Test { + public: void SetUp() override { - SocketTest::SetUp(); + test_unix_stream_socket_[0] = -1; + test_unix_stream_socket_[1] = -1; + test_unix_dgram_socket_[0] = -1; + test_unix_dgram_socket_[1] = -1; + test_unix_seqpacket_socket_[0] = -1; + test_unix_seqpacket_socket_[1] = -1; + + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT( + socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT( write(test_unix_stream_socket_[1], kReadvTestData, kReadvTestDataSize), SyscallSucceedsWithValue(kReadvTestDataSize)); @@ -40,11 +60,22 @@ class ReadvSocketTest : public SocketTest { ASSERT_THAT(write(test_unix_seqpacket_socket_[1], kReadvTestData, kReadvTestDataSize), SyscallSucceedsWithValue(kReadvTestDataSize)); - // FIXME(b/69821513): Enable when possible. - // ASSERT_THAT(write(test_tcp_socket_[1], kReadvTestData, - // kReadvTestDataSize), - // SyscallSucceedsWithValue(kReadvTestDataSize)); } + + void TearDown() override { + close(test_unix_stream_socket_[0]); + close(test_unix_stream_socket_[1]); + + close(test_unix_dgram_socket_[0]); + close(test_unix_dgram_socket_[1]); + + close(test_unix_seqpacket_socket_[0]); + close(test_unix_seqpacket_socket_[1]); + } + + int test_unix_stream_socket_[2]; + int test_unix_dgram_socket_[2]; + int test_unix_seqpacket_socket_[2]; }; TEST_F(ReadvSocketTest, ReadOneBufferPerByte_StreamSocket) { diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc new file mode 100644 index 000000000..106c045e3 --- /dev/null +++ b/test/syscalls/linux/rseq.cc @@ -0,0 +1,198 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <signal.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <sys/wait.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/rseq/test.h" +#include "test/syscalls/linux/rseq/uapi.h" +#include "test/util/logging.h" +#include "test/util/multiprocess_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Syscall test for rseq (restartable sequences). +// +// We must be very careful about how these tests are written. Each thread may +// only have one struct rseq registration, which may be done automatically at +// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does +// not do so). +// +// Testing of rseq is thus done primarily in a child process with no +// registration. This means exec'ing a nostdlib binary, as rseq registration can +// only be cleared by execve (or knowing the old rseq address), and glibc (based +// on the current unmerged patches) register rseq before calling main()). + +int RSeq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { + return syscall(kRseqSyscall, rseq, rseq_len, flags, sig); +} + +// Returns true if this kernel supports the rseq syscall. +PosixErrorOr<bool> RSeqSupported() { + // We have to be careful here, there are three possible cases: + // + // 1. rseq is not supported -> ENOSYS + // 2. rseq is supported and not registered -> success, but we should + // unregister. + // 3. rseq is supported and registered -> EINVAL (most likely). + + // The only validation done on new registrations is that rseq is aligned and + // writable. + rseq rseq = {}; + int ret = RSeq(&rseq, sizeof(rseq), 0, 0); + if (ret == 0) { + // Successfully registered, rseq is supported. Unregister. + ret = RSeq(&rseq, sizeof(rseq), kRseqFlagUnregister, 0); + if (ret != 0) { + return PosixError(errno); + } + return true; + } + + switch (errno) { + case ENOSYS: + // Not supported. + return false; + case EINVAL: + // Supported, but already registered. EINVAL returned because we provided + // a different address. + return true; + default: + // Unknown error. + return PosixError(errno); + } +} + +constexpr char kRseqBinary[] = "test/syscalls/linux/rseq/rseq"; + +void RunChildTest(std::string test_case, int want_status) { + std::string path = RunfilePath(kRseqBinary); + + pid_t child_pid = -1; + int execve_errno = 0; + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(path, {path, test_case}, {}, &child_pid, &execve_errno)); + + ASSERT_GT(child_pid, 0); + ASSERT_EQ(execve_errno, 0); + + int status = 0; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); + ASSERT_EQ(status, want_status); +} + +// Test that rseq must be aligned. +TEST(RseqTest, Unaligned) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnaligned, 0); +} + +// Sanity test that registration works. +TEST(RseqTest, Register) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestRegister, 0); +} + +// Registration can't be done twice. +TEST(RseqTest, DoubleRegister) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestDoubleRegister, 0); +} + +// Registration can be done again after unregister. +TEST(RseqTest, RegisterUnregister) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestRegisterUnregister, 0); +} + +// The pointer to rseq must match on register/unregister. +TEST(RseqTest, UnregisterDifferentPtr) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnregisterDifferentPtr, 0); +} + +// The signature must match on register/unregister. +TEST(RseqTest, UnregisterDifferentSignature) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnregisterDifferentSignature, 0); +} + +// The CPU ID is initialized. +TEST(RseqTest, CPU) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestCPU, 0); +} + +// Critical section is eventually aborted. +TEST(RseqTest, Abort) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbort, 0); +} + +// Abort may be before the critical section. +TEST(RseqTest, AbortBefore) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortBefore, 0); +} + +// Signature must match. +TEST(RseqTest, AbortSignature) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortSignature, SIGSEGV); +} + +// Abort must not be in the critical section. +TEST(RseqTest, AbortPreCommit) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortPreCommit, SIGSEGV); +} + +// rseq.rseq_cs is cleared on abort. +TEST(RseqTest, AbortClearsCS) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortClearsCS, 0); +} + +// rseq.rseq_cs is cleared on abort outside of critical section. +TEST(RseqTest, InvalidAbortClearsCS) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestInvalidAbortClearsCS, 0); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD new file mode 100644 index 000000000..5cfe4e56f --- /dev/null +++ b/test/syscalls/linux/rseq/BUILD @@ -0,0 +1,59 @@ +# This package contains a standalone rseq test binary. This binary must not +# depend on libc, which might use rseq itself. + +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") +load("@rules_cc//cc:defs.bzl", "cc_library") + +package(licenses = ["notice"]) + +genrule( + name = "rseq_binary", + srcs = [ + "critical.h", + "critical.S", + "rseq.cc", + "syscalls.h", + "start.S", + "test.h", + "types.h", + "uapi.h", + ], + outs = ["rseq"], + cmd = " ".join([ + "$(CC)", + "$(CC_FLAGS) ", + "-I.", + "-Wall", + "-Werror", + "-O2", + "-std=c++17", + "-static", + "-nostdlib", + "-ffreestanding", + "-o", + "$(location rseq)", + "$(location critical.S)", + "$(location rseq.cc)", + "$(location start.S)", + ]), + toolchains = [ + ":no_pie_cc_flags", + "@bazel_tools//tools/cpp:current_cc_toolchain", + ], + visibility = ["//:sandbox"], +) + +cc_flags_supplier( + name = "no_pie_cc_flags", + features = ["-pie"], +) + +cc_library( + name = "lib", + testonly = 1, + hdrs = [ + "test.h", + "uapi.h", + ], + visibility = ["//:sandbox"], +) diff --git a/test/syscalls/linux/rseq/critical.S b/test/syscalls/linux/rseq/critical.S new file mode 100644 index 000000000..8c0687e6d --- /dev/null +++ b/test/syscalls/linux/rseq/critical.S @@ -0,0 +1,66 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Restartable sequences critical sections. + +// Loops continuously until aborted. +// +// void rseq_loop(struct rseq* r, struct rseq_cs* cs) + + .text + .globl rseq_loop + .type rseq_loop, @function + +rseq_loop: + jmp begin + + // Abort block before the critical section. + // Abort signature is 4 nops for simplicity. + .byte 0x90, 0x90, 0x90, 0x90 + .globl rseq_loop_early_abort +rseq_loop_early_abort: + ret + +begin: + // r->rseq_cs = cs + movq %rsi, 8(%rdi) + + // N.B. rseq_cs will be cleared by any preempt, even outside the critical + // section. Thus it must be set in or immediately before the critical section + // to ensure it is not cleared before the section begins. + .globl rseq_loop_start +rseq_loop_start: + jmp rseq_loop_start + + // "Pre-commit": extra instructions inside the critical section. These are + // used as the abort point in TestAbortPreCommit, which is not valid. + .globl rseq_loop_pre_commit +rseq_loop_pre_commit: + // Extra abort signature + nop for TestAbortPostCommit. + .byte 0x90, 0x90, 0x90, 0x90 + nop + + // "Post-commit": never reached in this case. + .globl rseq_loop_post_commit +rseq_loop_post_commit: + + // Abort signature is 4 nops for simplicity. + .byte 0x90, 0x90, 0x90, 0x90 + + .globl rseq_loop_abort +rseq_loop_abort: + ret + + .size rseq_loop,.-rseq_loop + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/critical.h b/test/syscalls/linux/rseq/critical.h new file mode 100644 index 000000000..ac987a25e --- /dev/null +++ b/test/syscalls/linux/rseq/critical.h @@ -0,0 +1,39 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ + +#include "test/syscalls/linux/rseq/types.h" +#include "test/syscalls/linux/rseq/uapi.h" + +constexpr uint32_t kRseqSignature = 0x90909090; + +extern "C" { + +extern void rseq_loop(struct rseq* r, struct rseq_cs* cs); +extern void* rseq_loop_early_abort; +extern void* rseq_loop_start; +extern void* rseq_loop_pre_commit; +extern void* rseq_loop_post_commit; +extern void* rseq_loop_abort; + +extern int rseq_getpid(struct rseq* r, struct rseq_cs* cs); +extern void* rseq_getpid_start; +extern void* rseq_getpid_post_commit; +extern void* rseq_getpid_abort; + +} // extern "C" + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc new file mode 100644 index 000000000..f036db26d --- /dev/null +++ b/test/syscalls/linux/rseq/rseq.cc @@ -0,0 +1,366 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/rseq/critical.h" +#include "test/syscalls/linux/rseq/syscalls.h" +#include "test/syscalls/linux/rseq/test.h" +#include "test/syscalls/linux/rseq/types.h" +#include "test/syscalls/linux/rseq/uapi.h" + +namespace gvisor { +namespace testing { + +extern "C" int main(int argc, char** argv, char** envp); + +// Standalone initialization before calling main(). +extern "C" void __init(uintptr_t* sp) { + int argc = sp[0]; + char** argv = reinterpret_cast<char**>(&sp[1]); + char** envp = &argv[argc + 1]; + + // Call main() and exit. + sys_exit_group(main(argc, argv, envp)); + + // sys_exit_group does not return +} + +int strcmp(const char* s1, const char* s2) { + const unsigned char* p1 = reinterpret_cast<const unsigned char*>(s1); + const unsigned char* p2 = reinterpret_cast<const unsigned char*>(s2); + + while (*p1 == *p2) { + if (!*p1) { + return 0; + } + ++p1; + ++p2; + } + return static_cast<int>(*p1) - static_cast<int>(*p2); +} + +int sys_rseq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { + return raw_syscall(kRseqSyscall, rseq, rseq_len, flags, sig); +} + +// Test that rseq must be aligned. +int TestUnaligned() { + constexpr uintptr_t kRequiredAlignment = alignof(rseq); + + char buf[2 * kRequiredAlignment] = {}; + uintptr_t ptr = reinterpret_cast<uintptr_t>(&buf[0]); + if ((ptr & (kRequiredAlignment - 1)) == 0) { + // buf is already aligned. Misalign it. + ptr++; + } + + int ret = sys_rseq(reinterpret_cast<rseq*>(ptr), sizeof(rseq), 0, 0); + if (sys_errno(ret) != EINVAL) { + return 1; + } + return 0; +} + +// Sanity test that registration works. +int TestRegister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + return 0; +}; + +// Registration can't be done twice. +int TestDoubleRegister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) { + return 1; + } + + return 0; +}; + +// Registration can be done again after unregister. +int TestRegisterUnregister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); + sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + return 0; +}; + +// The pointer to rseq must match on register/unregister. +int TestUnregisterDifferentPtr() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + struct rseq r2 = {}; + if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); + sys_errno(ret) != EINVAL) { + return 1; + } + + return 0; +}; + +// The signature must match on register/unregister. +int TestUnregisterDifferentSignature() { + constexpr int kSignature = 0; + + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); + sys_errno(ret) != EPERM) { + return 1; + } + + return 0; +}; + +// The CPU ID is initialized. +int TestCPU() { + struct rseq r = {}; + r.cpu_id = kRseqCPUIDUninitialized; + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (__atomic_load_n(&r.cpu_id, __ATOMIC_RELAXED) < 0) { + return 1; + } + if (__atomic_load_n(&r.cpu_id_start, __ATOMIC_RELAXED) < 0) { + return 1; + } + + return 0; +}; + +// Critical section is eventually aborted. +int TestAbort() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + return 0; +}; + +// Abort may be before the critical section. +int TestAbortBefore() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_early_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + return 0; +}; + +// Signature must match. +int TestAbortSignature() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. This should SIGSEGV on abort. + rseq_loop(&r, &cs); + + return 1; +}; + +// Abort must not be in the critical section. +int TestAbortPreCommit() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_pre_commit); + + // Loops until abort. This should SIGSEGV on abort. + rseq_loop(&r, &cs); + + return 1; +}; + +// rseq.rseq_cs is cleared on abort. +int TestAbortClearsCS() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + if (__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { + return 1; + } + + return 0; +}; + +// rseq.rseq_cs is cleared on abort outside of critical section. +int TestInvalidAbortClearsCS() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + __atomic_store_n(&r.rseq_cs, &cs, __ATOMIC_RELAXED); + + // When the next abort condition occurs, the kernel will clear cs once it + // determines we aren't in the critical section. + while (1) { + if (!__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { + break; + } + } + + return 0; +}; + +// Exit codes: +// 0 - Pass +// 1 - Fail +// 2 - Missing argument +// 3 - Unknown test case +extern "C" int main(int argc, char** argv, char** envp) { + if (argc != 2) { + // Usage: rseq <test case> + return 2; + } + + if (strcmp(argv[1], kRseqTestUnaligned) == 0) { + return TestUnaligned(); + } + if (strcmp(argv[1], kRseqTestRegister) == 0) { + return TestRegister(); + } + if (strcmp(argv[1], kRseqTestDoubleRegister) == 0) { + return TestDoubleRegister(); + } + if (strcmp(argv[1], kRseqTestRegisterUnregister) == 0) { + return TestRegisterUnregister(); + } + if (strcmp(argv[1], kRseqTestUnregisterDifferentPtr) == 0) { + return TestUnregisterDifferentPtr(); + } + if (strcmp(argv[1], kRseqTestUnregisterDifferentSignature) == 0) { + return TestUnregisterDifferentSignature(); + } + if (strcmp(argv[1], kRseqTestCPU) == 0) { + return TestCPU(); + } + if (strcmp(argv[1], kRseqTestAbort) == 0) { + return TestAbort(); + } + if (strcmp(argv[1], kRseqTestAbortBefore) == 0) { + return TestAbortBefore(); + } + if (strcmp(argv[1], kRseqTestAbortSignature) == 0) { + return TestAbortSignature(); + } + if (strcmp(argv[1], kRseqTestAbortPreCommit) == 0) { + return TestAbortPreCommit(); + } + if (strcmp(argv[1], kRseqTestAbortClearsCS) == 0) { + return TestAbortClearsCS(); + } + if (strcmp(argv[1], kRseqTestInvalidAbortClearsCS) == 0) { + return TestInvalidAbortClearsCS(); + } + + return 3; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/rseq/start.S b/test/syscalls/linux/rseq/start.S new file mode 100644 index 000000000..b9611b276 --- /dev/null +++ b/test/syscalls/linux/rseq/start.S @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + + .text + .align 4 + .type _start,@function + .globl _start + +_start: + movq %rsp,%rdi + call __init + hlt + + .size _start,.-_start + .section .note.GNU-stack,"",@progbits + + .text + .globl raw_syscall + .type raw_syscall, @function + +raw_syscall: + mov %rdi,%rax // syscall # + mov %rsi,%rdi // arg0 + mov %rdx,%rsi // arg1 + mov %rcx,%rdx // arg2 + mov %r8,%r10 // arg3 (goes in r10 instead of rcx for system calls) + mov %r9,%r8 // arg4 + mov 0x8(%rsp),%r9 // arg5 + syscall + ret + + .size raw_syscall,.-raw_syscall + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/syscalls.h b/test/syscalls/linux/rseq/syscalls.h new file mode 100644 index 000000000..e5299c188 --- /dev/null +++ b/test/syscalls/linux/rseq/syscalls.h @@ -0,0 +1,66 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ + +#include "test/syscalls/linux/rseq/types.h" + +#ifdef __x86_64__ +// Syscall numbers. +constexpr int kGetpid = 39; +constexpr int kExitGroup = 231; +#else +#error "Unknown architecture" +#endif + +namespace gvisor { +namespace testing { + +// Standalone system call interfaces. +// Note that these are all "raw" system call interfaces which encode +// errors by setting the return value to a small negative number. +// Use sys_errno() to check system call return values for errors. + +// Maximum Linux error number. +constexpr int kMaxErrno = 4095; + +// Errno values. +#define EPERM 1 +#define EFAULT 14 +#define EBUSY 16 +#define EINVAL 22 + +// Get the error number from a raw system call return value. +// Returns a positive error number or 0 if there was no error. +static inline int sys_errno(uintptr_t rval) { + if (rval >= static_cast<uintptr_t>(-kMaxErrno)) { + return -static_cast<int>(rval); + } + return 0; +} + +extern "C" uintptr_t raw_syscall(int number, ...); + +static inline void sys_exit_group(int status) { + raw_syscall(kExitGroup, status); +} +static inline int sys_getpid() { + return static_cast<int>(raw_syscall(kGetpid)); +} + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h new file mode 100644 index 000000000..3b7bb74b1 --- /dev/null +++ b/test/syscalls/linux/rseq/test.h @@ -0,0 +1,43 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ + +namespace gvisor { +namespace testing { + +// Test cases supported by rseq binary. + +inline constexpr char kRseqTestUnaligned[] = "unaligned"; +inline constexpr char kRseqTestRegister[] = "register"; +inline constexpr char kRseqTestDoubleRegister[] = "double-register"; +inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; +inline constexpr char kRseqTestUnregisterDifferentPtr[] = + "unregister-different-ptr"; +inline constexpr char kRseqTestUnregisterDifferentSignature[] = + "unregister-different-signature"; +inline constexpr char kRseqTestCPU[] = "cpu"; +inline constexpr char kRseqTestAbort[] = "abort"; +inline constexpr char kRseqTestAbortBefore[] = "abort-before"; +inline constexpr char kRseqTestAbortSignature[] = "abort-signature"; +inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; +inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; +inline constexpr char kRseqTestInvalidAbortClearsCS[] = + "invalid-abort-clears-cs"; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ diff --git a/test/syscalls/linux/rseq/types.h b/test/syscalls/linux/rseq/types.h new file mode 100644 index 000000000..b6afe9817 --- /dev/null +++ b/test/syscalls/linux/rseq/types.h @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ + +using size_t = __SIZE_TYPE__; +using uintptr_t = __UINTPTR_TYPE__; + +using uint8_t = __UINT8_TYPE__; +using uint16_t = __UINT16_TYPE__; +using uint32_t = __UINT32_TYPE__; +using uint64_t = __UINT64_TYPE__; + +using int8_t = __INT8_TYPE__; +using int16_t = __INT16_TYPE__; +using int32_t = __INT32_TYPE__; +using int64_t = __INT64_TYPE__; + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h new file mode 100644 index 000000000..e3ff0579a --- /dev/null +++ b/test/syscalls/linux/rseq/uapi.h @@ -0,0 +1,54 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ + +// User-kernel ABI for restartable sequences. + +// Standard types. +// +// N.B. This header will be included in targets that do have the standard +// library, so we can't shadow the standard type names. +using __u32 = __UINT32_TYPE__; +using __u64 = __UINT64_TYPE__; + +#ifdef __x86_64__ +// Syscall numbers. +constexpr int kRseqSyscall = 334; +#else +#error "Unknown architecture" +#endif // __x86_64__ + +struct rseq_cs { + __u32 version; + __u32 flags; + __u64 start_ip; + __u64 post_commit_offset; + __u64 abort_ip; +} __attribute__((aligned(4 * sizeof(__u64)))); + +// N.B. alignment is enforced by the kernel. +struct rseq { + __u32 cpu_id_start; + __u32 cpu_id; + struct rseq_cs* rseq_cs; + __u32 flags; +} __attribute__((aligned(4 * sizeof(__u64)))); + +constexpr int kRseqFlagUnregister = 1 << 0; + +constexpr int kRseqCPUIDUninitialized = -1; + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc index 6fd3989a4..a778fa639 100644 --- a/test/syscalls/linux/sigaltstack.cc +++ b/test/syscalls/linux/sigaltstack.cc @@ -95,13 +95,7 @@ TEST(SigaltstackTest, ResetByExecve) { auto const cleanup_sigstack = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack)); - std::string full_path; - char* test_src = getenv("TEST_SRCDIR"); - if (test_src) { - full_path = JoinPath(test_src, "../../linux/sigaltstack_check"); - } - - ASSERT_FALSE(full_path.empty()); + std::string full_path = RunfilePath("test/syscalls/linux/sigaltstack_check"); pid_t child_pid = -1; int execve_errno = 0; diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index e4641c62e..34b1058a9 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -97,12 +97,22 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { sockets_to_close_.erase(socket_id); } - // Bind a socket with the reuse option and bind_to_device options. Checks + // SetDevice changes the bind_to_device option. It does not bind or re-bind. + void SetDevice(int socket_id, int device_id) { + auto socket_fd = sockets_to_close_[socket_id]->get(); + string device_name; + ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, + device_name.c_str(), device_name.size() + 1), + SyscallSucceedsWithValue(0)); + } + + // Bind a socket with the reuse options and bind_to_device options. Checks // that all steps succeed and that the bind command's error matches want. // Sets the socket_id to uniquely identify the socket bound if it is not // nullptr. - void BindSocket(bool reuse, int device_id = 0, int want = 0, - int *socket_id = nullptr) { + void BindSocket(bool reuse_port, bool reuse_addr, int device_id = 0, + int want = 0, int *socket_id = nullptr) { next_socket_id_++; sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket_fd = sockets_to_close_[next_socket_id_]->get(); @@ -110,13 +120,20 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { *socket_id = next_socket_id_; } - // If reuse is indicated, do that. - if (reuse) { + // If reuse_port is indicated, do that. + if (reuse_port) { EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); } + // If reuse_addr is indicated, do that. + if (reuse_addr) { + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + } + // If the device is non-zero, bind to that device. if (device_id != 0) { string device_name; @@ -182,129 +199,308 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { }; TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 3)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindToDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 1)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 2)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 1)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 2)); } TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithReuse) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0)); + BindSocket(/* reusePort */ true, /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); } TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse_port */ true, /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 999, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 999, 0)); } TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindAndRelease) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); int to_release; - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 345, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); // Release the bind to device 0 and try again. ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 345)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 345)); } TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + BindSocket(/* reusePort */ false, /* reuse_addr */ true)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); +} + +TEST_P(BindToDeviceSequenceTest, + CannotBindTo0AfterMixingReuseAddrAndNotReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReusePortThenReuseAddrReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +// This behavior seems like a bug? +TEST_P(BindToDeviceSequenceTest, + BindReuseAddrThenReuseAddrReusePortThenReuseAddr) { + // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); +} + +// Repro test for gvisor.dev/issue/1217. Not replicated in ports_test.go as this +// test is different from the others and wouldn't fit well there. +TEST_P(BindToDeviceSequenceTest, BindAndReleaseDifferentDevice) { + int to_release; + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, EADDRINUSE)); + // Change the device. Since the socket was already bound, this should have no + // effect. + SetDevice(to_release, 2); + // Release the bind to device 3 and try again. + ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); } INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 96a1731cf..fa4358ae4 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -635,7 +635,9 @@ INSTANTIATE_TEST_SUITE_P( using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>; -TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { +// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is +// saved/restored. +TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -643,6 +645,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { sockaddr_storage listen_addr = listener.addr; sockaddr_storage conn_addr = connector.addr; constexpr int kThreadCount = 3; + constexpr int kConnectAttempts = 4096; // Create the listening socket. FileDescriptor listener_fds[kThreadCount]; @@ -657,7 +660,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { ASSERT_THAT( bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), SyscallSucceeds()); - ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + ASSERT_THAT(listen(fd, kConnectAttempts / 3), SyscallSucceeds()); // On the first bind we need to determine which port was bound. if (i != 0) { @@ -676,7 +679,6 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); } - constexpr int kConnectAttempts = 10000; std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); std::unique_ptr<ScopedThread> listen_thread[kThreadCount]; int accept_counts[kThreadCount] = {}; diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index a37b49447..c74273436 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -24,6 +24,8 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -789,5 +791,26 @@ TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { EXPECT_EQ(get, kTCPLingerTimeout); } +TEST_P(TCPSocketPairTest, TestTCPCloseWithData) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ScopedThread t([&]() { + // Close one end to trigger sending of a FIN. + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds()); + char buf[3]; + ASSERT_THAT(read(sockets->second_fd(), buf, 3), + SyscallSucceedsWithValue(3)); + absl::SleepFor(absl::Milliseconds(50)); + ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); + }); + + absl::SleepFor(absl::Milliseconds(50)); + // Send some data then close. + constexpr char kStr[] = "abc"; + ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3)); + t.Join(); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc index 1159c5229..a16899493 100644 --- a/test/syscalls/linux/socket_unix_cmsg.cc +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -149,6 +149,35 @@ TEST_P(UnixSocketPairCmsgTest, BadFDPass) { SyscallFailsWithErrno(EBADF)); } +TEST_P(UnixSocketPairCmsgTest, ShortCmsg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + int sent_fd = -1; + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(sent_fd))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = 1; + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); + + struct iovec iov; + iov.iov_base = sent_data; + iov.iov_len = sizeof(sent_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EINVAL)); +} + // BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. // The difference is that when calling recvmsg, no space for FDs is provided, // only space for the cmsg header. diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index b6090ac66..dc35c2f50 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -527,6 +527,45 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) { SyscallFailsWithErrno(ENOTCONN)); } +TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { + struct sockaddr_storage baddr = {}; + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addr_in->sin_family = AF_INET; + addr_in->sin_port = port; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + addr_in->sin6_scope_id = 0; + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } + ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), + SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + // If the socket is bound to ANY and connected to a loopback address, + // getsockname() has to return the loopback address. + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { struct sockaddr_storage baddr = {}; socklen_t addrlen; @@ -617,6 +656,9 @@ TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { } TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); @@ -634,6 +676,9 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { } TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); @@ -839,6 +884,10 @@ TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { } TEST_P(UdpSocketTest, ReadShutdown) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + char received[512]; EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); @@ -861,6 +910,10 @@ TEST_P(UdpSocketTest, ReadShutdown) { } TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + char received[512]; EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); @@ -1150,6 +1203,10 @@ TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { } TEST_P(UdpSocketTest, SoTimestampOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + int v = -1; socklen_t optlen = sizeof(v); ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), @@ -1159,6 +1216,10 @@ TEST_P(UdpSocketTest, SoTimestampOffByDefault) { } TEST_P(UdpSocketTest, SoTimestamp) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1202,6 +1263,9 @@ TEST_P(UdpSocketTest, WriteShutdownNotConnected) { } TEST_P(UdpSocketTest, TimestampIoctl) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1220,7 +1284,10 @@ TEST_P(UdpSocketTest, TimestampIoctl) { ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); } -TEST_P(UdpSocketTest, TimetstampIoctlNothingRead) { +TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1231,6 +1298,10 @@ TEST_P(UdpSocketTest, TimetstampIoctlNothingRead) { // Test that the timestamp accessed via SIOCGSTAMP is still accessible after // SO_TIMESTAMP is enabled and used to retrieve a timestamp. TEST_P(UdpSocketTest, TimestampIoctlPersistence) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1265,7 +1336,6 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { msg.msg_controllen = sizeof(cmsgbuf); ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg = CMSG_FIRSTHDR(&msg); ASSERT_NE(cmsg, nullptr); // The ioctl should return the exact same values as before. @@ -1275,5 +1345,154 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { ASSERT_EQ(tv.tv_usec, tv2.tv_usec); } +// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on +// outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SetAndReceiveTOS) { + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + ASSERT_THAT( + setsockopt(s_, recv_level, recv_type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Set socket TOS. + int sent_level = recv_level; + int sent_type = IP_TOS; + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + } + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + ASSERT_THAT( + setsockopt(t_, sent_level, sent_type, &sent_tos, sizeof(sent_tos)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_type == IPV6_TCLASS) { + cmsg_data_len = sizeof(int); + } + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = received_cmsgbuf.size(); + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the +// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SendAndReceiveTOS) { + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + int recv_opt = kSockOptOn; + ASSERT_THAT( + setsockopt(s_, recv_level, recv_type, &recv_opt, sizeof(recv_opt)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + int sent_level = recv_level; + int sent_type = IP_TOS; + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + cmsg_data_len = sizeof(int); + } + std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + sent_msg.msg_control = &sent_cmsgbuf[0]; + sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + + // Manually add control message. + struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); + sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); + sent_cmsg->cmsg_level = sent_level; + sent_cmsg->cmsg_type = sent_type; + *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; + + ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index accf46347..b9fd885ff 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -46,6 +46,7 @@ var ( debug = flag.Bool("debug", false, "enable debug logs") strace = flag.Bool("strace", false, "enable strace logs") platform = flag.String("platform", "ptrace", "platform to run on") + network = flag.String("network", "none", "network stack to run on (sandbox, host, none)") useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp") fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") @@ -137,7 +138,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { args := []string{ "-root", rootDir, - "-network=none", + "-network", *network, "-log-format=text", "-TESTONLY-unsafe-nonroot=true", "-net-raw=true", @@ -335,10 +336,11 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { }) } - // Set environment variable that indicates we are - // running in gVisor and with the given platform. + // Set environment variables that indicate we are + // running in gVisor with the given platform and network. platformVar := "TEST_ON_GVISOR" - env := append(os.Environ(), platformVar+"="+*platform) + networkVar := "GVISOR_NETWORK" + env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) // Remove env variables that cause the gunit binary to write output // files, since they will stomp on eachother, and on the output files diff --git a/test/util/BUILD b/test/util/BUILD index 4526bb3f1..cbc728159 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -237,6 +237,7 @@ cc_library( ] + select_for_linux( [ "test_util_impl.cc", + "test_util_runfiles.cc", ], ), hdrs = ["test_util.h"], @@ -245,6 +246,7 @@ cc_library( ":logging", ":posix_error", ":save_util", + "@bazel_tools//tools/cpp/runfiles", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 88b1e7911..042cec94a 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -105,6 +105,15 @@ PosixErrorOr<struct stat> Stat(absl::string_view path) { return stat_buf; } +PosixErrorOr<struct stat> Lstat(absl::string_view path) { + struct stat stat_buf; + int res = lstat(std::string(path).c_str(), &stat_buf); + if (res < 0) { + return PosixError(errno, absl::StrCat("lstat ", path)); + } + return stat_buf; +} + PosixErrorOr<struct stat> Fstat(int fd) { struct stat stat_buf; int res = fstat(fd, &stat_buf); @@ -127,7 +136,7 @@ PosixErrorOr<bool> Exists(absl::string_view path) { } PosixErrorOr<bool> IsDirectory(absl::string_view path) { - ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Stat(path)); + ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Lstat(path)); if (S_ISDIR(stat_buf.st_mode)) { return true; } diff --git a/test/util/test_util.cc b/test/util/test_util.cc index 9cb050735..848504c88 100644 --- a/test/util/test_util.cc +++ b/test/util/test_util.cc @@ -41,6 +41,7 @@ namespace gvisor { namespace testing { #define TEST_ON_GVISOR "TEST_ON_GVISOR" +#define GVISOR_NETWORK "GVISOR_NETWORK" bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; } @@ -60,6 +61,11 @@ Platform GvisorPlatform() { abort(); } +bool IsRunningWithHostinet() { + char* env = getenv(GVISOR_NETWORK); + return env && strcmp(env, "host") == 0; +} + // Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations // %ebx contains the address of the global offset table. %rbx is occasionally // used to address stack variables in presence of dynamic allocas. diff --git a/test/util/test_util.h b/test/util/test_util.h index dc30575b8..b3235c7e3 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -220,6 +220,7 @@ enum class Platform { }; bool IsRunningOnGvisor(); Platform GvisorPlatform(); +bool IsRunningWithHostinet(); #ifdef __linux__ void SetupGvisorDeathTest(); @@ -764,6 +765,12 @@ MATCHER_P2(EquivalentWithin, target, tolerance, return Equivalent(arg, target, tolerance); } +// Returns the absolute path to the a data dependency. 'path' is the runfile +// location relative to workspace root. +#ifdef __linux__ +std::string RunfilePath(std::string path); +#endif + void TestInit(int* argc, char*** argv); } // namespace testing diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc new file mode 100644 index 000000000..7210094eb --- /dev/null +++ b/test/util/test_util_runfiles.cc @@ -0,0 +1,46 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <iostream> +#include <string> + +#include "test/util/fs_util.h" +#include "test/util/test_util.h" +#include "tools/cpp/runfiles/runfiles.h" + +namespace gvisor { +namespace testing { + +std::string RunfilePath(std::string path) { + static const bazel::tools::cpp::runfiles::Runfiles* const runfiles = [] { + std::string error; + auto* runfiles = + bazel::tools::cpp::runfiles::Runfiles::CreateForTest(&error); + if (runfiles == nullptr) { + std::cerr << "Unable to find runfiles: " << error << std::endl; + } + return runfiles; + }(); + + if (!runfiles) { + // Can't find runfiles? This probably won't work, but __main__/path is our + // best guess. + return JoinPath("__main__", path); + } + + return runfiles->Rlocation(JoinPath("__main__", path)); +} + +} // namespace testing +} // namespace gvisor |