summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2020-02-06 17:07:04 -0800
committerKevin Krakauer <krakauer@google.com>2020-02-06 17:07:04 -0800
commitd98287f5eb40a9c91668b7511824c05d542e0599 (patch)
treef8430747db6e1c02fe0eb45a7dad0899d06bd072
parentbf0ea204e9415a181c63ee10078cca753df14f7e (diff)
parent16561e461e82f8d846ef1f3ada990270ef39ccc6 (diff)
Merge branch 'master' into tcp-matchers-submit
-rw-r--r--benchmarks/BUILD8
-rw-r--r--benchmarks/README.md21
-rw-r--r--benchmarks/harness/BUILD21
-rw-r--r--benchmarks/harness/__init__.py36
-rw-r--r--benchmarks/harness/machine.py43
-rw-r--r--benchmarks/harness/machine_producers/BUILD1
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py114
-rw-r--r--benchmarks/harness/ssh_connection.py25
-rw-r--r--benchmarks/runner/__init__.py75
-rw-r--r--benchmarks/runner/commands.py70
-rw-r--r--benchmarks/tcp/tcp_proxy.go2
-rw-r--r--pkg/metric/metric.go1
-rw-r--r--pkg/p9/BUILD3
-rw-r--r--pkg/p9/client.go9
-rw-r--r--pkg/pool/BUILD25
-rw-r--r--pkg/pool/pool.go (renamed from pkg/p9/pool.go)26
-rw-r--r--pkg/pool/pool_test.go (renamed from pkg/p9/pool_test.go)8
-rw-r--r--pkg/sentry/arch/arch_x86.go4
-rw-r--r--pkg/sentry/arch/signal_amd64.go2
-rw-r--r--pkg/sentry/fs/file_overlay_test.go1
-rw-r--r--pkg/sentry/fs/proc/README.md4
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/kernel.go3
-rw-r--r--pkg/sentry/kernel/kernel_opts.go20
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go5
-rw-r--r--pkg/sentry/socket/hostinet/sockopt_impl.go27
-rw-r--r--pkg/sentry/socket/netstack/netstack.go7
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go111
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go22
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go44
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go186
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go278
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/ndp_test.go82
-rw-r--r--pkg/tcpip/stack/nic.go84
-rw-r--r--pkg/tcpip/stack/nic_test.go62
-rw-r--r--pkg/tcpip/stack/stack_test.go115
-rw-r--r--pkg/tcpip/tcpip.go12
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go9
-rw-r--r--runsc/boot/filter/BUILD1
-rw-r--r--runsc/boot/filter/config.go13
-rw-r--r--runsc/boot/filter/config_profile.go34
-rw-r--r--runsc/container/console_test.go5
-rw-r--r--runsc/dockerutil/dockerutil.go11
-rw-r--r--runsc/testutil/BUILD5
-rw-r--r--runsc/testutil/testutil.go61
-rw-r--r--runsc/testutil/testutil_runfiles.go75
-rw-r--r--test/image/image_test.go8
-rw-r--r--test/iptables/README.md2
-rw-r--r--test/syscalls/build_defs.bzl35
-rw-r--r--test/syscalls/linux/32bit.cc2
-rw-r--r--test/syscalls/linux/chroot.cc2
-rw-r--r--test/syscalls/linux/concurrency.cc3
-rw-r--r--test/syscalls/linux/exec_proc_exe_workload.cc6
-rw-r--r--test/syscalls/linux/fork.cc5
-rw-r--r--test/syscalls/linux/mmap.cc8
-rw-r--r--test/syscalls/linux/open_create.cc1
-rw-r--r--test/syscalls/linux/preadv.cc1
-rw-r--r--test/syscalls/linux/proc.cc46
-rw-r--r--test/syscalls/linux/readv.cc4
-rw-r--r--test/syscalls/linux/rseq.cc2
-rw-r--r--test/syscalls/linux/select.cc2
-rw-r--r--test/syscalls/linux/shm.cc2
-rw-r--r--test/syscalls/linux/sigprocmask.cc2
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.cc4
-rw-r--r--test/syscalls/linux/symlink.cc2
-rw-r--r--test/syscalls/linux/tcp_socket.cc12
-rw-r--r--test/syscalls/linux/time.cc1
-rw-r--r--test/syscalls/linux/tkill.cc2
-rw-r--r--test/util/temp_path.cc1
-rw-r--r--tools/build/tags.bzl4
-rw-r--r--tools/defs.bzl17
-rw-r--r--tools/images/defs.bzl5
-rw-r--r--tools/installers/BUILD7
-rwxr-xr-xtools/installers/head.sh2
77 files changed, 1358 insertions, 616 deletions
diff --git a/benchmarks/BUILD b/benchmarks/BUILD
index 1455c6c5b..43614cf5d 100644
--- a/benchmarks/BUILD
+++ b/benchmarks/BUILD
@@ -3,8 +3,16 @@ package(licenses = ["notice"])
py_binary(
name = "benchmarks",
srcs = ["run.py"],
+ data = [
+ "//tools/images:ubuntu1604",
+ "//tools/images:zone",
+ ],
main = "run.py",
python_version = "PY3",
srcs_version = "PY3",
+ tags = [
+ "local",
+ "manual",
+ ],
deps = ["//benchmarks/runner"],
)
diff --git a/benchmarks/README.md b/benchmarks/README.md
index ff21614c5..975321c99 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -26,6 +26,8 @@ For configuring the environment manually, consult the
## Running benchmarks
+### Locally
+
Run the following from the benchmarks directory:
```bash
@@ -44,7 +46,7 @@ runtime, runc. Running on another installed runtime, like say runsc, is as
simple as:
```bash
-bazel run :benchmakrs -- run-local startup --runtime=runsc
+bazel run :benchmarks -- run-local startup --runtime=runsc
```
There is help: ``bash bash bazel run :benchmarks -- --help bazel
@@ -104,6 +106,23 @@ Or with different parameters:
bazel run :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu
```
+### On Google Compute Engine (GCE)
+
+Benchmarks may be run on GCE in an automated way. The default project configured
+for `gcloud` will be used.
+
+An additional parameter `installers` may be provided to ensure that the latest
+runtime is installed from the workspace. See the files in `tools/installers` for
+supported install targets.
+
+```bash
+bazel run :benchmarks -- run-gcp --installers=head --runtime=runsc sysbench.cpu
+```
+
+When running on GCE, the scripts generate a per run SSH key, which is added to
+your project. The key is set to expire in GCE after 60 minutes and is stored in
+a temporary directory on the local machine running the scripts.
+
## Writing benchmarks
To write new benchmarks, you should familiarize yourself with the structure of
diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD
index 52d4e42f8..4d03e3a06 100644
--- a/benchmarks/harness/BUILD
+++ b/benchmarks/harness/BUILD
@@ -1,3 +1,4 @@
+load("//tools:defs.bzl", "pkg_tar")
load("//tools:defs.bzl", "py_library", "py_requirement")
package(
@@ -5,9 +6,29 @@ package(
licenses = ["notice"],
)
+pkg_tar(
+ name = "installers",
+ srcs = [
+ "//tools/installers:head",
+ "//tools/installers:master",
+ "//tools/installers:runsc",
+ ],
+ mode = "0755",
+)
+
+filegroup(
+ name = "files",
+ srcs = [
+ ":installers",
+ ],
+)
+
py_library(
name = "harness",
srcs = ["__init__.py"],
+ data = [
+ ":files",
+ ],
)
py_library(
diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py
index 61fd25f73..15aa2a69a 100644
--- a/benchmarks/harness/__init__.py
+++ b/benchmarks/harness/__init__.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,18 +15,48 @@
import getpass
import os
+import subprocess
+import tempfile
# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a
# format string that accepts a single string parameter.
-LOCAL_WORKLOADS_PATH = os.path.join(
- os.path.dirname(__file__), "../workloads/{}/tar.tar")
+LOCAL_WORKLOADS_PATH = os.path.dirname(__file__) + "/../workloads/{}/tar.tar"
# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the
# remote host. This is a format string that accepts a single string parameter.
REMOTE_WORKLOADS_PATH = "workloads/{}"
+# INSTALLER_ROOT is the set of files that needs to be copied.
+INSTALLER_ARCHIVE = os.readlink(os.path.join(
+ os.path.dirname(__file__), "installers.tar"))
+
+# SSH_KEY_DIR holds SSH_PRIVATE_KEY for this run. bm-tools paramiko requires
+# keys generated with the '-t rsa -m PEM' options from ssh-keygen. This is
+# abstracted away from the user.
+SSH_KEY_DIR = tempfile.TemporaryDirectory()
+SSH_PRIVATE_KEY = "key"
+
# DEFAULT_USER is the default user running this script.
DEFAULT_USER = getpass.getuser()
# DEFAULT_USER_HOME is the home directory of the user running the script.
DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else ""
+
+# Default directory to remotely installer "installer" targets.
+REMOTE_INSTALLERS_PATH = "installers"
+
+
+def make_key():
+ """Wraps a valid ssh key in a temporary directory."""
+ path = os.path.join(SSH_KEY_DIR.name, SSH_PRIVATE_KEY)
+ if not os.path.exists(path):
+ cmd = "ssh-keygen -t rsa -m PEM -b 4096 -f {key} -q -N".format(
+ key=path).split(" ")
+ cmd.append("")
+ subprocess.run(cmd, check=True)
+ return path
+
+
+def delete_key():
+ """Deletes temporary directory containing private key."""
+ SSH_KEY_DIR.cleanup()
diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py
index 2df4c9e31..3d32d3dda 100644
--- a/benchmarks/harness/machine.py
+++ b/benchmarks/harness/machine.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,10 +29,11 @@ to run contianers.
"""
import logging
+import os
import re
import subprocess
import time
-from typing import Tuple
+from typing import List, Tuple
import docker
@@ -201,6 +202,7 @@ class RemoteMachine(Machine):
self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs)
self._tunnel.connect()
self._docker_client = self._tunnel.get_docker_client()
+ self._has_installers = False
def run(self, cmd: str) -> Tuple[str, str]:
return self._ssh_connection.run(cmd)
@@ -210,14 +212,45 @@ class RemoteMachine(Machine):
stdout, stderr = self._ssh_connection.run("cat '{}'".format(path))
return stdout + stderr
+ def install(self,
+ installer: str,
+ results: List[bool] = None,
+ index: int = -1):
+ """Method unique to RemoteMachine to handle installation of installers.
+
+ Handles installers, which install things that may change between runs (e.g.
+ runsc). Usually called from gcloud_producer, which expects this method to
+ to store results.
+
+ Args:
+ installer: the installer target to run.
+ results: Passed by the caller of where to store success.
+ index: Index for this method to store the result in the passed results
+ list.
+ """
+ # This generates a tarball of the full installer root (which will generate
+ # be the full bazel root directory) and sends it over.
+ if not self._has_installers:
+ archive = self._ssh_connection.send_installers()
+ self.run("tar -xvf {archive} -C {dir}".format(
+ archive=archive, dir=harness.REMOTE_INSTALLERS_PATH))
+ self._has_installers = True
+
+ # Execute the remote installer.
+ self.run("sudo {dir}/{file}".format(
+ dir=harness.REMOTE_INSTALLERS_PATH, file=installer))
+ if results:
+ results[index] = True
+
def pull(self, workload: str) -> str:
# Push to the remote machine and build.
logging.info("Building %s@%s remotely...", workload, self._name)
remote_path = self._ssh_connection.send_workload(workload)
+ remote_dir = os.path.dirname(remote_path)
# Workloads are all tarballs.
- self.run("tar -xvf {remote_path}/tar.tar -C {remote_path}".format(
- remote_path=remote_path))
- self.run("docker build --tag={} {}".format(workload, remote_path))
+ self.run("tar -xvf {remote_path} -C {remote_dir}".format(
+ remote_path=remote_path, remote_dir=remote_dir))
+ self.run("docker build --tag={} {}".format(workload, remote_dir))
return workload # Workload is the tag.
def container(self, image: str, **kwargs) -> container.Container:
diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD
index 48ea0ef39..3711a397f 100644
--- a/benchmarks/harness/machine_producers/BUILD
+++ b/benchmarks/harness/machine_producers/BUILD
@@ -76,5 +76,6 @@ py_test(
python_version = "PY3",
tags = [
"local",
+ "manual",
],
)
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
index e0b77d52b..513d16e4f 100644
--- a/benchmarks/harness/machine_producers/gcloud_producer.py
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,12 +46,11 @@ class GCloudProducer(machine_producer.MachineProducer):
Produces Machine objects backed by GCP instances.
Attributes:
- project: The GCP project name under which to create the machines.
- ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
image: image name as a string.
- image_project: image project as a string.
- machine_type: type of GCP to create. e.g. n1-standard-4
zone: string to a valid GCP zone.
+ machine_type: type of GCP to create (e.g. n1-standard-4).
+ installers: list of installers post-boot.
+ ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
ssh_user: string of user name for ssh_key
ssh_password: string of password for ssh key
mock: a mock printer which will print mock data if required. Mock data is
@@ -60,21 +59,19 @@ class GCloudProducer(machine_producer.MachineProducer):
"""
def __init__(self,
- project: str,
- ssh_key_file: str,
image: str,
- image_project: str,
- machine_type: str,
zone: str,
+ machine_type: str,
+ installers: List[str],
+ ssh_key_file: str,
ssh_user: str,
ssh_password: str,
mock: gcloud_mock_recorder.MockPrinter = None):
- self.project = project
- self.ssh_key_file = ssh_key_file
self.image = image
- self.image_project = image_project
- self.machine_type = machine_type
self.zone = zone
+ self.machine_type = machine_type
+ self.installers = installers
+ self.ssh_key_file = ssh_key_file
self.ssh_user = ssh_user
self.ssh_password = ssh_password
self.mock = mock
@@ -87,10 +84,34 @@ class GCloudProducer(machine_producer.MachineProducer):
"Cannot ask for {num} machines!".format(num=num_machines))
with self.condition:
names = self._get_unique_names(num_machines)
- self._build_instances(names)
- instances = self._start_command(names)
+ instances = self._build_instances(names)
self._add_ssh_key_to_instances(names)
- return self._machines_from_instances(instances)
+ machines = self._machines_from_instances(instances)
+
+ # Install all bits in lock-step.
+ #
+ # This will perform paralell installations for however many machines we
+ # have, but it's easy to track errors because if installing (a, b, c), we
+ # won't install "c" until "b" is installed on all machines.
+ for installer in self.installers:
+ threads = [None] * len(machines)
+ results = [False] * len(machines)
+ for i in range(len(machines)):
+ threads[i] = threading.Thread(
+ target=machines[i].install, args=(installer, results, i))
+ threads[i].start()
+ for thread in threads:
+ thread.join()
+ for result in results:
+ if not result:
+ raise NotImplementedError(
+ "Installers failed on at least one machine!")
+
+ # Add this user to each machine's docker group.
+ for m in machines:
+ m.run("sudo setfacl -m user:$USER:rw /var/run/docker.sock")
+
+ return machines
def release_machines(self, machine_list: List[machine.Machine]):
"""Releases the requested number of machines, deleting the instances."""
@@ -123,15 +144,7 @@ class GCloudProducer(machine_producer.MachineProducer):
def _get_unique_names(self, num_names) -> List[str]:
"""Returns num_names unique names based on data from the GCP project."""
- curr_machines = self._list_machines()
- curr_names = set([machine["name"] for machine in curr_machines])
- ret = []
- while len(ret) < num_names:
- new_name = "machine-" + str(uuid.uuid4())
- if new_name not in curr_names:
- ret.append(new_name)
- curr_names.update(new_name)
- return ret
+ return ["machine-" + str(uuid.uuid4()) for _ in range(0, num_names)]
def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]:
"""Creates instances using gcloud command.
@@ -151,34 +164,9 @@ class GCloudProducer(machine_producer.MachineProducer):
"_build_instances cannot create instances without names.")
cmd = "gcloud compute instances create".split(" ")
cmd.extend(names)
- cmd.extend(
- "--preemptible --image={image} --zone={zone} --machine-type={machine_type}"
- .format(
- image=self.image, zone=self.zone,
- machine_type=self.machine_type).split(" "))
- if self.image_project:
- cmd.append("--image-project={project}".format(project=self.image_project))
- res = self._run_command(cmd)
- return json.loads(res.stdout)
-
- def _start_command(self, names):
- """Starts instances using gcloud command.
-
- Runs the command `gcloud compute instances start` on list of instances by
- name and returns json data on started instances on success.
-
- Args:
- names: list of names of instances to start.
-
- Returns:
- List of json data describing started machines.
- """
- if not names:
- raise ValueError("_start_command cannot start empty instance list.")
- cmd = "gcloud compute instances start".split(" ")
- cmd.extend(names)
- cmd.append("--zone={zone}".format(zone=self.zone))
- cmd.append("--project={project}".format(project=self.project))
+ cmd.append("--image=" + self.image)
+ cmd.append("--zone=" + self.zone)
+ cmd.append("--machine-type=" + self.machine_type)
res = self._run_command(cmd)
return json.loads(res.stdout)
@@ -186,7 +174,7 @@ class GCloudProducer(machine_producer.MachineProducer):
"""Adds ssh key to instances by calling gcloud ssh command.
Runs the command `gcloud compute ssh instance_name` on list of images by
- name. Tries to ssh into given instance
+ name. Tries to ssh into given instance.
Args:
names: list of machine names to which to add the ssh-key
@@ -202,30 +190,18 @@ class GCloudProducer(machine_producer.MachineProducer):
cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file))
cmd.append("--zone={zone}".format(zone=self.zone))
cmd.append("--command=uname")
+ cmd.append("--ssh-key-expire-after=60m")
timeout = datetime.timedelta(seconds=5 * 60)
start = datetime.datetime.now()
while datetime.datetime.now() <= timeout + start:
try:
self._run_command(cmd)
break
- except subprocess.CalledProcessError as e:
+ except subprocess.CalledProcessError:
if datetime.datetime.now() > timeout + start:
raise TimeoutError(
"Could not SSH into instance after 5 min: {name}".format(
name=name))
- # 255 is the returncode for ssh connection refused.
- elif e.returncode == 255:
-
- continue
- else:
- raise e
-
- def _list_machines(self) -> List[Dict[str, Any]]:
- """Runs `list` gcloud command and returns list of Machine data."""
- cmd = "gcloud compute instances list --project {project}".format(
- project=self.project).split(" ")
- res = self._run_command(cmd)
- return json.loads(res.stdout)
def _run_command(self,
cmd: List[str],
@@ -261,7 +237,7 @@ class GCloudProducer(machine_producer.MachineProducer):
self.mock.record(res)
if res.returncode != 0:
raise subprocess.CalledProcessError(
- cmd=res.args,
+ cmd=" ".join(res.args),
output=res.stdout,
stderr=res.stderr,
returncode=res.returncode)
diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py
index e0bf258f1..a50e34293 100644
--- a/benchmarks/harness/ssh_connection.py
+++ b/benchmarks/harness/ssh_connection.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# limitations under the License.
"""SSHConnection handles the details of SSH connections."""
+
import os
import warnings
@@ -24,18 +25,24 @@ from benchmarks import harness
warnings.filterwarnings(action="ignore", module=".*paramiko.*")
-def send_one_file(client: paramiko.SSHClient, path: str, remote_dir: str):
+def send_one_file(client: paramiko.SSHClient, path: str,
+ remote_dir: str) -> str:
"""Sends a single file via an SSH client.
Args:
client: The existing SSH client.
path: The local path.
remote_dir: The remote directory.
+
+ Returns:
+ :return: The remote path as a string.
"""
filename = path.split("/").pop()
- client.exec_command("mkdir -p " + remote_dir)
+ if remote_dir != ".":
+ client.exec_command("mkdir -p " + remote_dir)
with client.open_sftp() as ftp_client:
ftp_client.put(path, os.path.join(remote_dir, filename))
+ return os.path.join(remote_dir, filename)
class SSHConnection:
@@ -103,6 +110,12 @@ class SSHConnection:
The remote path.
"""
with self._client() as client:
- send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name),
- harness.REMOTE_WORKLOADS_PATH.format(name))
- return harness.REMOTE_WORKLOADS_PATH.format(name)
+ return send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name),
+ harness.REMOTE_WORKLOADS_PATH.format(name))
+
+ def send_installers(self) -> str:
+ with self._client() as client:
+ return send_one_file(
+ client,
+ path=harness.INSTALLER_ARCHIVE,
+ remote_dir=harness.REMOTE_INSTALLERS_PATH)
diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py
index ba80d83d7..ba27dc69f 100644
--- a/benchmarks/runner/__init__.py
+++ b/benchmarks/runner/__init__.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,13 +15,10 @@
import copy
import csv
-import json
import logging
-import os
import pkgutil
import pydoc
import re
-import subprocess
import sys
import types
from typing import List
@@ -123,57 +120,29 @@ def run_mock(ctx, **kwargs):
@runner.command("run-gcp", commands.GCPCommand)
@click.pass_context
-def run_gcp(ctx, project: str, ssh_key_file: str, image: str,
- image_project: str, machine_type: str, zone: str, ssh_user: str,
- ssh_password: str, **kwargs):
+def run_gcp(ctx, image_file: str, zone_file: str, machine_type: str,
+ installers: List[str], **kwargs):
"""Runs all benchmarks on GCP instances."""
- if not ssh_user:
- ssh_user = harness.DEFAULT_USER
-
- # Get the default project if one was not provided.
- if not project:
- sub = subprocess.run(
- "gcloud config get-value project".split(" "), stdout=subprocess.PIPE)
- if sub.returncode:
- raise ValueError(
- "Cannot get default project from gcloud. Is it configured>")
- project = sub.stdout.decode("utf-8").strip("\n")
-
- if not image_project:
- image_project = project
-
- # Check that the ssh-key exists and is readable.
- if not os.access(ssh_key_file, os.R_OK):
- raise ValueError(
- "ssh key given `{ssh_key}` is does not exist or is not readable."
- .format(ssh_key=ssh_key_file))
-
- # Check that the image exists.
- sub = subprocess.run(
- "gcloud compute images describe {image} --project {image_project} --format=json"
- .format(image=image, image_project=image_project).split(" "),
- stdout=subprocess.PIPE)
- if sub.returncode or "READY" not in json.loads(sub.stdout)["status"]:
- raise ValueError(
- "given image was not found or is not ready: {image} {image_project}."
- .format(image=image, image_project=image_project))
-
- # Check and set zone to default.
- if not zone:
- sub = subprocess.run(
- "gcloud config get-value compute/zone".split(" "),
- stdout=subprocess.PIPE)
- if sub.returncode:
- raise ValueError(
- "Default zone is not set in gcloud. Set one or pass a zone with the --zone flag."
- )
- zone = sub.stdout.decode("utf-8").strip("\n")
-
- producer = gcloud_producer.GCloudProducer(project, ssh_key_file, image,
- image_project, machine_type, zone,
- ssh_user, ssh_password)
- run(ctx, producer, **kwargs)
+ # Resolve all files.
+ image = open(image_file).read().rstrip()
+ zone = open(zone_file).read().rstrip()
+
+ key_file = harness.make_key()
+
+ producer = gcloud_producer.GCloudProducer(
+ image,
+ zone,
+ machine_type,
+ installers,
+ ssh_key_file=key_file,
+ ssh_user=harness.DEFAULT_USER,
+ ssh_password="")
+
+ try:
+ run(ctx, producer, **kwargs)
+ finally:
+ harness.delete_key()
def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int,
diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py
index 7ab12fac6..0fccb2fad 100644
--- a/benchmarks/runner/commands.py
+++ b/benchmarks/runner/commands.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,9 +22,9 @@ def run_mock(**kwargs):
# mock implementation
"""
-import click
+import os
-from benchmarks import harness
+import click
class RunCommand(click.core.Command):
@@ -90,46 +90,40 @@ class GCPCommand(RunCommand):
"""GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method.
Attributes:
- project: GCP project
- ssh_key_path: path to the ssh-key to use for the run
- image: name of the image to build machines from
- image_project: GCP project under which to find image
- zone: a GCP zone (e.g. us-west1-b)
- ssh_user: username to use for the ssh-key
- ssh_password: password to use for the ssh-key
+ image_file: name of the image to build machines from
+ zone_file: a GCP zone (e.g. us-west1-b)
+ installers: named installers for post-create
+ machine_type: type of machine to create (e.g. n1-standard-4)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- project = click.core.Option(
- ("--project",),
- help="Project to run on if not default value given by 'gcloud config get-value project'."
+ image_file = click.core.Option(
+ ("--image_file",),
+ help="The file containing the image for VMs.",
+ default=os.path.join(
+ os.path.dirname(__file__), "../../tools/images/ubuntu1604.txt"),
+ )
+ zone_file = click.core.Option(
+ ("--zone_file",),
+ help="The file containing the GCP zone.",
+ default=os.path.join(
+ os.path.dirname(__file__), "../../tools/images/zone.txt"),
+ )
+ installers = click.core.Option(
+ ("--installers",),
+ help="The set of installers to use.",
+ multiple=True,
+ )
+ machine_type = click.core.Option(
+ ("--machine_type",),
+ help="Type to make all machines.",
+ default="n1-standard-4",
)
- ssh_key_path = click.core.Option(
- ("--ssh-key-file",),
- help="Path to a valid ssh private key to use. See README on generating a valid ssh key. Set to ~/.ssh/benchmark-tools by default.",
- default=harness.DEFAULT_USER_HOME + "/.ssh/benchmark-tools")
- image = click.core.Option(("--image",),
- help="The image on which to build VMs.",
- default="bm-tools-testing")
- image_project = click.core.Option(
- ("--image_project",),
- help="The project under which the image to be used is listed.",
- default="")
- machine_type = click.core.Option(("--machine_type",),
- help="Type to make all machines.",
- default="n1-standard-4")
- zone = click.core.Option(("--zone",),
- help="The GCP zone to run on.",
- default="")
- ssh_user = click.core.Option(("--ssh-user",),
- help="User for the ssh key.",
- default=harness.DEFAULT_USER)
- ssh_password = click.core.Option(("--ssh-password",),
- help="Password for the ssh key.",
- default="")
self.params.extend([
- project, ssh_key_path, image, image_project, machine_type, zone,
- ssh_user, ssh_password
+ image_file,
+ zone_file,
+ machine_type,
+ installers,
])
diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go
index 72ada5700..73b7c4f5b 100644
--- a/benchmarks/tcp/tcp_proxy.go
+++ b/benchmarks/tcp/tcp_proxy.go
@@ -274,7 +274,7 @@ func (n netstackImpl) listen(port int) (net.Listener, error) {
NIC: nicID,
Port: uint16(port),
}
- listener, err := gonet.NewListener(n.s, addr, ipv4.ProtocolNumber)
+ listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber)
if err != nil {
return nil, err
}
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index 93d4f2b8c..006fcd9ab 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -46,7 +46,6 @@ var (
//
// TODO(b/67298402): Support non-cumulative metrics.
// TODO(b/67298427): Support metric fields.
-//
type Uint64Metric struct {
// value is the actual value of the metric. It must be accessed
// atomically.
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index 4ccc1de86..8904afad9 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -16,7 +16,6 @@ go_library(
"messages.go",
"p9.go",
"path_tree.go",
- "pool.go",
"server.go",
"transport.go",
"transport_flipcall.go",
@@ -27,6 +26,7 @@ go_library(
"//pkg/fdchannel",
"//pkg/flipcall",
"//pkg/log",
+ "//pkg/pool",
"//pkg/sync",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
@@ -41,7 +41,6 @@ go_test(
"client_test.go",
"messages_test.go",
"p9_test.go",
- "pool_test.go",
"transport_test.go",
"version_test.go",
],
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 4045e41fa..a6f493b82 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -22,6 +22,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/pool"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -74,10 +75,10 @@ type Client struct {
socket *unet.Socket
// tagPool is the collection of available tags.
- tagPool pool
+ tagPool pool.Pool
// fidPool is the collection of available fids.
- fidPool pool
+ fidPool pool.Pool
// messageSize is the maximum total size of a message.
messageSize uint32
@@ -155,8 +156,8 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
}
c := &Client{
socket: socket,
- tagPool: pool{start: 1, limit: uint64(NoTag)},
- fidPool: pool{start: 1, limit: uint64(NoFID)},
+ tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)},
+ fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)},
pending: make(map[Tag]*response),
recvr: make(chan bool, 1),
messageSize: messageSize,
diff --git a/pkg/pool/BUILD b/pkg/pool/BUILD
new file mode 100644
index 000000000..7b1c6b75b
--- /dev/null
+++ b/pkg/pool/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "pool",
+ srcs = [
+ "pool.go",
+ ],
+ deps = [
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "pool_test",
+ size = "small",
+ srcs = [
+ "pool_test.go",
+ ],
+ library = ":pool",
+)
diff --git a/pkg/p9/pool.go b/pkg/pool/pool.go
index 2b14a5ce3..a1b2e0cfe 100644
--- a/pkg/p9/pool.go
+++ b/pkg/pool/pool.go
@@ -12,33 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
"gvisor.dev/gvisor/pkg/sync"
)
-// pool is a simple allocator.
-//
-// It is used for both tags and FIDs.
-type pool struct {
+// Pool is a simple allocator.
+type Pool struct {
mu sync.Mutex
// cache is the set of returned values.
cache []uint64
- // start is the starting value (if needed).
- start uint64
+ // Start is the starting value (if needed).
+ Start uint64
// max is the current maximum issued.
max uint64
- // limit is the upper limit.
- limit uint64
+ // Limit is the upper limit.
+ Limit uint64
}
// Get gets a value from the pool.
-func (p *pool) Get() (uint64, bool) {
+func (p *Pool) Get() (uint64, bool) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -50,18 +48,18 @@ func (p *pool) Get() (uint64, bool) {
}
// Over the limit?
- if p.start == p.limit {
+ if p.Start == p.Limit {
return 0, false
}
// Generate a new value.
- v := p.start
- p.start++
+ v := p.Start
+ p.Start++
return v, true
}
// Put returns a value to the pool.
-func (p *pool) Put(v uint64) {
+func (p *Pool) Put(v uint64) {
p.mu.Lock()
p.cache = append(p.cache, v)
p.mu.Unlock()
diff --git a/pkg/p9/pool_test.go b/pkg/pool/pool_test.go
index e4746b8da..d928439c1 100644
--- a/pkg/p9/pool_test.go
+++ b/pkg/pool/pool_test.go
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
"testing"
)
func TestPoolUnique(t *testing.T) {
- p := pool{start: 1, limit: 3}
+ p := Pool{Start: 1, Limit: 3}
got := make(map[uint64]bool)
for {
@@ -39,7 +39,7 @@ func TestPoolUnique(t *testing.T) {
}
func TestExausted(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
for i := 0; i < 499; i++ {
_, ok := p.Get()
if !ok {
@@ -54,7 +54,7 @@ func TestExausted(t *testing.T) {
}
func TestPoolRecycle(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
n1, _ := p.Get()
p.Put(n1)
n2, _ := p.Get()
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index a18093155..3db8bd34b 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -114,6 +114,10 @@ func newX86FPStateSlice() []byte {
size, align := cpuid.HostFeatureSet().ExtendedStateSize()
capacity := size
// Always use at least 4096 bytes.
+ //
+ // For the KVM platform, this state is a fixed 4096 bytes, so make sure
+ // that the underlying array is at _least_ that size otherwise we will
+ // corrupt random memory. This is not a pleasant thing to debug.
if capacity < 4096 {
capacity = 4096
}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index 81b92bb43..6fb756f0e 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -55,7 +55,7 @@ type SignalContext64 struct {
Trapno uint64
Oldmask linux.SignalSet
Cr2 uint64
- // Pointer to a struct _fpstate.
+ // Pointer to a struct _fpstate. See b/33003106#comment8.
Fpstate uint64
Reserved [8]uint64
}
diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go
index 02538bb4f..a76d87e3a 100644
--- a/pkg/sentry/fs/file_overlay_test.go
+++ b/pkg/sentry/fs/file_overlay_test.go
@@ -177,6 +177,7 @@ func TestReaddirRevalidation(t *testing.T) {
// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with
// a frozen dirent tree does not make Readdir calls to the underlying files.
+// This is a regression test for b/114808269.
func TestReaddirOverlayFrozen(t *testing.T) {
ctx := contexttest.Context(t)
diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md
index 5d4ec6c7b..6667a0916 100644
--- a/pkg/sentry/fs/proc/README.md
+++ b/pkg/sentry/fs/proc/README.md
@@ -11,6 +11,8 @@ inconsistency, please file a bug.
The following files are implemented:
+<!-- mdformat off(don't wrap the table) -->
+
| File /proc/ | Content |
| :------------------------ | :---------------------------------------------------- |
| [cpuinfo](#cpuinfo) | Info about the CPU |
@@ -22,6 +24,8 @@ The following files are implemented:
| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus |
| [version](#version) | Kernel version |
+<!-- mdformat on -->
+
### cpuinfo
```bash
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index a27628c0a..2231d6973 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -91,6 +91,7 @@ go_library(
"fs_context.go",
"ipc_namespace.go",
"kernel.go",
+ "kernel_opts.go",
"kernel_state.go",
"pending_signals.go",
"pending_signals_list.go",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index dcd6e91c4..3ee760ba2 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -235,6 +235,9 @@ type Kernel struct {
// events. This is initialized lazily on the first unimplemented
// syscall.
unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"`
+
+ // SpecialOpts contains special kernel options.
+ SpecialOpts
}
// InitKernelArgs holds arguments to Init.
diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go
new file mode 100644
index 000000000..2e66ec587
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_opts.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// SpecialOpts contains non-standard options for the kernel.
+//
+// +stateify savable
+type SpecialOpts struct{}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 5a07d5d0e..023bad156 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -10,6 +10,7 @@ go_library(
"save_restore.go",
"socket.go",
"socket_unsafe.go",
+ "sockopt_impl.go",
"stack.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 34f63986f..de76388ac 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -285,7 +285,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
}
// Whitelist options and constrain option length.
- var optlen int
+ optlen := getSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
switch name {
@@ -330,7 +330,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// Whitelist options and constrain option length.
- var optlen int
+ optlen := setSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
switch name {
@@ -353,6 +353,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
optlen = sizeofInt32
}
}
+
if optlen == 0 {
// Pretend to accept socket options we don't understand. This seems
// dangerous, but it's what netstack does...
diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go
new file mode 100644
index 000000000..8a783712e
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/sockopt_impl.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+func getSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
+
+func setSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 049d04bf2..ed2fbcceb 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2229,11 +2229,16 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
var copied int
// Copy as many views as possible into the user-provided buffer.
- for dst.NumBytes() != 0 {
+ for {
+ // Always do at least one fetchReadView, even if the number of bytes to
+ // read is 0.
err = s.fetchReadView()
if err != nil {
break
}
+ if dst.NumBytes() == 0 {
+ break
+ }
var n int
var e error
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 711969b9b..6e0db2741 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
-// A Listener is a wrapper around a tcpip endpoint that implements
+// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
// net.Listener.
-type Listener struct {
+type TCPListener struct {
stack *stack.Stack
ep tcpip.Endpoint
wq *waiter.Queue
cancel chan struct{}
}
-// NewListener creates a new Listener.
-func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
- // Create TCP endpoint, bind it, then start listening.
+// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
+func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
+ return &TCPListener{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ cancel: make(chan struct{}),
+ }
+}
+
+// ListenTCP creates a new TCPListener.
+func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
+ // Create a TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
@@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr
}
}
- return &Listener{
- stack: s,
- ep: ep,
- wq: &wq,
- cancel: make(chan struct{}),
- }, nil
+ return NewTCPListener(s, &wq, ep), nil
}
// Close implements net.Listener.Close.
-func (l *Listener) Close() error {
+func (l *TCPListener) Close() error {
l.ep.Close()
return nil
}
// Shutdown stops the HTTP server.
-func (l *Listener) Shutdown() {
+func (l *TCPListener) Shutdown() {
l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
close(l.cancel) // broadcast cancellation
}
// Addr implements net.Listener.Addr.
-func (l *Listener) Addr() net.Addr {
+func (l *TCPListener) Addr() net.Addr {
a, err := l.ep.GetLocalAddress()
if err != nil {
return nil
@@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error {
return nil
}
-// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
+// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
// interface.
-type Conn struct {
+type TCPConn struct {
deadlineTimer
wq *waiter.Queue
@@ -228,9 +233,9 @@ type Conn struct {
read buffer.View
}
-// NewConn creates a new Conn.
-func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
- c := &Conn{
+// NewTCPConn creates a new TCPConn.
+func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
+ c := &TCPConn{
wq: wq,
ep: ep,
}
@@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
}
// Accept implements net.Conn.Accept.
-func (l *Listener) Accept() (net.Conn, error) {
+func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept()
if err == tcpip.ErrWouldBlock {
@@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) {
}
}
- return NewConn(wq, n), nil
+ return NewTCPConn(wq, n), nil
}
type opErrorer interface {
@@ -323,7 +328,7 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a
}
// Read implements net.Conn.Read.
-func (c *Conn) Read(b []byte) (int, error) {
+func (c *TCPConn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
@@ -352,7 +357,7 @@ func (c *Conn) Read(b []byte) (int, error) {
}
// Write implements net.Conn.Write.
-func (c *Conn) Write(b []byte) (int, error) {
+func (c *TCPConn) Write(b []byte) (int, error) {
deadline := c.writeCancel()
// Check if deadlineTimer has already expired.
@@ -431,7 +436,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
// Close implements net.Conn.Close.
-func (c *Conn) Close() error {
+func (c *TCPConn) Close() error {
c.ep.Close()
return nil
}
@@ -440,7 +445,7 @@ func (c *Conn) Close() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
-func (c *Conn) CloseRead() error {
+func (c *TCPConn) CloseRead() error {
if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -451,7 +456,7 @@ func (c *Conn) CloseRead() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
-func (c *Conn) CloseWrite() error {
+func (c *TCPConn) CloseWrite() error {
if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -459,7 +464,7 @@ func (c *Conn) CloseWrite() error {
}
// LocalAddr implements net.Conn.LocalAddr.
-func (c *Conn) LocalAddr() net.Addr {
+func (c *TCPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
@@ -468,7 +473,7 @@ func (c *Conn) LocalAddr() net.Addr {
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *Conn) RemoteAddr() net.Addr {
+func (c *TCPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -476,7 +481,7 @@ func (c *Conn) RemoteAddr() net.Addr {
return fullToTCPAddr(a)
}
-func (c *Conn) newOpError(op string, err error) *net.OpError {
+func (c *TCPConn) newOpError(op string, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "tcp",
@@ -494,14 +499,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
}
-// DialTCP creates a new TCP Conn connected to the specified address.
-func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+// DialTCP creates a new TCPConn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCP Conn connected to the specified address
+// DialContextTCP creates a new TCPConn connected to the specified address
// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -543,12 +548,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
}
- return NewConn(&wq, ep), nil
+ return NewTCPConn(&wq, ep), nil
}
-// A PacketConn is a wrapper around a tcpip endpoint that implements
-// net.PacketConn.
-type PacketConn struct {
+// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
+// net.Conn and net.PacketConn.
+type UDPConn struct {
deadlineTimer
stack *stack.Stack
@@ -556,9 +561,9 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn {
- c := &PacketConn{
+// NewUDPConn creates a new UDPConn.
+func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
+ c := &UDPConn{
stack: s,
ep: ep,
wq: wq,
@@ -567,12 +572,12 @@ func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketC
return c
}
-// DialUDP creates a new PacketConn.
+// DialUDP creates a new UDPConn.
//
// If laddr is nil, a local address is automatically chosen.
//
-// If raddr is nil, the PacketConn is left unconnected.
-func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
+// If raddr is nil, the UDPConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
@@ -591,7 +596,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := NewPacketConn(s, &wq, ep)
+ c := NewUDPConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -608,11 +613,11 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
return c, nil
}
-func (c *PacketConn) newOpError(op string, err error) *net.OpError {
+func (c *UDPConn) newOpError(op string, err error) *net.OpError {
return c.newRemoteOpError(op, nil, err)
}
-func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "udp",
@@ -623,7 +628,7 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *PacketConn) RemoteAddr() net.Addr {
+func (c *UDPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -632,13 +637,13 @@ func (c *PacketConn) RemoteAddr() net.Addr {
}
// Read implements net.Conn.Read
-func (c *PacketConn) Read(b []byte) (int, error) {
+func (c *UDPConn) Read(b []byte) (int, error) {
bytesRead, _, err := c.ReadFrom(b)
return bytesRead, err
}
// ReadFrom implements net.PacketConn.ReadFrom.
-func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
@@ -650,12 +655,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return copy(b, read), fullToUDPAddr(addr), nil
}
-func (c *PacketConn) Write(b []byte) (int, error) {
+func (c *UDPConn) Write(b []byte) (int, error) {
return c.WriteTo(b, nil)
}
// WriteTo implements net.PacketConn.WriteTo.
-func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
deadline := c.writeCancel()
// Check if deadline has already expired.
@@ -713,13 +718,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// Close implements net.PacketConn.Close.
-func (c *PacketConn) Close() error {
+func (c *UDPConn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.PacketConn.LocalAddr.
-func (c *PacketConn) LocalAddr() net.Addr {
+func (c *UDPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ee077ae83..ea0a0409a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -41,7 +41,7 @@ const (
)
func TestTimeouts(t *testing.T) {
- nc := NewConn(nil, nil)
+ nc := NewTCPConn(nil, nil)
dlfs := []struct {
name string
f func(time.Time) error
@@ -132,7 +132,7 @@ func TestCloseReader(t *testing.T) {
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -168,7 +168,7 @@ func TestCloseReader(t *testing.T) {
sender.close()
}
-// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
+// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
// using tcp.Forwarder.
func TestCloseReaderWithForwarder(t *testing.T) {
s, err := newLoopbackStack()
@@ -192,7 +192,7 @@ func TestCloseReaderWithForwarder(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
@@ -238,7 +238,7 @@ func TestCloseRead(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
@@ -257,7 +257,7 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseRead(); err != nil {
t.Errorf("c.CloseRead() = %v", err)
@@ -291,7 +291,7 @@ func TestCloseWrite(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
n, e := c.Read(make([]byte, 256))
if n != 0 || e != io.EOF {
@@ -309,7 +309,7 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseWrite(); err != nil {
t.Errorf("c.CloseWrite() = %v", err)
@@ -353,7 +353,7 @@ func TestUDPForwarder(t *testing.T) {
}
defer ep.Close()
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
@@ -396,7 +396,7 @@ func TestDeadlineChange(t *testing.T) {
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -541,7 +541,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
addr := tcpip.FullAddress{NICID, ip, 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 60817d36d..45dc757c7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "log"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -194,7 +196,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// TODO(b/148429853): Properly process the NS message and do Neighbor
// Unreachability Detection.
for {
- opt, done, _ := it.Next()
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
if done {
break
}
@@ -253,21 +259,25 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
}
na := header.NDPNeighborAdvert(h.NDPPayload())
+ it, err := na.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
targetAddr := na.TargetAddress()
stack := r.Stack()
rxNICID := r.NICID()
- isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr)
- if err != nil {
+ if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil {
// We will only get an error if rxNICID is unrecognized,
// which should not happen. For now short-circuit this
// packet.
//
// TODO(b/141002840): Handle this better?
return
- }
-
- if isTentative {
+ } else if isTentative {
// We just got an NA from a node that owns an address we
// are performing DAD on, implying the address is not
// unique. In this case we let the stack know so it can
@@ -283,13 +293,29 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// scenario is beyond the scope of RFC 4862. As such, we simply
// ignore such a scenario for now and proceed as normal.
//
+ // If the NA message has the target link layer option, update the link
+ // address cache with the link address for the target of the message.
+ //
// TODO(b/143147598): Handle the scenario described above. Also
// inform the netstack integration that a duplicate address was
// detected outside of DAD.
+ //
+ // TODO(b/148429853): Properly process the NA message and do Neighbor
+ // Unreachability Detection.
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress)
- if targetAddr != r.RemoteAddress {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
+ switch opt := opt.(type) {
+ case header.NDPTargetLinkLayerAddressOption:
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress())
+ }
}
case header.ICMPv6EchoRequest:
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d0e930e20..50c4b6474 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -121,21 +121,60 @@ func TestICMPCounts(t *testing.T) {
}
defer r.Release()
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
- typ header.ICMPv6Type
- size int
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
}{
- {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize},
- {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize},
- {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize},
- {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize},
- {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize},
- {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize},
- {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize},
- {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize},
- {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize},
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize},
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
}
handleIPv6Payload := func(hdr buffer.Prependable) {
@@ -154,10 +193,13 @@ func TestICMPCounts(t *testing.T) {
}
for _, typ := range types {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
handleIPv6Payload(hdr)
}
@@ -372,97 +414,104 @@ func TestLinkResolution(t *testing.T) {
}
func TestICMPChecksumValidationSimple(t *testing.T) {
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
{
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "DstUnreachable",
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.DstUnreachable
},
},
{
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "PacketTooBig",
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.PacketTooBig
},
},
{
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "TimeExceeded",
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.TimeExceeded
},
},
{
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "ParamProblem",
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.ParamProblem
},
},
{
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoRequest",
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoRequest
},
},
{
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoReply",
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoReply
},
},
{
- "RouterSolicit",
- header.ICMPv6RouterSolicit,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
},
{
- "RouterAdvert",
- header.ICMPv6RouterAdvert,
- header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterAdvert
},
},
{
- "NeighborSolicit",
- header.ICMPv6NeighborSolicit,
- header.ICMPv6NeighborSolicitMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborSolicit
},
},
{
- "NeighborAdvert",
- header.ICMPv6NeighborAdvert,
- header.ICMPv6NeighborAdvertSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborAdvert
},
},
{
- "RedirectMsg",
- header.ICMPv6RedirectMsg,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RedirectMsg
},
},
@@ -494,16 +543,19 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
)
}
- handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
+ handleIPv6Payload := func(checksum bool) {
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size),
+ PayloadLength: uint16(typ.size + extraDataLen),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: lladdr1,
@@ -528,7 +580,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
// Without setting checksum, the incoming packet should
// be invalid.
- handleIPv6Payload(typ.typ, typ.size, false)
+ handleIPv6Payload(false)
if got := invalid.Value(); got != 1 {
t.Fatalf("got invalid = %d, want = 1", got)
}
@@ -538,7 +590,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
// When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, true)
+ handleIPv6Payload(true)
if got := typStat.Value(); got != 1 {
t.Fatalf("got %s = %d, want = 1", typ.name, got)
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index bd732f93f..c9395de52 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -70,76 +70,29 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
return s, ep
}
-// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an
-// NDP NS message with the Source Link Layer Address option results in a
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(0, 1280, linkAddr0)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(lladdr0)
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
- if linkAddr != linkAddr1 {
- t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1)
- }
-}
-
-// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving
-// an NDP NS message with an invalid Source Link Layer Address option does not
-// result in a new entry in the link address cache for the sender of the
-// message.
-func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
- const nicID = 1
-
tests := []struct {
- name string
- optsBuf []byte
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
}{
{
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
name: "Too Small",
- optsBuf: []byte{1, 1, 1, 2, 3, 4, 5},
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
},
{
name: "Invalid Length",
- optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6},
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
},
}
@@ -186,20 +139,138 @@ func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
})
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
}
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// valid NDP NA message with the Target Link Layer Address option results in a
+// new entry in the link address cache for the target of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
}
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
}
- if linkAddr != "" {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr)
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
}
})
}
@@ -238,27 +309,59 @@ func TestHopLimitValidation(t *testing.T) {
})
}
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
- {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- }},
- {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- }},
- {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- }},
- {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- }},
- {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- }},
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
}
for _, typ := range types {
@@ -270,10 +373,13 @@ func TestHopLimitValidation(t *testing.T) {
invalid := stats.Invalid
typStat := typ.statCounter(stats)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index f5b750046..705cf01ee 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -78,11 +78,15 @@ go_test(
go_test(
name = "stack_test",
size = "small",
- srcs = ["linkaddrcache_test.go"],
+ srcs = [
+ "linkaddrcache_test.go",
+ "nic_test.go",
+ ],
library = ":stack",
deps = [
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
],
)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 8af8565f7..1e575bdaf 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -478,13 +478,17 @@ func TestDADFail(t *testing.T) {
{
"RxAdvert",
func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(pkt.NDPPayload())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -1535,7 +1539,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
}
// Checks to see if list contains an IPv6 address, item.
-func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
+func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
protocolAddress := tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: item,
@@ -1661,7 +1665,7 @@ func TestAutoGenAddr(t *testing.T) {
// with non-zero lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -1677,10 +1681,10 @@ func TestAutoGenAddr(t *testing.T) {
// Receive an RA with prefix2 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
t.Fatalf("Should have %s in the list of addresses", addr2)
}
@@ -1701,10 +1705,10 @@ func TestAutoGenAddr(t *testing.T) {
case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
t.Fatalf("Should have %s in the list of addresses", addr2)
}
}
@@ -1849,7 +1853,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Receive PI for prefix1.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
expectPrimaryAddr(addr1)
@@ -1857,7 +1861,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Deprecate addr for prefix1 immedaitely.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
// addr should still be the primary endpoint as there are no other addresses.
@@ -1875,7 +1879,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Receive PI for prefix2.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -1883,7 +1887,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Deprecate addr for prefix2 immedaitely.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr1 should be the primary endpoint now since addr2 is deprecated but
@@ -1978,7 +1982,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Receive PI for prefix2.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -1986,10 +1990,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Receive a PI for prefix1.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr1)
@@ -2005,10 +2009,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be deprecated.
expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr2 should be the primary endpoint now since addr1 is deprecated but
@@ -2045,10 +2049,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be deprecated.
expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr2 should be the primary endpoint now since it is not deprecated.
@@ -2059,10 +2063,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be invalidated.
expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -2108,10 +2112,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should not have %s in the list of addresses", addr2)
}
// Should not have any primary endpoints.
@@ -2596,7 +2600,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -2609,7 +2613,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
default:
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -2620,7 +2624,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
t.Fatal("unexpectedly received an auto gen addr event")
case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
}
@@ -2698,17 +2702,17 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
const validLifetimeSecondPrefix1 = 1
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
// Receive an RA with prefix2 in a PI with a large valid lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
@@ -2721,10 +2725,10 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
}
@@ -3010,16 +3014,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
nicinfo := s.NICInfo()
nic1Addrs := nicinfo[nicID1].ProtocolAddresses
nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !contains(nic1Addrs, e1Addr1) {
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
}
- if !contains(nic1Addrs, e1Addr2) {
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
}
- if !contains(nic2Addrs, e2Addr1) {
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
}
- if !contains(nic2Addrs, e2Addr2) {
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
@@ -3098,16 +3102,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
nicinfo = s.NICInfo()
nic1Addrs = nicinfo[nicID1].ProtocolAddresses
nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if contains(nic1Addrs, e1Addr1) {
+ if containsV6Addr(nic1Addrs, e1Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
}
- if contains(nic1Addrs, e1Addr2) {
+ if containsV6Addr(nic1Addrs, e1Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
}
- if contains(nic2Addrs, e2Addr1) {
+ if containsV6Addr(nic2Addrs, e2Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
}
- if contains(nic2Addrs, e2Addr2) {
+ if containsV6Addr(nic2Addrs, e2Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7dad9a8cb..682e9c416 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -16,6 +16,7 @@ package stack
import (
"log"
+ "reflect"
"sort"
"strings"
"sync/atomic"
@@ -39,6 +40,7 @@ type NIC struct {
mu struct {
sync.RWMutex
+ enabled bool
spoofing bool
promiscuous bool
primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
@@ -56,6 +58,14 @@ type NIC struct {
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
+
+ DisabledRx DirectionStats
+}
+
+func makeNICStats() NICStats {
+ var s NICStats
+ tcpip.InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
}
// DirectionStats includes packet and byte counts.
@@ -99,16 +109,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
name: name,
linkEP: ep,
context: ctx,
- stats: NICStats{
- Tx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- Rx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- },
+ stats: makeNICStats(),
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
@@ -137,14 +138,30 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// enable enables the NIC. enable will attach the link to its LinkEndpoint and
// join the IPv6 All-Nodes Multicast address (ff02::1).
func (n *NIC) enable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if n.mu.enabled {
+ return nil
+ }
+
+ n.mu.enabled = true
+
n.attachLinkEndpoint()
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if err := n.AddAddress(tcpip.ProtocolAddress{
+ if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: header.IPv4ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint); err != nil {
+ }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
@@ -166,8 +183,22 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
- n.mu.Lock()
- defer n.mu.Unlock()
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the NIC was
+ // last enabled, other devices may have acquired the same addresses.
+ for _, r := range n.mu.endpoints {
+ addr := r.ep.ID().LocalAddress
+ if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
+ continue
+ }
+
+ r.setKind(permanentTentative)
+ if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
+ return err
+ }
+ }
if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
return err
@@ -633,7 +664,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
// If the address is an IPv6 address and it is a permanent address,
- // mark it as tentative so it goes through the DAD process.
+ // mark it as tentative so it goes through the DAD process if the NIC is
+ // enabled. If the NIC is not enabled, DAD will be started when the NIC is
+ // enabled.
if isIPv6Unicast && kind == permanent {
kind = permanentTentative
}
@@ -668,8 +701,8 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
n.insertPrimaryEndpointLocked(ref, peb)
- // If we are adding a tentative IPv6 address, start DAD.
- if isIPv6Unicast && kind == permanentTentative {
+ // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
+ if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
@@ -700,9 +733,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
// Don't include tentative, expired or temporary endpoints to
// avoid confusion and prevent the caller from using those.
switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- // TODO(b/140898488): Should tentative addresses be
- // returned?
+ case permanentExpired, temporary:
continue
}
addrs = append(addrs, tcpip.ProtocolAddress{
@@ -1016,11 +1047,23 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ // If the NIC is not yet enabled, don't receive any packets.
+ if !enabled {
+ n.mu.RUnlock()
+
+ n.stats.DisabledRx.Packets.Increment()
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ return
+ }
+
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
+ n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
@@ -1032,7 +1075,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
// Are any packet sockets listening for this network protocol?
- n.mu.RLock()
packetEPs := n.mu.packetEPs[protocol]
// Check whether there are packet sockets listening for every protocol.
// If we received a packet with protocol EthernetProtocolAll, then the
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
new file mode 100644
index 000000000..edaee3b86
--- /dev/null
+++ b/pkg/tcpip/stack/nic_test.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
+ // When the NIC is disabled, the only field that matters is the stats field.
+ // This test is limited to stats counter checks.
+ nic := NIC{
+ stats: makeNICStats(),
+ }
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 834fe9487..243868f3a 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -2561,3 +2561,118 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
})
}
}
+
+// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
+// was disabled have DAD performed on them when the NIC is enabled.
+func TestDoDADWhenNICEnabled(t *testing.T) {
+ t.Parallel()
+
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(dadTransmits, 1280, linkAddr1)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 128,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ // Address should be in the list of all addresses.
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should be tentative so it should not be a main address.
+ got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Enabling the NIC should start DAD for the address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != nicID {
+ t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID)
+ }
+ if e.addr != addr.AddressWithPrefix.Address {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", e.addr, addr.AddressWithPrefix.Address)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+
+ // Enabling the NIC again should be a no-op.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0fa141d58..0e944712f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1124,6 +1124,10 @@ type ReadErrors struct {
// InvalidEndpointState is the number of times we found the endpoint state
// to be unexpected.
InvalidEndpointState StatCounter
+
+ // NotConnected is the number of times we tried to read but found that the
+ // endpoint was not connected.
+ NotConnected StatCounter
}
// WriteErrors collects packet write errors from an endpoint write call.
@@ -1166,7 +1170,9 @@ type TransportEndpointStats struct {
// marker interface.
func (*TransportEndpointStats) IsEndpointStats() {}
-func fillIn(v reflect.Value) {
+// InitStatCounters initializes v's fields with nil StatCounter fields to new
+// StatCounters.
+func InitStatCounters(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
if s, ok := v.Addr().Interface().(**StatCounter); ok {
@@ -1174,14 +1180,14 @@ func fillIn(v reflect.Value) {
*s = new(StatCounter)
}
} else {
- fillIn(v)
+ InitStatCounters(v)
}
}
}
// FillIn returns a copy of s with nil fields initialized to new StatCounters.
func (s Stats) FillIn() Stats {
- fillIn(reflect.ValueOf(&s).Elem())
+ InitStatCounters(reflect.ValueOf(&s).Elem())
return s
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b5a8e15ee..f2be0e651 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1003,8 +1003,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ e.stats.ReadErrors.NotConnected.Increment()
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
}
v, err := e.readLocked()
@@ -2166,6 +2166,9 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
e.isRegistered = true
e.setEndpointState(StateListen)
+ // The channel may be non-nil when we're restoring the endpoint, and it
+ // may be pre-populated with some previously accepted (but not Accepted)
+ // endpoints.
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 2c1505067..cc118c993 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5405,12 +5405,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- // Expect InvalidEndpointState errors on a read at this point.
- if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
+ t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
}
- if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
- t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
}
if err := ep.Listen(10); err != nil {
diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD
index ce30f6c53..ed18f0047 100644
--- a/runsc/boot/filter/BUILD
+++ b/runsc/boot/filter/BUILD
@@ -8,6 +8,7 @@ go_library(
"config.go",
"config_amd64.go",
"config_arm64.go",
+ "config_profile.go",
"extra_filters.go",
"extra_filters_msan.go",
"extra_filters_race.go",
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index f8d351c7b..c69f4c602 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -536,16 +536,3 @@ func controlServerFilters(fd int) seccomp.SyscallRules {
},
}
}
-
-// profileFilters returns extra syscalls made by runtime/pprof package.
-func profileFilters() seccomp.SyscallRules {
- return seccomp.SyscallRules{
- syscall.SYS_OPENAT: []seccomp.Rule{
- {
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
- },
- },
- }
-}
diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go
new file mode 100644
index 000000000..194952a7b
--- /dev/null
+++ b/runsc/boot/filter/config_profile.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// profileFilters returns extra syscalls made by runtime/pprof package.
+func profileFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_OPENAT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
+ },
+ },
+ }
+}
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 060b63bf3..c2518d52b 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -196,7 +196,10 @@ func TestJobControlSignalExec(t *testing.T) {
defer ptyMaster.Close()
defer ptySlave.Close()
- // Exec bash and attach a terminal.
+ // Exec bash and attach a terminal. Note that occasionally /bin/sh
+ // may be a different shell or have a different configuration (such
+ // as disabling interactive mode and job control). Since we want to
+ // explicitly test interactive mode, use /bin/bash. See b/116981926.
execArgs := &control.ExecArgs{
Filename: "/bin/bash",
// Don't let bash execute from profile or rc files, otherwise
diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go
index 9b6346ca2..1ff5e8cc3 100644
--- a/runsc/dockerutil/dockerutil.go
+++ b/runsc/dockerutil/dockerutil.go
@@ -143,8 +143,11 @@ func PrepareFiles(names ...string) (string, error) {
return "", fmt.Errorf("os.Chmod(%q, 0777) failed: %v", dir, err)
}
for _, name := range names {
- src := getLocalPath(name)
- dst := path.Join(dir, name)
+ src, err := testutil.FindFile(name)
+ if err != nil {
+ return "", fmt.Errorf("testutil.Preparefiles(%q) failed: %v", name, err)
+ }
+ dst := path.Join(dir, path.Base(name))
if err := testutil.Copy(src, dst); err != nil {
return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
}
@@ -152,10 +155,6 @@ func PrepareFiles(names ...string) (string, error) {
return dir, nil
}
-func getLocalPath(file string) string {
- return path.Join(".", file)
-}
-
// do executes docker command.
func do(args ...string) (string, error) {
log.Printf("Running: docker %s\n", args)
diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD
index f845120b0..945405303 100644
--- a/runsc/testutil/BUILD
+++ b/runsc/testutil/BUILD
@@ -5,7 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
testonly = 1,
- srcs = ["testutil.go"],
+ srcs = [
+ "testutil.go",
+ "testutil_runfiles.go",
+ ],
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index edf2e809a..92d677e71 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -79,60 +79,6 @@ func ConfigureExePath() error {
return nil
}
-// FindFile searchs for a file inside the test run environment. It returns the
-// full path to the file. It fails if none or more than one file is found.
-func FindFile(path string) (string, error) {
- wd, err := os.Getwd()
- if err != nil {
- return "", err
- }
-
- // The test root is demarcated by a path element called "__main__". Search for
- // it backwards from the working directory.
- root := wd
- for {
- dir, name := filepath.Split(root)
- if name == "__main__" {
- break
- }
- if len(dir) == 0 {
- return "", fmt.Errorf("directory __main__ not found in %q", wd)
- }
- // Remove ending slash to loop around.
- root = dir[:len(dir)-1]
- }
-
- // Annoyingly, bazel adds the build type to the directory path for go
- // binaries, but not for c++ binaries. We use two different patterns to
- // to find our file.
- patterns := []string{
- // Try the obvious path first.
- filepath.Join(root, path),
- // If it was a go binary, use a wildcard to match the build
- // type. The pattern is: /test-path/__main__/directories/*/file.
- filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
- }
-
- for _, p := range patterns {
- matches, err := filepath.Glob(p)
- if err != nil {
- // "The only possible returned error is ErrBadPattern,
- // when pattern is malformed." -godoc
- return "", fmt.Errorf("error globbing %q: %v", p, err)
- }
- switch len(matches) {
- case 0:
- // Try the next pattern.
- case 1:
- // We found it.
- return matches[0], nil
- default:
- return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
- }
- }
- return "", fmt.Errorf("file %q not found", path)
-}
-
// TestConfig returns the default configuration to use in tests. Note that
// 'RootDir' must be set by caller if required.
func TestConfig() *boot.Config {
@@ -173,6 +119,13 @@ func NewSpecWithArgs(args ...string) *specs.Spec {
Capabilities: specutils.AllCapabilities(),
},
Mounts: []specs.Mount{
+ // Hide the host /etc to avoid any side-effects.
+ // For example, bash reads /etc/passwd and if it is
+ // very big, tests can fail by timeout.
+ {
+ Type: "tmpfs",
+ Destination: "/etc",
+ },
// Root is readonly, but many tests want to write to tmpdir.
// This creates a writable mount inside the root. Also, when tmpdir points
// to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs
diff --git a/runsc/testutil/testutil_runfiles.go b/runsc/testutil/testutil_runfiles.go
new file mode 100644
index 000000000..ece9ea9a1
--- /dev/null
+++ b/runsc/testutil/testutil_runfiles.go
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+)
+
+// FindFile searchs for a file inside the test run environment. It returns the
+// full path to the file. It fails if none or more than one file is found.
+func FindFile(path string) (string, error) {
+ wd, err := os.Getwd()
+ if err != nil {
+ return "", err
+ }
+
+ // The test root is demarcated by a path element called "__main__". Search for
+ // it backwards from the working directory.
+ root := wd
+ for {
+ dir, name := filepath.Split(root)
+ if name == "__main__" {
+ break
+ }
+ if len(dir) == 0 {
+ return "", fmt.Errorf("directory __main__ not found in %q", wd)
+ }
+ // Remove ending slash to loop around.
+ root = dir[:len(dir)-1]
+ }
+
+ // Annoyingly, bazel adds the build type to the directory path for go
+ // binaries, but not for c++ binaries. We use two different patterns to
+ // to find our file.
+ patterns := []string{
+ // Try the obvious path first.
+ filepath.Join(root, path),
+ // If it was a go binary, use a wildcard to match the build
+ // type. The pattern is: /test-path/__main__/directories/*/file.
+ filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
+ }
+
+ for _, p := range patterns {
+ matches, err := filepath.Glob(p)
+ if err != nil {
+ // "The only possible returned error is ErrBadPattern,
+ // when pattern is malformed." -godoc
+ return "", fmt.Errorf("error globbing %q: %v", p, err)
+ }
+ switch len(matches) {
+ case 0:
+ // Try the next pattern.
+ case 1:
+ // We found it.
+ return matches[0], nil
+ default:
+ return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
+ }
+ }
+ return "", fmt.Errorf("file %q not found", path)
+}
diff --git a/test/image/image_test.go b/test/image/image_test.go
index d0dcb1861..0a1e19d6f 100644
--- a/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -107,7 +107,7 @@ func TestHttpd(t *testing.T) {
}
d := dockerutil.MakeDocker("http-test")
- dir, err := dockerutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -139,7 +139,7 @@ func TestNginx(t *testing.T) {
}
d := dockerutil.MakeDocker("net-test")
- dir, err := dockerutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -183,7 +183,7 @@ func TestMysql(t *testing.T) {
}
client := dockerutil.MakeDocker("mysql-client-test")
- dir, err := dockerutil.PrepareFiles("mysql.sql")
+ dir, err := dockerutil.PrepareFiles("test/image/mysql.sql")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -283,7 +283,7 @@ func TestRuby(t *testing.T) {
}
d := dockerutil.MakeDocker("ruby-test")
- dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh")
+ dir, err := dockerutil.PrepareFiles("test/image/ruby.rb", "test/image/ruby.sh")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
diff --git a/test/iptables/README.md b/test/iptables/README.md
index 8f61b4c41..c2b934e1f 100644
--- a/test/iptables/README.md
+++ b/test/iptables/README.md
@@ -28,7 +28,7 @@ Your test is now runnable with bazel!
Build the testing Docker container:
```bash
-$ bazel run //test/iptables/runner-image -- --norun
+$ bazel run //test/iptables/runner:runner-image -- --norun
```
Run an individual test via:
diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl
index 1df761dd0..cbab85ef7 100644
--- a/test/syscalls/build_defs.bzl
+++ b/test/syscalls/build_defs.bzl
@@ -2,8 +2,6 @@
load("//tools:defs.bzl", "loopback")
-# syscall_test is a macro that will create targets to run the given test target
-# on the host (native) and runsc.
def syscall_test(
test,
shard_count = 5,
@@ -13,6 +11,19 @@ def syscall_test(
add_uds_tree = False,
add_hostinet = False,
tags = None):
+ """syscall_test is a macro that will create targets for all platforms.
+
+ Args:
+ test: the test target.
+ shard_count: shards for defined tests.
+ size: the defined test size.
+ use_tmpfs: use tmpfs in the defined tests.
+ add_overlay: add an overlay test.
+ add_uds_tree: add a UDS test.
+ add_hostinet: add a hostinet test.
+ tags: starting test tags.
+ """
+
_syscall_test(
test = test,
shard_count = shard_count,
@@ -111,6 +122,19 @@ def _syscall_test(
# all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
tags += [full_platform, "file_" + file_access]
+ # Hash this target into one of 15 buckets. This can be used to
+ # randomly split targets between different workflows.
+ hash15 = hash(native.package_name() + name) % 15
+ tags.append("hash15:" + str(hash15))
+
+ # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until
+ # we figure out how to request ipv4 sockets on Guitar machines.
+ if network == "host":
+ tags.append("noguitar")
+
+ # Disable off-host networking.
+ tags.append("requires-net:loopback")
+
# Add tag to prevent the tests from running in a Bazel sandbox.
# TODO(b/120560048): Make the tests run without this tag.
tags.append("no-sandbox")
@@ -118,8 +142,11 @@ def _syscall_test(
# TODO(b/112165693): KVM tests are tagged "manual" to until the platform is
# more stable.
if platform == "kvm":
- tags += ["manual"]
- tags += ["requires-kvm"]
+ tags.append("manual")
+ tags.append("requires-kvm")
+
+ # TODO(b/112165693): Remove when tests pass reliably.
+ tags.append("notap")
args = [
# Arguments are passed directly to syscall_test_runner binary.
diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc
index 9883aef61..c47a05181 100644
--- a/test/syscalls/linux/32bit.cc
+++ b/test/syscalls/linux/32bit.cc
@@ -155,7 +155,7 @@ TEST(Syscall32Bit, Syscall) {
case PlatformSupport::Ignored:
// See above.
EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
- ::testing::KilledBySignal(SIGILL), "");
+ ::testing::KilledBySignal(SIGSEGV), "");
break;
case PlatformSupport::Allowed:
diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc
index 0a2d44a2c..85ec013d5 100644
--- a/test/syscalls/linux/chroot.cc
+++ b/test/syscalls/linux/chroot.cc
@@ -167,7 +167,7 @@ TEST(ChrootTest, DotDotFromOpenFD) {
}
// Test that link resolution in a chroot can escape the root by following an
-// open proc fd.
+// open proc fd. Regression test for b/32316719.
TEST(ChrootTest, ProcFdLinkResolutionInChroot) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc
index f41f99900..7cd6a75bd 100644
--- a/test/syscalls/linux/concurrency.cc
+++ b/test/syscalls/linux/concurrency.cc
@@ -46,7 +46,8 @@ TEST(ConcurrencyTest, SingleProcessMultithreaded) {
}
// Test that multiple threads in this process continue to execute in parallel,
-// even if an unrelated second process is spawned.
+// even if an unrelated second process is spawned. Regression test for
+// b/32119508.
TEST(ConcurrencyTest, MultiProcessMultithreaded) {
// In PID 1, start TIDs 1 and 2, and put both to sleep.
//
diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc
index b790fe5be..2989379b7 100644
--- a/test/syscalls/linux/exec_proc_exe_workload.cc
+++ b/test/syscalls/linux/exec_proc_exe_workload.cc
@@ -21,6 +21,12 @@
#include "test/util/posix_error.h"
int main(int argc, char** argv, char** envp) {
+ // This is annoying. Because remote build systems may put these binaries
+ // in a content-addressable-store, you may wind up with /proc/self/exe
+ // pointing to some random path (but with a sensible argv[0]).
+ //
+ // Therefore, this test simply checks that the /proc/self/exe
+ // is absolute and *doesn't* match argv[1].
std::string exe =
gvisor::testing::ProcessExePath(getpid()).ValueOrDie();
if (exe[0] != '/') {
diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc
index 906f3358d..ff8bdfeb0 100644
--- a/test/syscalls/linux/fork.cc
+++ b/test/syscalls/linux/fork.cc
@@ -271,7 +271,7 @@ TEST_F(ForkTest, Alarm) {
EXPECT_EQ(0, alarmed);
}
-// Child cannot affect parent private memory.
+// Child cannot affect parent private memory. Regression test for b/24137240.
TEST_F(ForkTest, PrivateMemory) {
std::atomic<uint32_t> local(0);
@@ -298,6 +298,9 @@ TEST_F(ForkTest, PrivateMemory) {
}
// Kernel-accessed buffers should remain coherent across COW.
+//
+// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates
+// differently. Regression test for b/33811887.
TEST_F(ForkTest, COWSegment) {
constexpr int kBufSize = 1024;
char* read_buf = private_;
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index 1c4d9f1c7..11fb1b457 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -1418,7 +1418,7 @@ TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) {
//
// On most platforms this is trivial, but when the file is mapped via the sentry
// page cache (which does not yet support writing to shared mappings), a bug
-// caused reads to fail unnecessarily on such mappings.
+// caused reads to fail unnecessarily on such mappings. See b/28913513.
TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) {
uintptr_t addr;
size_t len = strlen(kFileContents);
@@ -1435,7 +1435,7 @@ TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) {
// Tests that EFAULT is returned when invoking a syscall that requires the OS to
// read past end of file (resulting in a fault in sentry context in the gVisor
-// case).
+// case). See b/28913513.
TEST_F(MMapFileTest, InternalSigBus) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
@@ -1578,7 +1578,7 @@ TEST_F(MMapFileTest, Bug38498194) {
}
// Tests that reading from a file to a memory mapping of the same file does not
-// deadlock.
+// deadlock. See b/34813270.
TEST_F(MMapFileTest, SelfRead) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
@@ -1590,7 +1590,7 @@ TEST_F(MMapFileTest, SelfRead) {
}
// Tests that writing to a file from a memory mapping of the same file does not
-// deadlock.
+// deadlock. Regression test for b/34813270.
TEST_F(MMapFileTest, SelfWrite) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 431733dbe..902d0a0dc 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -132,6 +132,7 @@ TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) {
}
// A file originally created RW, but opened RO can later be opened RW.
+// Regression test for b/65385065.
TEST(CreateTest, OpenCreateROThenRW) {
TempPath file(NewTempAbsPath());
diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc
index f7ea44054..5b0743fe9 100644
--- a/test/syscalls/linux/preadv.cc
+++ b/test/syscalls/linux/preadv.cc
@@ -37,6 +37,7 @@ namespace testing {
namespace {
+// Stress copy-on-write. Attempts to reproduce b/38430174.
TEST(PreadvTest, MMConcurrencyStress) {
// Fill a one-page file with zeroes (the contents don't really matter).
const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 169b723eb..a23fdb58d 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -1352,13 +1352,19 @@ TEST(ProcPidSymlink, SubprocessZombied) {
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
- // 4.17 & gVisor: Syscall succeeds and returns 1
+ //
+ // ~4.3: Syscall fails with EACCES.
+ // 4.17 & gVisor: Syscall succeeds and returns 1.
+ //
// EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
// SyscallFailsWithErrno(EACCES));
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
- // 4.17 & gVisor: Syscall succeeds and returns 1.
+ //
+ // ~4.3: Syscall fails with EACCES.
+ // 4.17 & gVisor: Syscall succeeds and returns 1.
+ //
// EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
// SyscallFailsWithErrno(EACCES));
}
@@ -1431,8 +1437,12 @@ TEST(ProcPidFile, SubprocessRunning) {
TEST(ProcPidFile, SubprocessZombie) {
char buf[1];
- // 4.17: Succeeds and returns 1
- // gVisor: Succeeds and returns 0
+ // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent
+ // behavior on different kernels.
+ //
+ // ~4.3: Succeds and returns 0.
+ // 4.17: Succeeds and returns 1.
+ // gVisor: Succeeds and returns 0.
EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds());
EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)),
@@ -1458,7 +1468,10 @@ TEST(ProcPidFile, SubprocessZombie) {
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
+ //
+ // ~4.3: Fails and returns EACCES.
// gVisor & 4.17: Succeeds and returns 1.
+ //
// EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)),
// SyscallFailsWithErrno(EACCES));
}
@@ -1467,9 +1480,12 @@ TEST(ProcPidFile, SubprocessZombie) {
TEST(ProcPidFile, SubprocessExited) {
char buf[1];
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels.
+ //
+ // ~4.3: Fails and returns ESRCH.
// gVisor: Fails with ESRCH.
// 4.17: Succeeds and returns 1.
+ //
// EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)),
// SyscallFailsWithErrno(ESRCH));
@@ -1641,7 +1657,7 @@ TEST(ProcTask, KilledThreadsDisappear) {
EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task",
TaskFiles(initial, {child1.Tid()})));
- // Stat child1's task file.
+ // Stat child1's task file. Regression test for b/32097707.
struct stat statbuf;
const std::string child1_task_file =
absl::StrCat("/proc/self/task/", child1.Tid());
@@ -1669,7 +1685,7 @@ TEST(ProcTask, KilledThreadsDisappear) {
EXPECT_NO_ERRNO(EventuallyDirContainsExactly(
"/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()})));
- // Stat child1's task file again. This time it should fail.
+ // Stat child1's task file again. This time it should fail. See b/32097707.
EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf),
SyscallFailsWithErrno(ENOENT));
@@ -1824,7 +1840,7 @@ TEST(ProcSysVmOvercommitMemory, HasNumericValue) {
}
// Check that link for proc fd entries point the target node, not the
-// symlink itself.
+// symlink itself. Regression test for b/31155070.
TEST(ProcTaskFd, FstatatFollowsSymlink) {
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const FileDescriptor fd =
@@ -1883,6 +1899,20 @@ TEST(ProcMounts, IsSymlink) {
EXPECT_EQ(link, "self/mounts");
}
+TEST(ProcSelfMountinfo, RequiredFieldsArePresent) {
+ auto mountinfo =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo"));
+ EXPECT_THAT(
+ mountinfo,
+ AllOf(
+ // Root mount.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"),
+ // Proc mount - always rw.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)")));
+}
+
// Check that /proc/self/mounts looks something like a real mounts file.
TEST(ProcSelfMounts, RequiredFieldsArePresent) {
auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts"));
diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc
index 4069cbc7e..baaf9f757 100644
--- a/test/syscalls/linux/readv.cc
+++ b/test/syscalls/linux/readv.cc
@@ -254,7 +254,9 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) {
// This test depends on the maximum extent of a single readv() syscall, so
// we can't tolerate interruption from saving.
TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) {
- // Ensure that we won't be interrupted by ITIMER_PROF.
+ // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly
+ // important in environments where automated profiling tools may start
+ // ITIMER_PROF automatically.
struct itimerval itv = {};
auto const cleanup_itimer =
ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv));
diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc
index 106c045e3..4bfb1ff56 100644
--- a/test/syscalls/linux/rseq.cc
+++ b/test/syscalls/linux/rseq.cc
@@ -36,7 +36,7 @@ namespace {
// We must be very careful about how these tests are written. Each thread may
// only have one struct rseq registration, which may be done automatically at
// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does
-// not do so).
+// not do so, but other libraries do).
//
// Testing of rseq is thus done primarily in a child process with no
// registration. This means exec'ing a nostdlib binary, as rseq registration can
diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc
index 424e2a67f..be2364fb8 100644
--- a/test/syscalls/linux/select.cc
+++ b/test/syscalls/linux/select.cc
@@ -146,7 +146,7 @@ TEST_F(SelectTest, IgnoreBitsAboveNfds) {
// This test illustrates Linux's behavior of 'select' calls passing after
// setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on
-// this behavior.
+// this behavior. See b/122318458.
TEST_F(SelectTest, SetrlimitCallNOFILE) {
fd_set read_set;
FD_ZERO(&read_set);
diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc
index 7ba752599..c7fdbb924 100644
--- a/test/syscalls/linux/shm.cc
+++ b/test/syscalls/linux/shm.cc
@@ -473,7 +473,7 @@ TEST(ShmTest, PartialUnmap) {
}
// Check that sentry does not panic when asked for a zero-length private shm
-// segment.
+// segment. Regression test for b/110694797.
TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) {
EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _));
}
diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc
index 654c6a47f..a603fc1d1 100644
--- a/test/syscalls/linux/sigprocmask.cc
+++ b/test/syscalls/linux/sigprocmask.cc
@@ -237,7 +237,7 @@ TEST_F(SigProcMaskTest, SignalHandler) {
}
// Check that sigprocmask correctly handles aliasing of the set and oldset
-// pointers.
+// pointers. Regression test for b/30502311.
TEST_F(SigProcMaskTest, AliasedSets) {
sigset_t mask;
diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc
index 276a94eb8..884319e1d 100644
--- a/test/syscalls/linux/socket_unix_non_stream.cc
+++ b/test/syscalls/linux/socket_unix_non_stream.cc
@@ -109,7 +109,7 @@ PosixErrorOr<std::vector<Mapping>> CreateFragmentedRegion(const int size,
}
// A contiguous iov that is heavily fragmented in FileMem can still be sent
-// successfully.
+// successfully. See b/115833655.
TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -165,7 +165,7 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) {
}
// A contiguous iov that is heavily fragmented in FileMem can still be received
-// into successfully.
+// into successfully. Regression test for b/115833655.
TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc
index b249ff91f..03ee1250d 100644
--- a/test/syscalls/linux/symlink.cc
+++ b/test/syscalls/linux/symlink.cc
@@ -38,7 +38,7 @@ mode_t FilePermission(const std::string& path) {
}
// Test that name collisions are checked on the new link path, not the source
-// path.
+// path. Regression test for b/31782115.
TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) {
const std::string srcname = NewTempAbsPath();
const std::string newname = NewTempAbsPath();
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 525ccbd88..c4591a3b9 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -244,7 +244,8 @@ TEST_P(TcpSocketTest, ZeroWriteAllowed) {
}
// Test that a non-blocking write with a buffer that is larger than the send
-// buffer size will not actually write the whole thing at once.
+// buffer size will not actually write the whole thing at once. Regression test
+// for b/64438887.
TEST_P(TcpSocketTest, NonblockingLargeWrite) {
// Set the FD to O_NONBLOCK.
int opts;
@@ -1339,6 +1340,15 @@ TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) {
EXPECT_EQ(get, kTCPDeferAccept);
}
+TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ char buf[1];
+ EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc
index c7eead17e..1ccb95733 100644
--- a/test/syscalls/linux/time.cc
+++ b/test/syscalls/linux/time.cc
@@ -62,6 +62,7 @@ TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) {
::testing::KilledBySignal(SIGSEGV), "");
}
+// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2.
int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) {
constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000;
return reinterpret_cast<int (*)(struct timeval*, struct timezone*)>(
diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc
index bae377c69..8d8ebbb24 100644
--- a/test/syscalls/linux/tkill.cc
+++ b/test/syscalls/linux/tkill.cc
@@ -54,7 +54,7 @@ void SigHandler(int sig, siginfo_t* info, void* context) {
TEST_CHECK(info->si_code == SI_TKILL);
}
-// Test with a real signal.
+// Test with a real signal. Regression test for b/24790092.
TEST(TkillTest, ValidTIDAndRealSignal) {
struct sigaction sa;
sa.sa_sigaction = SigHandler;
diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc
index 35aacb172..9c10b6674 100644
--- a/test/util/temp_path.cc
+++ b/test/util/temp_path.cc
@@ -77,6 +77,7 @@ std::string NewTempAbsPath() {
std::string NewTempRelPath() { return NextTempBasename(); }
std::string GetAbsoluteTestTmpdir() {
+ // Note that TEST_TMPDIR is guaranteed to be set.
char* env_tmpdir = getenv("TEST_TMPDIR");
std::string tmp_dir =
env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp";
diff --git a/tools/build/tags.bzl b/tools/build/tags.bzl
index e99c87f81..a6db44e47 100644
--- a/tools/build/tags.bzl
+++ b/tools/build/tags.bzl
@@ -33,4 +33,8 @@ go_suffixes = [
"_wasm_unsafe",
"_linux",
"_linux_unsafe",
+ "_opts",
+ "_opts_unsafe",
+ "_impl",
+ "_impl_unsafe",
]
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 5d5fa134a..c03b557ae 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -73,6 +73,16 @@ def calculate_sets(srcs):
result[target].append(file)
return result
+def go_imports(name, src, out):
+ """Simplify a single Go source file by eliminating unused imports."""
+ native.genrule(
+ name = name,
+ srcs = [src],
+ outs = [out],
+ tools = ["@org_golang_x_tools//cmd/goimports:goimports"],
+ cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"),
+ )
+
def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs):
"""Wraps the standard go_library and does stateification and marshalling.
@@ -107,10 +117,15 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
state_sets = calculate_sets(srcs)
for (suffix, srcs) in state_sets.items():
go_stateify(
- name = name + suffix + "_state_autogen",
+ name = name + suffix + "_state_autogen_with_imports",
srcs = srcs,
imports = imports,
package = name,
+ out = name + suffix + "_state_autogen_with_imports.go",
+ )
+ go_imports(
+ name = name + suffix + "_state_autogen",
+ src = name + suffix + "_state_autogen_with_imports.go",
out = name + suffix + "_state_autogen.go",
)
all_srcs = all_srcs + [
diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl
index 32235813a..de365d153 100644
--- a/tools/images/defs.bzl
+++ b/tools/images/defs.bzl
@@ -57,7 +57,10 @@ def _vm_image_impl(ctx):
command = argv,
input_manifests = runfiles_manifests,
)
- return [DefaultInfo(files = depset([ctx.outputs.out]))]
+ return [DefaultInfo(
+ files = depset([ctx.outputs.out]),
+ runfiles = ctx.runfiles(files = [ctx.outputs.out]),
+ )]
_vm_image = rule(
attrs = {
diff --git a/tools/installers/BUILD b/tools/installers/BUILD
index 01bc4de8c..d78a265ca 100644
--- a/tools/installers/BUILD
+++ b/tools/installers/BUILD
@@ -5,10 +5,15 @@ package(
licenses = ["notice"],
)
+filegroup(
+ name = "runsc",
+ srcs = ["//runsc"],
+)
+
sh_binary(
name = "head",
srcs = ["head.sh"],
- data = ["//runsc"],
+ data = [":runsc"],
)
sh_binary(
diff --git a/tools/installers/head.sh b/tools/installers/head.sh
index 4435cb27a..9de8f138c 100755
--- a/tools/installers/head.sh
+++ b/tools/installers/head.sh
@@ -15,7 +15,7 @@
# limitations under the License.
# Install our runtime.
-third_party/gvisor/runsc/runsc install
+$(dirname $0)/runsc install
# Restart docker.
service docker restart || true