summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE21
-rw-r--r--benchmarks/BUILD1
-rw-r--r--benchmarks/harness/BUILD5
-rw-r--r--benchmarks/harness/machine_producers/BUILD40
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py250
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer_test.py48
-rw-r--r--benchmarks/harness/machine_producers/mock_producer.py23
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_five.json211
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_one.json145
-rw-r--r--go.mod2
-rw-r--r--go.sum2
-rw-r--r--kokoro/issue_reviver.cfg15
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/rseq.go130
-rw-r--r--pkg/sentry/arch/arch.go6
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/fs/file.go7
-rw-r--r--pkg/sentry/fs/fsutil/BUILD2
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go14
-rw-r--r--pkg/sentry/fs/splice.go5
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go10
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go8
-rw-r--r--pkg/sentry/fsimpl/memfs/regular_file.go154
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go102
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go147
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD (renamed from pkg/sentry/fsimpl/memfs/BUILD)32
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go (renamed from pkg/sentry/fsimpl/memfs/benchmark_test.go)12
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go (renamed from pkg/sentry/fsimpl/memfs/directory.go)2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go (renamed from pkg/sentry/fsimpl/memfs/filesystem.go)8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go (renamed from pkg/sentry/fsimpl/memfs/named_pipe.go)2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go (renamed from pkg/sentry/fsimpl/memfs/pipe_test.go)6
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go357
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go224
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go (renamed from pkg/sentry/fsimpl/memfs/symlink.go)2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go (renamed from pkg/sentry/fsimpl/memfs/memfs.go)46
-rw-r--r--pkg/sentry/kernel/rseq.go383
-rw-r--r--pkg/sentry/kernel/shm/shm.go85
-rw-r--r--pkg/sentry/kernel/task.go43
-rw-r--r--pkg/sentry/kernel/task_clone.go7
-rw-r--r--pkg/sentry/kernel/task_exec.go6
-rw-r--r--pkg/sentry/kernel/task_run.go16
-rw-r--r--pkg/sentry/kernel/task_start.go10
-rw-r--r--pkg/sentry/kernel/thread_group.go18
-rw-r--r--pkg/sentry/mm/procfs.go12
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s29
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s30
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go20
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go2
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s67
-rw-r--r--pkg/sentry/socket/control/control.go6
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/socket.go29
-rw-r--r--pkg/sentry/socket/netstack/netstack.go54
-rw-r--r--pkg/sentry/socket/unix/io.go13
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go24
-rw-r--r--pkg/sentry/socket/unix/unix.go23
-rw-r--r--pkg/sentry/strace/BUILD3
-rw-r--r--pkg/sentry/strace/linux64_amd64.go (renamed from pkg/sentry/strace/linux64.go)19
-rw-r--r--pkg/sentry/strace/linux64_arm64.go323
-rw-r--r--pkg/sentry/strace/syscalls.go9
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go2
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_rseq.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go7
-rw-r--r--pkg/tcpip/BUILD8
-rw-r--r--pkg/tcpip/header/BUILD1
-rw-r--r--pkg/tcpip/header/ipv6.go45
-rw-r--r--pkg/tcpip/header/ipv6_test.go163
-rw-r--r--pkg/tcpip/network/arp/arp.go6
-rw-r--r--pkg/tcpip/network/ip_test.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go18
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go14
-rw-r--r--pkg/tcpip/stack/BUILD24
-rw-r--r--pkg/tcpip/stack/ndp.go568
-rw-r--r--pkg/tcpip/stack/ndp_test.go905
-rw-r--r--pkg/tcpip/stack/nic.go140
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go6
-rw-r--r--pkg/tcpip/stack/stack.go154
-rw-r--r--pkg/tcpip/stack/stack_test.go293
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go92
-rw-r--r--pkg/tcpip/stack/transport_test.go14
-rw-r--r--pkg/tcpip/tcpip.go33
-rw-r--r--pkg/tcpip/timer.go161
-rw-r--r--pkg/tcpip/timer_test.go236
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go14
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go22
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go169
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go57
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go16
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go89
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go38
-rw-r--r--runsc/boot/network.go17
-rwxr-xr-xscripts/issue_reviver.sh27
-rw-r--r--test/iptables/README.md2
-rw-r--r--test/syscalls/linux/BUILD1
-rw-r--r--test/syscalls/linux/inotify.cc28
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc10
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h6
-rw-r--r--test/syscalls/linux/partial_bad_buffer.cc138
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc25
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc27
-rw-r--r--test/syscalls/linux/socket_ip_unbound.cc33
-rw-r--r--test/syscalls/linux/socket_non_stream.cc113
-rw-r--r--test/syscalls/linux/socket_non_stream_blocking.cc37
-rw-r--r--test/syscalls/linux/socket_stream.cc55
-rw-r--r--tools/issue_reviver/BUILD12
-rw-r--r--tools/issue_reviver/github/BUILD17
-rw-r--r--tools/issue_reviver/github/github.go164
-rw-r--r--tools/issue_reviver/main.go89
-rw-r--r--tools/issue_reviver/reviver/BUILD19
-rw-r--r--tools/issue_reviver/reviver/reviver.go192
-rw-r--r--tools/issue_reviver/reviver/reviver_test.go88
118 files changed, 6454 insertions, 1340 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 4b5a3bfe2..e2afc073c 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -290,6 +290,27 @@ go_repository(
version = "v1.3.1",
)
+go_repository(
+ name = "com_github_google_go-github",
+ importpath = "github.com/google/go-github",
+ sum = "h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=",
+ version = "v17.0.0",
+)
+
+go_repository(
+ name = "org_golang_x_oauth2",
+ importpath = "golang.org/x/oauth2",
+ sum = "h1:pE8b58s1HRDMi8RDc79m0HISf9D4TzseP40cEA6IGfs=",
+ version = "v0.0.0-20191202225959-858c2ad4c8b6",
+)
+
+go_repository(
+ name = "com_github_google_go-querystring",
+ importpath = "github.com/google/go-querystring",
+ sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=",
+ version = "v1.0.0",
+)
+
# System Call test dependencies.
http_archive(
name = "com_google_absl",
diff --git a/benchmarks/BUILD b/benchmarks/BUILD
index dbadeeaf2..1455c6c5b 100644
--- a/benchmarks/BUILD
+++ b/benchmarks/BUILD
@@ -5,5 +5,6 @@ py_binary(
srcs = ["run.py"],
main = "run.py",
python_version = "PY3",
+ srcs_version = "PY3",
deps = ["//benchmarks/runner"],
)
diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD
index 9546220c4..081a74243 100644
--- a/benchmarks/harness/BUILD
+++ b/benchmarks/harness/BUILD
@@ -24,6 +24,7 @@ py_library(
name = "container",
srcs = ["container.py"],
deps = [
+ "//benchmarks/workloads",
requirement("asn1crypto", False),
requirement("chardet", False),
requirement("certifi", False),
@@ -45,6 +46,7 @@ py_library(
"//benchmarks/harness:container",
"//benchmarks/harness:ssh_connection",
"//benchmarks/harness:tunnel_dispatcher",
+ "//benchmarks/harness/machine_mocks",
requirement("asn1crypto", False),
requirement("chardet", False),
requirement("certifi", False),
@@ -53,6 +55,7 @@ py_library(
requirement("idna", False),
requirement("ptyprocess", False),
requirement("requests", False),
+ requirement("six", False),
requirement("urllib3", False),
requirement("websocket-client", False),
],
@@ -64,7 +67,7 @@ py_library(
deps = [
"//benchmarks/harness",
requirement("bcrypt", False),
- requirement("cffi", False),
+ requirement("cffi", True),
requirement("paramiko", True),
requirement("cryptography", False),
],
diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD
index a48da02a1..c4e943882 100644
--- a/benchmarks/harness/machine_producers/BUILD
+++ b/benchmarks/harness/machine_producers/BUILD
@@ -20,6 +20,7 @@ py_library(
srcs = ["mock_producer.py"],
deps = [
"//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_producer",
"//benchmarks/harness/machine_producers:machine_producer",
],
)
@@ -38,3 +39,42 @@ py_library(
name = "gcloud_mock_recorder",
srcs = ["gcloud_mock_recorder.py"],
)
+
+py_library(
+ name = "gcloud_producer",
+ srcs = ["gcloud_producer.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_mock_recorder",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ ],
+)
+
+filegroup(
+ name = "test_data",
+ srcs = [
+ "testdata/get_five.json",
+ "testdata/get_one.json",
+ ],
+)
+
+py_library(
+ name = "gcloud_producer_test_lib",
+ srcs = ["gcloud_producer_test.py"],
+ deps = [
+ "//benchmarks/harness/machine_producers:machine_producer",
+ "//benchmarks/harness/machine_producers:mock_producer",
+ ],
+)
+
+py_test(
+ name = "gcloud_producer_test",
+ srcs = [":gcloud_producer_test_lib"],
+ data = [
+ ":test_data",
+ ],
+ python_version = "PY3",
+ tags = [
+ "local",
+ ],
+)
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
new file mode 100644
index 000000000..4693dd8a2
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -0,0 +1,250 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A machine producer which produces machine objects using `gcloud`.
+
+Machine producers produce valid harness.Machine objects which are backed by
+real machines. This producer produces those machines on the given user's GCP
+account using the `gcloud` tool.
+
+GCloudProducer creates instances on the given GCP account named like:
+`machine-XXXXXXX-XXXX-XXXX-XXXXXXXXXXXX` in a randomized fashion such that name
+collisions with user instances shouldn't happen.
+
+ Typical usage example:
+
+ producer = GCloudProducer(args)
+ machines = producer.get_machines(NUM_MACHINES)
+ # run stuff on machines with machines[i].run(CMD)
+ producer.release_machines(NUM_MACHINES)
+"""
+import datetime
+import getpass
+import json
+import subprocess
+import threading
+from typing import List, Dict, Any
+import uuid
+
+from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import machine_producer
+
+DEFAULT_USER = getpass.getuser()
+
+
+class GCloudProducer(machine_producer.MachineProducer):
+ """Implementation of MachineProducer backed by GCP.
+
+ Produces Machine objects backed by GCP instances.
+
+ Attributes:
+ project: The GCP project name under which to create the machines.
+ ssh_key_path: path to a valid ssh key. See README on vaild ssh keys.
+ image: image name as a string.
+ image_project: image project as a string.
+ zone: string to a valid GCP zone.
+ ssh_user: string of user name for ssh_key
+ ssh_password: string of password for ssh key
+ mock: a mock printer which will print mock data if required. Mock data is
+ recorded output from subprocess calls (returncode, stdout, args).
+ condition: mutex for this class around machine creation and deleteion.
+ """
+
+ def __init__(self,
+ project: str,
+ ssh_key_path: str,
+ image: str,
+ image_project: str,
+ zone: str,
+ ssh_user: str,
+ mock: gcloud_mock_recorder.MockPrinter = None):
+ self.project = project
+ self.ssh_key_path = ssh_key_path
+ self.image = image
+ self.image_project = image_project
+ self.zone = zone
+ self.ssh_user = ssh_user if ssh_user else DEFAULT_USER
+ self.mock = mock
+ self.condition = threading.Condition()
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns requested number of machines backed by GCP instances."""
+ if num_machines <= 0:
+ raise ValueError(
+ "Cannot ask for {num} machines!".format(num=num_machines))
+ with self.condition:
+ names = self._get_unique_names(num_machines)
+ self._build_instances(names)
+ instances = self._start_command(names)
+ self._add_ssh_key_to_instances(names)
+ return self._machines_from_instances(instances)
+
+ def release_machines(self, machine_list: List[machine.Machine]):
+ """Releases the requested number of machines, deleting the instances."""
+ if not machine_list:
+ return
+ with self.condition:
+ cmd = "gcloud compute instances delete --quiet".split(" ")
+ names = [str(m) for m in machine_list]
+ cmd.extend(names)
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ self._run_command(cmd)
+
+ def _machines_from_instances(
+ self, instances: List[Dict[str, Any]]) -> List[machine.Machine]:
+ """Creates Machine Objects from json data describing created instances."""
+ machines = []
+ for instance in instances:
+ name = instance["name"]
+ kwargs = {
+ "hostname":
+ instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"],
+ "key_path":
+ self.ssh_key_path,
+ "username":
+ self.ssh_user
+ }
+ machines.append(machine.RemoteMachine(name=name, **kwargs))
+ return machines
+
+ def _get_unique_names(self, num_names) -> List[str]:
+ """Returns num_names unique names based on data from the GCP project."""
+ curr_machines = self._list_machines()
+ curr_names = set([machine["name"] for machine in curr_machines])
+ ret = []
+ while len(ret) < num_names:
+ new_name = "machine-" + str(uuid.uuid4())
+ if new_name not in curr_names:
+ ret.append(new_name)
+ curr_names.update(new_name)
+ return ret
+
+ def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]:
+ """Creates instances using gcloud command.
+
+ Runs the command `gcloud compute instances create` and returns json data
+ on created instances on success. Creates len(names) instances, one for each
+ name.
+
+ Args:
+ names: list of names of instances to create.
+
+ Returns:
+ List of json data describing created machines.
+ """
+ if not names:
+ raise ValueError(
+ "_build_instances cannot create instances without names.")
+ cmd = "gcloud compute instances create".split(" ")
+ cmd.extend(names)
+ cmd.extend("--preemptible --image={image} --zone={zone}".format(
+ image=self.image, zone=self.zone).split(" "))
+ if self.image_project:
+ cmd.append("--image-project={project}".format(project=self.image_project))
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _start_command(self, names):
+ """Starts instances using gcloud command.
+
+ Runs the command `gcloud compute instances start` on list of instances by
+ name and returns json data on started instances on success.
+
+ Args:
+ names: list of names of instances to start.
+
+ Returns:
+ List of json data describing started machines.
+ """
+ if not names:
+ raise ValueError("_start_command cannot start empty instance list.")
+ cmd = "gcloud compute instances start".split(" ")
+ cmd.extend(names)
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ cmd.append("--project={project}".format(project=self.project))
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _add_ssh_key_to_instances(self, names: List[str]) -> None:
+ """Adds ssh key to instances by calling gcloud ssh command.
+
+ Runs the command `gcloud compute ssh instance_name` on list of images by
+ name. Tries to ssh into given instance
+
+ Args:
+ names: list of machine names to which to add the ssh-key
+ self.ssh_key_path.
+
+ Raises:
+ subprocess.CalledProcessError: when underlying subprocess call returns an
+ error other than 255 (Connection closed by remote host).
+ TimeoutError: when 3 unsuccessful tries to ssh into the host return 255.
+ """
+ for name in names:
+ cmd = "gcloud compute ssh {name}".format(name=name).split(" ")
+ cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_path))
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ cmd.append("--command=uname")
+ timeout = datetime.timedelta(seconds=5 * 60)
+ start = datetime.datetime.now()
+ while datetime.datetime.now() <= timeout + start:
+ try:
+ self._run_command(cmd)
+ break
+ except subprocess.CalledProcessError as e:
+ if datetime.datetime.now() > timeout + start:
+ raise TimeoutError(
+ "Could not SSH into instance after 5 min: {name}".format(
+ name=name))
+ # 255 is the returncode for ssh connection refused.
+ elif e.returncode == 255:
+
+ continue
+ else:
+ raise e
+
+ def _list_machines(self) -> List[Dict[str, Any]]:
+ """Runs `list` gcloud command and returns list of Machine data."""
+ cmd = "gcloud compute instances list --project {project}".format(
+ project=self.project).split(" ")
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _run_command(self, cmd: List[str]) -> subprocess.CompletedProcess:
+ """Runs command as a subprocess.
+
+ Runs command as subprocess and returns the result.
+ If this has a mock recorder, use the record method to record the subprocess
+ call.
+
+ Args:
+ cmd: command to be run as a list of strings.
+
+ Returns:
+ Completed process object to be parsed by caller.
+
+ Raises:
+ CalledProcessError: if subprocess.run returns an error.
+ """
+ cmd = cmd + ["--format=json"]
+ res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if self.mock:
+ self.mock.record(res)
+ if res.returncode != 0:
+ raise subprocess.CalledProcessError(
+ cmd=res.args,
+ output=res.stdout,
+ stderr=res.stderr,
+ returncode=res.returncode)
+ return res
diff --git a/benchmarks/harness/machine_producers/gcloud_producer_test.py b/benchmarks/harness/machine_producers/gcloud_producer_test.py
new file mode 100644
index 000000000..c8adb2bdc
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer_test.py
@@ -0,0 +1,48 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests GCloudProducer using mock data.
+
+GCloudProducer produces machines using 'get_machines' and 'release_machines'
+methods. The tests check recorded data (jsonified subprocess.CompletedProcess
+objects) of the producer producing one and five machines.
+"""
+import os
+import types
+
+from benchmarks.harness.machine_producers import machine_producer
+from benchmarks.harness.machine_producers import mock_producer
+
+TEST_DIR = os.path.dirname(__file__)
+
+
+def run_get_release(producer: machine_producer.MachineProducer,
+ num_machines: int,
+ validator: types.FunctionType = None):
+ machines = producer.get_machines(num_machines)
+ assert len(machines) == num_machines
+ if validator:
+ validator(machines=machines, cmd="uname -a", workload=None)
+ producer.release_machines(machines)
+
+
+def test_run_one():
+ mock = mock_producer.MockReader(TEST_DIR + "get_one.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 1)
+
+
+def test_run_five():
+ mock = mock_producer.MockReader(TEST_DIR + "get_five.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 5)
diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py
index 4f29ad53f..37e9cb4b7 100644
--- a/benchmarks/harness/machine_producers/mock_producer.py
+++ b/benchmarks/harness/machine_producers/mock_producer.py
@@ -13,9 +13,11 @@
# limitations under the License.
"""Producers of mocks."""
-from typing import List
+from typing import List, Any
from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import gcloud_producer
from benchmarks.harness.machine_producers import machine_producer
@@ -29,3 +31,22 @@ class MockMachineProducer(machine_producer.MachineProducer):
def release_machines(self, machine_list: List[machine.MockMachine]):
"""No-op."""
return
+
+
+class MockGCloudProducer(gcloud_producer.GCloudProducer):
+ """Mocks GCloudProducer for testing purposes."""
+
+ def __init__(self, mock: gcloud_mock_recorder.MockReader, **kwargs):
+ gcloud_producer.GCloudProducer.__init__(
+ self, project="mock", ssh_private_key_path="mock", **kwargs)
+ self.mock = mock
+
+ def _validate_ssh_file(self):
+ pass
+
+ def _run_command(self, cmd):
+ return self.mock.pop(cmd)
+
+ def _machines_from_instances(
+ self, instances: List[Any]) -> List[machine.MockMachine]:
+ return [machine.MockMachine() for _ in instances]
diff --git a/benchmarks/harness/machine_producers/testdata/get_five.json b/benchmarks/harness/machine_producers/testdata/get_five.json
new file mode 100644
index 000000000..32bad1b06
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_five.json
@@ -0,0 +1,211 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--project=project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/benchmarks/harness/machine_producers/testdata/get_one.json b/benchmarks/harness/machine_producers/testdata/get_one.json
new file mode 100644
index 000000000..c359c19c8
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_one.json
@@ -0,0 +1,145 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--project=linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/go.mod b/go.mod
index 304b8bf13..c4687ed02 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/golang/protobuf v1.3.1
github.com/google/btree v1.0.0
github.com/google/go-cmp v0.2.0
+ github.com/google/go-github/v28 v28.1.1
github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d
github.com/kr/pty v1.1.1
@@ -17,5 +18,6 @@ require (
github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936
golang.org/x/net v0.0.0-20190311183353-d8887717615a
+ golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
)
diff --git a/go.sum b/go.sum
index 7a0bc175a..434770beb 100644
--- a/go.sum
+++ b/go.sum
@@ -4,6 +4,7 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
+github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM=
github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
@@ -13,6 +14,7 @@ github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+S
github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
diff --git a/kokoro/issue_reviver.cfg b/kokoro/issue_reviver.cfg
new file mode 100644
index 000000000..2370d9250
--- /dev/null
+++ b/kokoro/issue_reviver.cfg
@@ -0,0 +1,15 @@
+build_file: "repo/scripts/issue_reviver.sh"
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-github-access-token"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_GITHUB_ACCESS_TOKEN"
+ value: "73898_kokoro-github-access-token"
+}
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 9553f164d..716ff22d2 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -41,6 +41,7 @@ go_library(
"poll.go",
"prctl.go",
"ptrace.go",
+ "rseq.go",
"rusage.go",
"sched.go",
"seccomp.go",
diff --git a/pkg/abi/linux/rseq.go b/pkg/abi/linux/rseq.go
new file mode 100644
index 000000000..76253ba30
--- /dev/null
+++ b/pkg/abi/linux/rseq.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Flags passed to rseq(2).
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_FLAG_UNREGISTER unregisters the current thread.
+ RSEQ_FLAG_UNREGISTER = 1 << 0
+)
+
+// Critical section flags used in RSeqCriticalSection.Flags and RSeq.Flags.
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT inhibits restart on preemption.
+ RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT = 1 << 0
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL inhibits restart on signal
+ // delivery.
+ RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL = 1 << 1
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE inhibits restart on CPU
+ // migration.
+ RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE = 1 << 2
+)
+
+// RSeqCriticalSection describes a restartable sequences critical section. It
+// is equivalent to struct rseq_cs, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+//
+// +marshal
+type RSeqCriticalSection struct {
+ // Version is the version of this structure. Version 0 is defined here.
+ Version uint32
+
+ // Flags are the critical section flags, defined above.
+ Flags uint32
+
+ // Start is the start address of the critical section.
+ Start uint64
+
+ // PostCommitOffset is the offset from Start of the first instruction
+ // outside of the critical section.
+ PostCommitOffset uint64
+
+ // Abort is the abort address. It must be outside the critical section,
+ // and the 4 bytes prior must match the abort signature.
+ Abort uint64
+}
+
+const (
+ // SizeOfRSeqCriticalSection is the size of RSeqCriticalSection.
+ SizeOfRSeqCriticalSection = 32
+
+ // SizeOfRSeqSignature is the size of the signature immediately
+ // preceding RSeqCriticalSection.Abort.
+ SizeOfRSeqSignature = 4
+)
+
+// Special values for RSeq.CPUID, defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CPU_ID_UNINITIALIZED indicates that this thread has not
+ // performed rseq initialization.
+ RSEQ_CPU_ID_UNINITIALIZED = ^uint32(0) // -1
+
+ // RSEQ_CPU_ID_REGISTRATION_FAILED indicates that rseq initialization
+ // failed.
+ RSEQ_CPU_ID_REGISTRATION_FAILED = ^uint32(1) // -2
+)
+
+// RSeq is the thread-local restartable sequences config/status. It
+// is equivalent to struct rseq, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+type RSeq struct {
+ // CPUIDStart contains the current CPU ID if rseq is initialized.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUIDStart uint32
+
+ // CPUID contains the current CPU ID or one of the CPU ID special
+ // values defined above.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUID uint32
+
+ // RSeqCriticalSection is a pointer to the current RSeqCriticalSection
+ // block, or NULL. It is reset to NULL by the kernel on restart or
+ // non-restarting preempt/signal.
+ //
+ // This field should only be written by the thread which registered
+ // this structure, and must be written atomically.
+ RSeqCriticalSection uint64
+
+ // Flags are the critical section flags that apply to all critical
+ // sections on this thread, defined above.
+ Flags uint32
+}
+
+const (
+ // SizeOfRSeq is the size of RSeq.
+ //
+ // Note that RSeq is naively 24 bytes. However, it has 32-byte
+ // alignment, which in C increases sizeof to 32. That is the size that
+ // the Linux kernel uses.
+ SizeOfRSeq = 32
+
+ // AlignOfRSeq is the standard alignment of RSeq.
+ AlignOfRSeq = 32
+
+ // OffsetOfRSeqCriticalSection is the offset of RSeqCriticalSection in RSeq.
+ OffsetOfRSeqCriticalSection = 8
+)
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 498ca4669..81ec98a77 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -125,9 +125,9 @@ type Context interface {
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
SetTLS(value uintptr) bool
- // SetRSEQInterruptedIP sets the register that contains the old IP when a
- // restartable sequence is interrupted.
- SetRSEQInterruptedIP(value uintptr)
+ // SetOldRSeqInterruptedIP sets the register that contains the old IP
+ // when an "old rseq" restartable sequence is interrupted.
+ SetOldRSeqInterruptedIP(value uintptr)
// StateData returns a pointer to underlying architecture state.
StateData() *State
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 67daa6c24..2aa08b1a9 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -174,8 +174,8 @@ func (c *context64) SetTLS(value uintptr) bool {
return true
}
-// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP.
-func (c *context64) SetRSEQInterruptedIP(value uintptr) {
+// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
+func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
c.Regs.R10 = uint64(value)
}
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index c0a6e884b..a2f966cb6 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -555,6 +555,10 @@ type lockedWriter struct {
//
// This applies only to Write, not WriteAt.
Offset int64
+
+ // Err contains the first error encountered while copying. This is
+ // useful to determine whether Writer or Reader failed during io.Copy.
+ Err error
}
// Write implements io.Writer.Write.
@@ -590,5 +594,8 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
break
}
}
+ if w.Err == nil {
+ w.Err = err
+ }
return written, err
}
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index b2e8d9c77..9ca695a95 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -53,7 +53,7 @@ go_template_instance(
"Key": "uint64",
"Range": "memmap.MappableRange",
"Value": "uint64",
- "Functions": "fileRangeSetFunctions",
+ "Functions": "FileRangeSetFunctions",
},
)
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 0a5466b0a..f52d712e3 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -34,25 +34,25 @@ import (
//
// type FileRangeSet <generated by go_generics>
-// fileRangeSetFunctions implements segment.Functions for FileRangeSet.
-type fileRangeSetFunctions struct{}
+// FileRangeSetFunctions implements segment.Functions for FileRangeSet.
+type FileRangeSetFunctions struct{}
// MinKey implements segment.Functions.MinKey.
-func (fileRangeSetFunctions) MinKey() uint64 {
+func (FileRangeSetFunctions) MinKey() uint64 {
return 0
}
// MaxKey implements segment.Functions.MaxKey.
-func (fileRangeSetFunctions) MaxKey() uint64 {
+func (FileRangeSetFunctions) MaxKey() uint64 {
return math.MaxUint64
}
// ClearValue implements segment.Functions.ClearValue.
-func (fileRangeSetFunctions) ClearValue(_ *uint64) {
+func (FileRangeSetFunctions) ClearValue(_ *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
+func (FileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
if frstart1+mr1.Length() != frstart2 {
return 0, false
}
@@ -60,7 +60,7 @@ func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _
}
// Split implements segment.Functions.Split.
-func (fileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
+func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
return frstart, frstart + (split - mr.Start)
}
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index 311798811..389c330a0 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -167,6 +167,11 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
if !srcPipe && !opts.SrcOffset {
atomic.StoreInt64(&src.offset, src.offset+n)
}
+
+ // Don't report any errors if we have some progress without data loss.
+ if w.Err == nil {
+ err = nil
+ }
}
// Drop locks.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 567523d32..4110649ab 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -29,8 +29,12 @@ package disklayout
// byte (i * sb.BlockSize()) to ((i+1) * sb.BlockSize()).
const (
- // ExtentStructsSize is the size of all the three extent on-disk structs.
- ExtentStructsSize = 12
+ // ExtentHeaderSize is the size of the header of an extent tree node.
+ ExtentHeaderSize = 12
+
+ // ExtentEntrySize is the size of an entry in an extent tree node.
+ // This size is the same for both leaf and internal nodes.
+ ExtentEntrySize = 12
// ExtentMagic is the magic number which must be present in the header.
ExtentMagic = 0xf30a
@@ -57,7 +61,7 @@ type ExtentNode struct {
Entries []ExtentEntryPair
}
-// ExtentEntry reprsents an extent tree node entry. The entry can either be
+// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
// FileBlock returns the first file block number covered by this entry.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index b0fad9b71..8762b90db 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,7 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentStructsSize)
- assertSize(t, ExtentIdx{}, ExtentStructsSize)
- assertSize(t, Extent{}, ExtentStructsSize)
+ assertSize(t, ExtentHeader{}, ExtentHeaderSize)
+ assertSize(t, ExtentIdx{}, ExtentEntrySize)
+ assertSize(t, Extent{}, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 3d3ebaca6..11dcc0346 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -57,7 +57,7 @@ func newExtentFile(regFile regularFile) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &f.root.Header)
+ binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -67,7 +67,7 @@ func (f *extentFile) buildExtTree() error {
}
f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries)
- for i, off := uint16(0), disklayout.ExtentStructsSize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), disklayout.ExtentEntrySize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if f.root.Header.Height == 0 {
// Leaf node.
@@ -76,7 +76,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry)
+ binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
f.root.Entries[i].Entry = curEntry
}
@@ -105,7 +105,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla
}
entries := make([]disklayout.ExtentEntryPair, header.NumEntries)
- for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), off+disklayout.ExtentEntrySize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if header.Height == 0 {
// Leaf node.
diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go
deleted file mode 100644
index b7f4853b3..000000000
--- a/pkg/sentry/fsimpl/memfs/regular_file.go
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package memfs
-
-import (
- "io"
- "sync"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-type regularFile struct {
- inode inode
-
- mu sync.RWMutex
- data []byte
- // dataLen is len(data), but accessed using atomic memory operations to
- // avoid locking in inode.stat().
- dataLen int64
-}
-
-func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
- file := &regularFile{}
- file.inode.init(file, fs, creds, mode)
- file.inode.nlink = 1 // from parent directory
- return &file.inode
-}
-
-type regularFileFD struct {
- fileDescription
-
- // These are immutable.
- readable bool
- writable bool
-
- // off is the file offset. off is accessed using atomic memory operations.
- // offMu serializes operations that may mutate off.
- off int64
- offMu sync.Mutex
-}
-
-// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {
- if fd.writable {
- fd.vfsfd.VirtualDentry().Mount().EndWrite()
- }
-}
-
-// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
- if !fd.readable {
- return 0, syserror.EINVAL
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.RLock()
- if offset >= int64(len(f.data)) {
- f.mu.RUnlock()
- return 0, io.EOF
- }
- n, err := dst.CopyOut(ctx, f.data[offset:])
- f.mu.RUnlock()
- return int64(n), err
-}
-
-// Read implements vfs.FileDescriptionImpl.Read.
-func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PRead(ctx, dst, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- if !fd.writable {
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- srclen := src.NumBytes()
- if srclen == 0 {
- return 0, nil
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.Lock()
- end := offset + srclen
- if end < offset {
- // Overflow.
- f.mu.Unlock()
- return 0, syserror.EFBIG
- }
- if end > f.dataLen {
- f.data = append(f.data, make([]byte, end-f.dataLen)...)
- atomic.StoreInt64(&f.dataLen, end)
- }
- n, err := src.CopyIn(ctx, f.data[offset:end])
- f.mu.Unlock()
- return int64(n), err
-}
-
-// Write implements vfs.FileDescriptionImpl.Write.
-func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// Seek implements vfs.FileDescriptionImpl.Seek.
-func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
- fd.offMu.Lock()
- defer fd.offMu.Unlock()
- switch whence {
- case linux.SEEK_SET:
- // use offset as specified
- case linux.SEEK_CUR:
- offset += fd.off
- case linux.SEEK_END:
- offset += atomic.LoadInt64(&fd.inode().impl.(*regularFile).dataLen)
- default:
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- fd.off = offset
- return offset, nil
-}
-
-// Sync implements vfs.FileDescriptionImpl.Sync.
-func (fd *regularFileFD) Sync(ctx context.Context) error {
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 50b2a832f..d8f92d52f 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -27,7 +27,11 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-const defaultPermission = 0444
+const (
+ defaultPermission = 0444
+ selfName = "self"
+ threadSelfName = "thread-self"
+)
// InoGenerator generates unique inode numbers for a given filesystem.
type InoGenerator interface {
@@ -45,6 +49,11 @@ type tasksInode struct {
inoGen InoGenerator
pidns *kernel.PIDNamespace
+
+ // '/proc/self' and '/proc/thread-self' have custom directory offsets in
+ // Linux. So handle them outside of OrderedChildren.
+ selfSymlink *vfs.Dentry
+ threadSelfSymlink *vfs.Dentry
}
var _ kernfs.Inode = (*tasksInode)(nil)
@@ -54,20 +63,20 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames
contents := map[string]*kernfs.Dentry{
//"cpuinfo": newCPUInfo(ctx, msrc),
//"filesystems": seqfile.NewSeqFileInode(ctx, &filesystemsData{}, msrc),
- "loadavg": newDentry(root, inoGen.NextIno(), defaultPermission, &loadavgData{}),
- "meminfo": newDentry(root, inoGen.NextIno(), defaultPermission, &meminfoData{k: k}),
- "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), defaultPermission, "self/mounts"),
- "self": newSelfSymlink(root, inoGen.NextIno(), defaultPermission, pidns),
- "stat": newDentry(root, inoGen.NextIno(), defaultPermission, &statData{k: k}),
- "thread-self": newThreadSelfSymlink(root, inoGen.NextIno(), defaultPermission, pidns),
+ "loadavg": newDentry(root, inoGen.NextIno(), defaultPermission, &loadavgData{}),
+ "meminfo": newDentry(root, inoGen.NextIno(), defaultPermission, &meminfoData{k: k}),
+ "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), defaultPermission, "self/mounts"),
+ "stat": newDentry(root, inoGen.NextIno(), defaultPermission, &statData{k: k}),
//"uptime": newUptime(ctx, msrc),
//"version": newVersionData(root, inoGen.NextIno(), k),
"version": newDentry(root, inoGen.NextIno(), defaultPermission, &versionData{k: k}),
}
inode := &tasksInode{
- pidns: pidns,
- inoGen: inoGen,
+ pidns: pidns,
+ inoGen: inoGen,
+ selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
+ threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
}
inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555)
@@ -86,6 +95,13 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
// Try to lookup a corresponding task.
tid, err := strconv.ParseUint(name, 10, 64)
if err != nil {
+ // If it failed to parse, check if it's one of the special handled files.
+ switch name {
+ case selfName:
+ return i.selfSymlink, nil
+ case threadSelfName:
+ return i.threadSelfSymlink, nil
+ }
return nil, syserror.ENOENT
}
@@ -104,41 +120,81 @@ func (i *tasksInode) Valid(ctx context.Context) bool {
}
// IterDirents implements kernfs.inodeDynamicLookup.
-//
-// TODO(gvisor.dev/issue/1195): Use tgid N offset = TGID_OFFSET + N.
-func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
- var tids []int
+func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+ // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
+ const FIRST_PROCESS_ENTRY = 256
+
+ // Use maxTaskID to shortcut searches that will result in 0 entries.
+ const maxTaskID = kernel.TasksLimit + 1
+ if offset >= maxTaskID {
+ return offset, nil
+ }
+
+ // According to Linux (fs/proc/base.c:proc_pid_readdir()), process directories
+ // start at offset FIRST_PROCESS_ENTRY with '/proc/self', followed by
+ // '/proc/thread-self' and then '/proc/[pid]'.
+ if offset < FIRST_PROCESS_ENTRY {
+ offset = FIRST_PROCESS_ENTRY
+ }
+
+ if offset == FIRST_PROCESS_ENTRY {
+ dirent := vfs.Dirent{
+ Name: selfName,
+ Type: linux.DT_LNK,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+ if offset == FIRST_PROCESS_ENTRY+1 {
+ dirent := vfs.Dirent{
+ Name: threadSelfName,
+ Type: linux.DT_LNK,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
- // Collect all tasks. Per linux we only include it in directory listings if
- // it's the leader. But for whatever crazy reason, you can still walk to the
- // given node.
+ // Collect all tasks that TGIDs are greater than the offset specified. Per
+ // Linux we only include in directory listings if it's the leader. But for
+ // whatever crazy reason, you can still walk to the given node.
+ var tids []int
+ startTid := offset - FIRST_PROCESS_ENTRY - 2
for _, tg := range i.pidns.ThreadGroups() {
+ tid := i.pidns.IDOfThreadGroup(tg)
+ if int64(tid) < startTid {
+ continue
+ }
if leader := tg.Leader(); leader != nil {
- tids = append(tids, int(i.pidns.IDOfThreadGroup(tg)))
+ tids = append(tids, int(tid))
}
}
if len(tids) == 0 {
return offset, nil
}
- if relOffset >= int64(len(tids)) {
- return offset, nil
- }
sort.Ints(tids)
- for _, tid := range tids[relOffset:] {
+ for _, tid := range tids {
dirent := vfs.Dirent{
Name: strconv.FormatUint(uint64(tid), 10),
Type: linux.DT_DIR,
Ino: i.inoGen.NextIno(),
- NextOff: offset + 1,
+ NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1,
}
if !cb.Handle(dirent) {
return offset, nil
}
offset++
}
- return offset, nil
+ return maxTaskID, nil
}
// Open implements kernfs.Inode.
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 2560fcef9..ca8c87ec2 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -16,6 +16,7 @@ package proc
import (
"fmt"
+ "math"
"path"
"strconv"
"testing"
@@ -30,6 +31,18 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+var (
+ // Next offset 256 by convention. Adds 1 for the next offset.
+ selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1}
+ threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1}
+
+ // /proc/[pid] next offset starts at 256+2 (files above), then adds the
+ // PID, and adds 1 for the next offset.
+ proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1}
+ proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1}
+ proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1}
+)
+
type testIterDirentsCallback struct {
dirents []vfs.Dirent
}
@@ -59,9 +72,9 @@ func checkTasksStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
"loadavg": {Type: linux.DT_REG},
"meminfo": {Type: linux.DT_REG},
"mounts": {Type: linux.DT_LNK},
- "self": {Type: linux.DT_LNK},
+ "self": selfLink,
"stat": {Type: linux.DT_REG},
- "thread-self": {Type: linux.DT_LNK},
+ "thread-self": threadSelfLink,
"version": {Type: linux.DT_REG},
}
return checkFiles(gots, wants)
@@ -93,6 +106,9 @@ func checkFiles(gots []vfs.Dirent, wants map[string]vfs.Dirent) ([]vfs.Dirent, e
if want.Type != got.Type {
return gots, fmt.Errorf("wrong file type, want: %v, got: %v: %+v", want.Type, got.Type, got)
}
+ if want.NextOff != 0 && want.NextOff != got.NextOff {
+ return gots, fmt.Errorf("wrong dirent offset, want: %v, got: %v: %+v", want.NextOff, got.NextOff, got)
+ }
delete(wants, got.Name)
gots = append(gots[0:i], gots[i+1:]...)
@@ -154,7 +170,7 @@ func TestTasksEmpty(t *testing.T) {
t.Error(err.Error())
}
if len(cb.dirents) != 0 {
- t.Error("found more files than expected: %+v", cb.dirents)
+ t.Errorf("found more files than expected: %+v", cb.dirents)
}
}
@@ -216,6 +232,11 @@ func TestTasks(t *testing.T) {
if !found {
t.Errorf("Additional task ID %d listed: %v", pid, tasks)
}
+ // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the
+ // PID, and adds 1 for the next offset.
+ if want := int64(256 + 2 + pid + 1); d.NextOff != want {
+ t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d)
+ }
}
// Test lookup.
@@ -246,6 +267,126 @@ func TestTasks(t *testing.T) {
}
}
+func TestTasksOffset(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ for i := 0; i < 3; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ if _, err := createTask(ctx, fmt.Sprintf("name-%d", i), tc); err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ }
+
+ for _, tc := range []struct {
+ name string
+ offset int64
+ wants map[string]vfs.Dirent
+ }{
+ {
+ name: "small offset",
+ offset: 100,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "offset at start",
+ offset: 256,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip /proc/self",
+ offset: 257,
+ wants: map[string]vfs.Dirent{
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip symlinks",
+ offset: 258,
+ wants: map[string]vfs.Dirent{
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip first process",
+ offset: 260,
+ wants: map[string]vfs.Dirent{
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "last process",
+ offset: 261,
+ wants: map[string]vfs.Dirent{
+ "3": proc3,
+ },
+ },
+ {
+ name: "after last",
+ offset: 262,
+ wants: nil,
+ },
+ {
+ name: "TaskLimit+1",
+ offset: kernel.TasksLimit + 1,
+ wants: nil,
+ },
+ {
+ name: "max",
+ offset: math.MaxInt64,
+ wants: nil,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+ if _, err := fd.Impl().Seek(ctx, tc.offset, linux.SEEK_SET); err != nil {
+ t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ if cb.dirents, err = checkFiles(cb.dirents, tc.wants); err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+ })
+ }
+}
+
func TestTask(t *testing.T) {
ctx, vfsObj, root, err := setup()
if err != nil {
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 5689bed3b..a5b285987 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -1,14 +1,13 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "dentry_list",
out = "dentry_list.go",
- package = "memfs",
+ package = "tmpfs",
prefix = "dentry",
template = "//pkg/ilist:generic_list",
types = {
@@ -18,25 +17,34 @@ go_template_instance(
)
go_library(
- name = "memfs",
+ name = "tmpfs",
srcs = [
"dentry_list.go",
"directory.go",
"filesystem.go",
- "memfs.go",
"named_pipe.go",
"regular_file.go",
"symlink.go",
+ "tmpfs.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs",
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/fspath",
+ "//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/safemem",
+ "//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
@@ -48,7 +56,7 @@ go_test(
size = "small",
srcs = ["benchmark_test.go"],
deps = [
- ":memfs",
+ ":tmpfs",
"//pkg/abi/linux",
"//pkg/fspath",
"//pkg/refs",
@@ -63,16 +71,20 @@ go_test(
)
go_test(
- name = "memfs_test",
+ name = "tmpfs_test",
size = "small",
- srcs = ["pipe_test.go"],
- embed = [":memfs"],
+ srcs = [
+ "pipe_test.go",
+ "regular_file_test.go",
+ ],
+ embed = [":tmpfs"],
deps = [
"//pkg/abi/linux",
"//pkg/fspath",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index a27876a4e..d88c83499 100644
--- a/pkg/sentry/fsimpl/memfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -27,7 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -176,10 +176,10 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -367,10 +367,10 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -399,7 +399,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
}
defer mountPoint.DecRef()
// Create and mount the submount.
- if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil {
+ if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
filePathBuilder.WriteString(mountPointName)
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 0bd82e480..887ca2619 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index b063e09a3..26979729e 100644
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"fmt"
@@ -50,7 +50,7 @@ afterSymlink:
return nil, err
}
if nextVFSD == nil {
- // Since the Dentry tree is the sole source of truth for memfs, if it's
+ // Since the Dentry tree is the sole source of truth for tmpfs, if it's
// not in the Dentry tree, it doesn't exist.
return nil, syserror.ENOENT
}
@@ -351,8 +351,8 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32,
fd.vfsfd.Init(&fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{})
if flags&linux.O_TRUNC != 0 {
impl.mu.Lock()
- impl.data = impl.data[:0]
- atomic.StoreInt64(&impl.dataLen, 0)
+ impl.data.Truncate(0, impl.memFile)
+ atomic.StoreUint64(&impl.size, 0)
impl.mu.Unlock()
}
return &fd.vfsfd, nil
diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index b5a204438..40bde54de 100644
--- a/pkg/sentry/fsimpl/memfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 807c1af7a..70b42a6ec 100644
--- a/pkg/sentry/fsimpl/memfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"bytes"
@@ -152,10 +152,10 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
new file mode 100644
index 000000000..f51e247a7
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -0,0 +1,357 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "io"
+ "math"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type regularFile struct {
+ inode inode
+
+ // memFile is a platform.File used to allocate pages to this regularFile.
+ memFile *pgalloc.MemoryFile
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // data maps offsets into the file to offsets into memFile that store
+ // the file's data.
+ data fsutil.FileRangeSet
+
+ // size is the size of data, but accessed using atomic memory
+ // operations to avoid locking in inode.stat().
+ size uint64
+
+ // seals represents file seals on this inode.
+ seals uint32
+}
+
+func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
+ file := &regularFile{
+ memFile: fs.memFile,
+ }
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // These are immutable.
+ readable bool
+ writable bool
+
+ // off is the file offset. off is accessed using atomic memory operations.
+ // offMu serializes operations that may mutate off.
+ off int64
+ offMu sync.Mutex
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {
+ if fd.writable {
+ fd.vfsfd.VirtualDentry().Mount().EndWrite()
+ }
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putRegularFileReadWriter(rw)
+ return int64(n), err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ srclen := src.NumBytes()
+ if srclen == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ end := offset + srclen
+ if end < offset {
+ // Overflow.
+ return 0, syserror.EFBIG
+ }
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := src.CopyInTo(ctx, rw)
+ putRegularFileReadWriter(rw)
+ return n, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.offMu.Lock()
+ defer fd.offMu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size))
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *regularFileFD) Sync(ctx context.Context) error {
+ return nil
+}
+
+// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
+type regularFileReadWriter struct {
+ file *regularFile
+
+ // Offset into the file to read/write at. Note that this may be
+ // different from the FD offset if PRead/PWrite is used.
+ off uint64
+}
+
+var regularFileReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &regularFileReadWriter{}
+ },
+}
+
+func getRegularFileReadWriter(file *regularFile, offset int64) *regularFileReadWriter {
+ rw := regularFileReadWriterPool.Get().(*regularFileReadWriter)
+ rw.file = file
+ rw.off = uint64(offset)
+ return rw
+}
+
+func putRegularFileReadWriter(rw *regularFileReadWriter) {
+ rw.file = nil
+ regularFileReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ rw.file.mu.RLock()
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= rw.file.size {
+ rw.file.mu.RUnlock()
+ return 0, io.EOF
+ }
+ end := rw.file.size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Tmpfs holes are zero-filled.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := safemem.ZeroSeq(dst)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ rw.file.mu.RUnlock()
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ rw.file.mu.Lock()
+
+ // Compute the range to write (overflow-checked).
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ // Check if seals prevent either file growth or all writes.
+ switch {
+ case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
+ rw.file.mu.Unlock()
+ return 0, syserror.EPERM
+ case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ // When growth is sealed, Linux effectively allows writes which would
+ // normally grow the file to partially succeed up to the current EOF,
+ // rounded down to the page boundary before the EOF.
+ //
+ // This happens because writes (and thus the growth check) for tmpfs
+ // files proceed page-by-page on Linux, and the final write to the page
+ // containing EOF fails, resulting in a partial write up to the start of
+ // that page.
+ //
+ // To emulate this behaviour, artifically truncate the write to the
+ // start of the page containing the current EOF.
+ //
+ // See Linux, mm/filemap.c:generic_perform_write() and
+ // mm/shmem.c:shmem_write_begin().
+ if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart {
+ end = pgstart
+ }
+ if end <= rw.off {
+ // Truncation would result in no data being written.
+ rw.file.mu.Unlock()
+ return 0, syserror.EPERM
+ }
+ }
+
+ // Page-aligned mr for when we need to allocate memory. RoundUp can't
+ // overflow since end is an int64.
+ pgstartaddr := usermem.Addr(rw.off).RoundDown()
+ pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += uint64(n)
+ srcs = srcs.DropFirst64(n)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Allocate memory for the write.
+ gapMR := gap.Range().Intersect(pgMR)
+ fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Write to that memory as usual.
+ seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.off > rw.file.size {
+ atomic.StoreUint64(&rw.file.size, rw.off)
+ }
+
+ rw.file.mu.Unlock()
+ return done, retErr
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
new file mode 100644
index 000000000..3731c5b6f
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -0,0 +1,224 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If
+// the returned err is not nil, then cleanup should be called when the FD is no
+// longer needed.
+func newFileFD(ctx context.Context, filename string) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
+ }
+ root := mntns.Root()
+
+ // Create the file that will be write/read.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ FollowFinalSymlink: true,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: 0644,
+ })
+ if err != nil {
+ root.DecRef()
+ mntns.DecRef(vfsObj)
+ return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err)
+ }
+
+ return fd, func() {
+ root.DecRef()
+ mntns.DecRef(vfsObj)
+ }, nil
+}
+
+// Test that we can write some data to a file and read it back.`
+func TestSimpleWriteRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, "simpleReadWrite")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write.
+ data := []byte("foobarbaz")
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+
+ // Seek back to beginning.
+ if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil {
+ t.Fatalf("fd.Seek failed: %v", err)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want {
+ t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want)
+ }
+
+ // Read.
+ buf := make([]byte, len(data))
+ n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Read got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(buf), string(data); got != want {
+ t.Errorf("Read got %q want %s", got, want)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+}
+
+func TestPWrite(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, "PRead")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill file with 1k 'a's.
+ data := bytes.Repeat([]byte{'a'}, 1000)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Write "gVisor is awesome" at various offsets.
+ buf := []byte("gVisor is awesome")
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PWrite offset=%d", offset)
+ t.Run(name, func(t *testing.T) {
+ n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{})
+ if err != nil {
+ t.Errorf("fd.PWrite got err %v want nil", err)
+ }
+ if n != int64(len(buf)) {
+ t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf))
+ }
+
+ // Update data to reflect expected file contents.
+ if len(data) < offset+len(buf) {
+ data = append(data, make([]byte, (offset+len(buf))-len(data))...)
+ }
+ copy(data[offset:], buf)
+
+ // Read the whole file and compare with data.
+ readBuf := make([]byte, len(data))
+ n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("fd.PRead failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.PRead got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(readBuf), string(data); got != want {
+ t.Errorf("PRead got %q want %s", got, want)
+ }
+
+ })
+ }
+}
+
+func TestPRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, "PRead")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write 100 sequences of 'gVisor is awesome'.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Read various sizes from various offsets.
+ sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000}
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+
+ for _, size := range sizes {
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PRead offset=%d size=%d", offset, size)
+ t.Run(name, func(t *testing.T) {
+ var (
+ wantRead []byte
+ wantErr error
+ )
+ if offset < len(data) {
+ wantRead = data[offset:]
+ } else if size > 0 {
+ wantErr = io.EOF
+ }
+ if offset+size < len(data) {
+ wantRead = wantRead[:size]
+ }
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{})
+ if err != wantErr {
+ t.Errorf("fd.PRead got err %v want %v", err, wantErr)
+ }
+ if n != int64(len(wantRead)) {
+ t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead))
+ }
+ if got := string(buf[:n]); got != string(wantRead) {
+ t.Errorf("fd.PRead got %q want %q", got, string(wantRead))
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/sentry/fsimpl/memfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index b2ac2cbeb..5246aca84 100644
--- a/pkg/sentry/fsimpl/memfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 8d0167c93..7be6faa5b 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -12,20 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package memfs provides a filesystem implementation that behaves like tmpfs:
+// Package tmpfs provides a filesystem implementation that behaves like tmpfs:
// the Dentry tree is the sole source of truth for the state of the filesystem.
//
-// memfs is intended primarily to demonstrate filesystem implementation
-// patterns. Real uses cases for an in-memory filesystem should use tmpfs
-// instead.
-//
// Lock order:
//
// filesystem.mu
// regularFileFD.offMu
// regularFile.mu
// inode.mu
-package memfs
+package tmpfs
import (
"fmt"
@@ -36,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -47,6 +44,9 @@ type FilesystemType struct{}
type filesystem struct {
vfsfs vfs.Filesystem
+ // memFile is used to allocate pages to for regular files.
+ memFile *pgalloc.MemoryFile
+
// mu serializes changes to the Dentry tree.
mu sync.RWMutex
@@ -55,7 +55,13 @@ type filesystem struct {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- var fs filesystem
+ memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
+ if memFileProvider == nil {
+ panic("MemoryFileProviderFromContext returned nil")
+ }
+ fs := filesystem{
+ memFile: memFileProvider.MemoryFile(),
+ }
fs.vfsfs.Init(vfsObj, &fs)
root := fs.newDentry(fs.newDirectory(creds, 01777))
return &fs.vfsfs, &root.vfsd, nil
@@ -74,11 +80,11 @@ type dentry struct {
// immutable.
inode *inode
- // memfs doesn't count references on dentries; because the dentry tree is
+ // tmpfs doesn't count references on dentries; because the dentry tree is
// the sole source of truth, it is by definition always consistent with the
// state of the filesystem. However, it does count references on inodes,
// because inode resources are released when all references are dropped.
- // (memfs doesn't really have resources to release, but we implement
+ // (tmpfs doesn't really have resources to release, but we implement
// reference counting because tmpfs regular files will.)
// dentryEntry (ugh) links dentries into their parent directory.childList.
@@ -150,7 +156,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials,
// i.nlink < maxLinks.
func (i *inode) incLinksLocked() {
if i.nlink == 0 {
- panic("memfs.inode.incLinksLocked() called with no existing links")
+ panic("tmpfs.inode.incLinksLocked() called with no existing links")
}
if i.nlink == maxLinks {
panic("memfs.inode.incLinksLocked() called with maximum link count")
@@ -163,14 +169,14 @@ func (i *inode) incLinksLocked() {
// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
func (i *inode) decLinksLocked() {
if i.nlink == 0 {
- panic("memfs.inode.decLinksLocked() called with no existing links")
+ panic("tmpfs.inode.decLinksLocked() called with no existing links")
}
atomic.AddUint32(&i.nlink, ^uint32(0))
}
func (i *inode) incRef() {
if atomic.AddInt64(&i.refs, 1) <= 1 {
- panic("memfs.inode.incRef() called without holding a reference")
+ panic("tmpfs.inode.incRef() called without holding a reference")
}
}
@@ -189,14 +195,14 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
// This is unnecessary; it's mostly to simulate what tmpfs would do.
- if regfile, ok := i.impl.(*regularFile); ok {
- regfile.mu.Lock()
- regfile.data = nil
- atomic.StoreInt64(&regfile.dataLen, 0)
- regfile.mu.Unlock()
+ if regFile, ok := i.impl.(*regularFile); ok {
+ regFile.mu.Lock()
+ regFile.data.DropAll(regFile.memFile)
+ atomic.StoreUint64(&regFile.size, 0)
+ regFile.mu.Unlock()
}
} else if refs < 0 {
- panic("memfs.inode.decRef() called without holding a reference")
+ panic("tmpfs.inode.decRef() called without holding a reference")
}
}
@@ -220,7 +226,7 @@ func (i *inode) statTo(stat *linux.Statx) {
case *regularFile:
stat.Mode |= linux.S_IFREG
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
- stat.Size = uint64(atomic.LoadInt64(&impl.dataLen))
+ stat.Size = uint64(atomic.LoadUint64(&impl.size))
// In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in
// a uint64 accessed using atomic memory operations to avoid taking
// locks).
@@ -261,7 +267,7 @@ func (i *inode) direntType() uint8 {
}
}
-// fileDescription is embedded by memfs implementations of
+// fileDescription is embedded by tmpfs implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
vfsfd vfs.FileDescription
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 24ea002ba..b14429854 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -15,17 +15,29 @@
package kernel
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
-// Restartable sequences, as described in https://lwn.net/Articles/650333/.
+// Restartable sequences.
+//
+// We support two different APIs for restartable sequences.
+//
+// 1. The upstream interface added in v4.18.
+// 2. The interface described in https://lwn.net/Articles/650333/.
+//
+// Throughout this file and other parts of the kernel, the latter is referred
+// to as "old rseq". This interface was never merged upstream, but is supported
+// for a limited set of applications that use it regardless.
-// RSEQCriticalRegion describes a restartable sequence critical region.
+// OldRSeqCriticalRegion describes an old rseq critical region.
//
// +stateify savable
-type RSEQCriticalRegion struct {
+type OldRSeqCriticalRegion struct {
// When a task in this thread group has its CPU preempted (as defined by
// platform.ErrContextCPUPreempted) or has a signal delivered to an
// application handler while its instruction pointer is in CriticalSection,
@@ -35,86 +47,359 @@ type RSEQCriticalRegion struct {
Restart usermem.Addr
}
-// RSEQAvailable returns true if t supports restartable sequences.
-func (t *Task) RSEQAvailable() bool {
+// RSeqAvailable returns true if t supports (old and new) restartable sequences.
+func (t *Task) RSeqAvailable() bool {
return t.k.useHostCores && t.k.Platform.DetectsCPUPreemption()
}
-// RSEQCriticalRegion returns a copy of t's thread group's current restartable
-// sequence.
-func (t *Task) RSEQCriticalRegion() RSEQCriticalRegion {
- return *t.tg.rscr.Load().(*RSEQCriticalRegion)
+// SetRSeq registers addr as this thread's rseq structure.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr != 0 {
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EINVAL
+ }
+ return syserror.EBUSY
+ }
+
+ // rseq must be aligned and correctly sized.
+ if addr&(linux.AlignOfRSeq-1) != 0 {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok {
+ return syserror.EFAULT
+ }
+
+ t.rseqAddr = addr
+ t.rseqSignature = signature
+
+ // Initialize the CPUID.
+ //
+ // Linux implicitly does this on return from userspace, where failure
+ // would cause SIGSEGV.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return syserror.EFAULT
+ }
+
+ return nil
}
-// SetRSEQCriticalRegion replaces t's thread group's restartable sequence.
+// ClearRSeq unregisters addr as this thread's rseq structure.
//
-// Preconditions: t.RSEQAvailable() == true.
-func (t *Task) SetRSEQCriticalRegion(rscr RSEQCriticalRegion) error {
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr == 0 {
+ return syserror.EINVAL
+ }
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EPERM
+ }
+
+ if err := t.rseqClearCPU(); err != nil {
+ return err
+ }
+
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ if t.oldRSeqCPUAddr == 0 {
+ // rseqCPU no longer needed.
+ t.rseqCPU = -1
+ }
+
+ return nil
+}
+
+// OldRSeqCriticalRegion returns a copy of t's thread group's current
+// old restartable sequence.
+func (t *Task) OldRSeqCriticalRegion() OldRSeqCriticalRegion {
+ return *t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+}
+
+// SetOldRSeqCriticalRegion replaces t's thread group's old restartable
+// sequence.
+//
+// Preconditions: t.RSeqAvailable() == true.
+func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error {
// These checks are somewhat more lenient than in Linux, which (bizarrely)
- // requires rscr.CriticalSection to be non-empty and rscr.Restart to be
- // outside of rscr.CriticalSection, even if rscr.CriticalSection.Start == 0
+ // requires r.CriticalSection to be non-empty and r.Restart to be
+ // outside of r.CriticalSection, even if r.CriticalSection.Start == 0
// (which disables the critical region).
- if rscr.CriticalSection.Start == 0 {
- rscr.CriticalSection.End = 0
- rscr.Restart = 0
- t.tg.rscr.Store(&rscr)
+ if r.CriticalSection.Start == 0 {
+ r.CriticalSection.End = 0
+ r.Restart = 0
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
- if rscr.CriticalSection.Start >= rscr.CriticalSection.End {
+ if r.CriticalSection.Start >= r.CriticalSection.End {
return syserror.EINVAL
}
- if rscr.CriticalSection.Contains(rscr.Restart) {
+ if r.CriticalSection.Contains(r.Restart) {
return syserror.EINVAL
}
- // TODO(jamieliu): check that rscr.CriticalSection and rscr.Restart are in
- // the application address range, for consistency with Linux
- t.tg.rscr.Store(&rscr)
+ // TODO(jamieliu): check that r.CriticalSection and r.Restart are in
+ // the application address range, for consistency with Linux.
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
-// RSEQCPUAddr returns the address that RSEQ will keep updated with t's CPU
-// number.
+// OldRSeqCPUAddr returns the address that old rseq will keep updated with t's
+// CPU number.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) RSEQCPUAddr() usermem.Addr {
- return t.rseqCPUAddr
+func (t *Task) OldRSeqCPUAddr() usermem.Addr {
+ return t.oldRSeqCPUAddr
}
-// SetRSEQCPUAddr replaces the address that RSEQ will keep updated with t's CPU
-// number.
+// SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with
+// t's CPU number.
//
-// Preconditions: t.RSEQAvailable() == true. The caller must be running on the
+// Preconditions: t.RSeqAvailable() == true. The caller must be running on the
// task goroutine. t's AddressSpace must be active.
-func (t *Task) SetRSEQCPUAddr(addr usermem.Addr) error {
- t.rseqCPUAddr = addr
- if addr != 0 {
- t.rseqCPU = int32(hostcpu.GetCPU())
- if err := t.rseqCopyOutCPU(); err != nil {
- t.rseqCPUAddr = 0
- t.rseqCPU = -1
- return syserror.EINVAL // yes, EINVAL, not err or EFAULT
- }
- } else {
- t.rseqCPU = -1
+func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error {
+ t.oldRSeqCPUAddr = addr
+
+ // Check that addr is writable.
+ //
+ // N.B. rseqUpdateCPU may fail on a bad t.rseqAddr as well. That's
+ // unfortunate, but unlikely in a correct program.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.oldRSeqCPUAddr = 0
+ return syserror.EINVAL // yes, EINVAL, not err or EFAULT
}
return nil
}
// Preconditions: The caller must be running on the task goroutine. t's
// AddressSpace must be active.
-func (t *Task) rseqCopyOutCPU() error {
+func (t *Task) rseqUpdateCPU() error {
+ if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 {
+ t.rseqCPU = -1
+ return nil
+ }
+
+ t.rseqCPU = int32(hostcpu.GetCPU())
+
+ // Update both CPUs, even if one fails.
+ rerr := t.rseqCopyOutCPU()
+ oerr := t.oldRSeqCopyOutCPU()
+
+ if rerr != nil {
+ return rerr
+ }
+ return oerr
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) oldRSeqCopyOutCPU() error {
+ if t.oldRSeqCPUAddr == 0 {
+ return nil
+ }
+
buf := t.CopyScratchBuffer(4)
usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
- _, err := t.CopyOutBytes(t.rseqCPUAddr, buf)
+ _, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqCopyOutCPU() error {
+ if t.rseqAddr == 0 {
+ return nil
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqClearCPU() error {
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
return err
}
+// rseqAddrInterrupt checks if IP is in a critical section, and aborts if so.
+//
+// This is a bit complex since both the RSeq and RSeqCriticalSection structs
+// are stored in userspace. So we must:
+//
+// 1. Copy in the address of RSeqCriticalSection from RSeq.
+// 2. Copy in RSeqCriticalSection itself.
+// 3. Validate critical section struct version, address range, abort address.
+// 4. Validate the abort signature (4 bytes preceding abort IP match expected
+// signature).
+// 5. Clear address of RSeqCriticalSection from RSeq.
+// 6. Finally, conditionally abort.
+//
+// See kernel/rseq.c:rseq_ip_fixup for reference.
+//
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqAddrInterrupt() {
+ if t.rseqAddr == 0 {
+ return
+ }
+
+ critAddrAddr, ok := t.rseqAddr.AddLength(linux.OffsetOfRSeqCriticalSection)
+ if !ok {
+ // SetRSeq should validate this.
+ panic(fmt.Sprintf("t.rseqAddr (%#x) not large enough", t.rseqAddr))
+ }
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ t.Debugf("Only 64-bit rseq supported.")
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ if _, err := t.CopyInBytes(critAddrAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf))
+ if critAddr == 0 {
+ return
+ }
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqCriticalSection)
+ if _, err := t.CopyInBytes(critAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Manually marshal RSeqCriticalSection as this is in the hot path when
+ // rseq is enabled. It must be as fast as possible.
+ //
+ // TODO(b/130243041): Replace with go_marshal.
+ cs := linux.RSeqCriticalSection{
+ Version: usermem.ByteOrder.Uint32(buf[0:4]),
+ Flags: usermem.ByteOrder.Uint32(buf[4:8]),
+ Start: usermem.ByteOrder.Uint64(buf[8:16]),
+ PostCommitOffset: usermem.ByteOrder.Uint64(buf[16:24]),
+ Abort: usermem.ByteOrder.Uint64(buf[24:32]),
+ }
+
+ if cs.Version != 0 {
+ t.Debugf("Unknown version in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ start := usermem.Addr(cs.Start)
+ critRange, ok := start.ToRange(cs.PostCommitOffset)
+ if !ok {
+ t.Debugf("Invalid start and offset in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ abort := usermem.Addr(cs.Abort)
+ if critRange.Contains(abort) {
+ t.Debugf("Abort in critical section in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Verify signature.
+ sigAddr := abort - linux.SizeOfRSeqSignature
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqSignature)
+ if _, err := t.CopyInBytes(sigAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section signature from %#x for rseq: %v", sigAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ sig := usermem.ByteOrder.Uint32(buf)
+ if sig != t.rseqSignature {
+ t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Clear the critical section address.
+ //
+ // NOTE(b/143949567): We don't support any rseq flags, so we always
+ // restart if we are in the critical section, and thus *always* clear
+ // critAddrAddr.
+ if _, err := t.MemoryManager().ZeroOut(t, critAddrAddr, int64(t.Arch().Width()), usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ t.Debugf("Failed to clear critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Finally we can actually decide whether or not to restart.
+ if !critRange.Contains(usermem.Addr(t.Arch().IP())) {
+ return
+ }
+
+ t.Arch().SetIP(uintptr(cs.Abort))
+}
+
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) rseqInterrupt() {
- rscr := t.tg.rscr.Load().(*RSEQCriticalRegion)
- if ip := t.Arch().IP(); rscr.CriticalSection.Contains(usermem.Addr(ip)) {
- t.Debugf("Interrupted RSEQ critical section at %#x; restarting at %#x", ip, rscr.Restart)
- t.Arch().SetIP(uintptr(rscr.Restart))
- t.Arch().SetRSEQInterruptedIP(ip)
+func (t *Task) oldRSeqInterrupt() {
+ r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+ if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) {
+ t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart)
+ t.Arch().SetIP(uintptr(r.Restart))
+ t.Arch().SetOldRSeqInterruptedIP(ip)
}
}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) rseqInterrupt() {
+ t.rseqAddrInterrupt()
+ t.oldRSeqInterrupt()
+}
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 5bd610f68..19034a21e 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -71,9 +71,20 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
// shms maps segment ids to segments.
+ //
+ // shms holds all referenced segments, which are removed on the last
+ // DecRef. Thus, it cannot itself hold a reference on the Shm.
+ //
+ // Since removal only occurs after the last (unlocked) DecRef, there
+ // exists a short window during which a Shm still exists in Shm, but is
+ // unreferenced. Users must use TryIncRef to determine if the Shm is
+ // still valid.
shms map[ID]*Shm
// keysToShms maps segment keys to segments.
+ //
+ // Shms in keysToShms are guaranteed to be referenced, as they are
+ // removed by disassociateKey before the last DecRef.
keysToShms map[Key]*Shm
// Sum of the sizes of all existing segments rounded up to page size, in
@@ -95,10 +106,18 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
}
// FindByID looks up a segment given an ID.
+//
+// FindByID returns a reference on Shm.
func (r *Registry) FindByID(id ID) *Shm {
r.mu.Lock()
defer r.mu.Unlock()
- return r.shms[id]
+ s := r.shms[id]
+ // Take a reference on s. If TryIncRef fails, s has reached the last
+ // DecRef, but hasn't quite been removed from r.shms yet.
+ if s != nil && s.TryIncRef() {
+ return s
+ }
+ return nil
}
// dissociateKey removes the association between a segment and its key,
@@ -119,6 +138,8 @@ func (r *Registry) dissociateKey(s *Shm) {
// FindOrCreate looks up or creates a segment in the registry. It's functionally
// analogous to open(2).
+//
+// FindOrCreate returns a reference on Shm.
func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
// "A new segment was to be created and size is less than SHMMIN or
@@ -166,6 +187,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
return nil, syserror.EEXIST
}
+ shm.IncRef()
return shm, nil
}
@@ -193,7 +215,14 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
// Need to create a new segment.
creator := fs.FileOwnerFromContext(ctx)
perms := fs.FilePermsFromMode(mode)
- return r.newShm(ctx, pid, key, creator, perms, size)
+ s, err := r.newShm(ctx, pid, key, creator, perms, size)
+ if err != nil {
+ return nil, err
+ }
+ // The initial reference is held by s itself. Take another to return to
+ // the caller.
+ s.IncRef()
+ return s, nil
}
// newShm creates a new segment in the registry.
@@ -296,22 +325,26 @@ func (r *Registry) remove(s *Shm) {
// Shm represents a single shared memory segment.
//
-// Shm segment are backed directly by an allocation from platform
-// memory. Segments are always mapped as a whole, greatly simplifying how
-// mappings are tracked. However note that mremap and munmap calls may cause the
-// vma for a segment to become fragmented; which requires special care when
-// unmapping a segment. See mm/shm.go.
+// Shm segment are backed directly by an allocation from platform memory.
+// Segments are always mapped as a whole, greatly simplifying how mappings are
+// tracked. However note that mremap and munmap calls may cause the vma for a
+// segment to become fragmented; which requires special care when unmapping a
+// segment. See mm/shm.go.
//
// Segments persist until they are explicitly marked for destruction via
-// shmctl(SHM_RMID).
+// MarkDestroyed().
//
// Shm implements memmap.Mappable and memmap.MappingIdentity.
//
// +stateify savable
type Shm struct {
- // AtomicRefCount tracks the number of references to this segment from
- // maps. A segment always holds a reference to itself, until it's marked for
+ // AtomicRefCount tracks the number of references to this segment.
+ //
+ // A segment holds a reference to itself until it is marked for
// destruction.
+ //
+ // In addition to direct users, the MemoryManager will hold references
+ // via MappingIdentity.
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
@@ -484,9 +517,8 @@ type AttachOpts struct {
// ConfigureAttach creates an mmap configuration for the segment with the
// requested attach options.
//
-// ConfigureAttach returns with a ref on s on success. The caller should drop
-// this once the map is installed. This reference prevents s from being
-// destroyed before the returned configuration is used.
+// Postconditions: The returned MMapOpts are valid only as long as a reference
+// continues to be held on s.
func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -504,7 +536,6 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts Attac
// in the user namespace that governs its IPC namespace." - man shmat(2)
return memmap.MMapOpts{}, syserror.EACCES
}
- s.IncRef()
return memmap.MMapOpts{
Length: s.size,
Offset: 0,
@@ -549,10 +580,15 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
}
creds := auth.CredentialsFromContext(ctx)
- nattach := uint64(s.ReadRefs())
- // Don't report the self-reference we keep prior to being marked for
- // destruction. However, also don't report a count of -1 for segments marked
- // as destroyed, with no mappings.
+ // Use the reference count as a rudimentary count of the number of
+ // attaches. We exclude:
+ //
+ // 1. The reference the caller holds.
+ // 2. The self-reference held by s prior to destruction.
+ //
+ // Note that this may still overcount by including transient references
+ // used in concurrent calls.
+ nattach := uint64(s.ReadRefs()) - 1
if !s.pendingDestruction {
nattach--
}
@@ -620,18 +656,17 @@ func (s *Shm) MarkDestroyed() {
s.registry.dissociateKey(s)
s.mu.Lock()
- // Only drop the segment's self-reference once, when destruction is
- // requested. Otherwise, repeated calls to shmctl(IPC_RMID) would force a
- // segment to be destroyed prematurely, potentially with active maps to the
- // segment's address range. Remaining references are dropped when the
- // segment is detached or unmaped.
+ defer s.mu.Unlock()
if !s.pendingDestruction {
s.pendingDestruction = true
- s.mu.Unlock() // Must release s.mu before calling s.DecRef.
+ // Drop the self-reference so destruction occurs when all
+ // external references are gone.
+ //
+ // N.B. This cannot be the final DecRef, as the caller also
+ // holds a reference.
s.DecRef()
return
}
- s.mu.Unlock()
}
// checkOwnership verifies whether a segment may be accessed by ctx as an
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index ab0c6c4aa..d25a7903b 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -489,18 +489,43 @@ type Task struct {
// netns is protected by mu. netns is owned by the task goroutine.
netns bool
- // If rseqPreempted is true, before the next call to p.Switch(), interrupt
- // RSEQ critical regions as defined by tg.rseq and write the task
- // goroutine's CPU number to rseqCPUAddr. rseqCPU is the last CPU number
- // written to rseqCPUAddr.
+ // If rseqPreempted is true, before the next call to p.Switch(),
+ // interrupt rseq critical regions as defined by rseqAddr and
+ // tg.oldRSeqCritical and write the task goroutine's CPU number to
+ // rseqAddr/oldRSeqCPUAddr.
//
- // If rseqCPUAddr is 0, rseqCPU is -1.
+ // We support two ABIs for restartable sequences:
//
- // rseqCPUAddr, rseqCPU, and rseqPreempted are exclusive to the task
- // goroutine.
+ // 1. The upstream interface added in v4.18,
+ // 2. An "old" interface never merged upstream. In the implementation,
+ // this is referred to as "old rseq".
+ //
+ // rseqPreempted is exclusive to the task goroutine.
rseqPreempted bool `state:"nosave"`
- rseqCPUAddr usermem.Addr
- rseqCPU int32
+
+ // rseqCPU is the last CPU number written to rseqAddr/oldRSeqCPUAddr.
+ //
+ // If rseq is unused, rseqCPU is -1 for convenient use in
+ // platform.Context.Switch.
+ //
+ // rseqCPU is exclusive to the task goroutine.
+ rseqCPU int32
+
+ // oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable.
+ //
+ // oldRSeqCPUAddr is exclusive to the task goroutine.
+ oldRSeqCPUAddr usermem.Addr
+
+ // rseqAddr is a pointer to the userspace linux.RSeq structure.
+ //
+ // rseqAddr is exclusive to the task goroutine.
+ rseqAddr usermem.Addr
+
+ // rseqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ //
+ // rseqSignature is exclusive to the task goroutine.
+ rseqSignature uint32
// copyScratchBuffer is a buffer available to CopyIn/CopyOut
// implementations that require an intermediate buffer to copy data
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 5f3589493..247bd4aba 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -236,7 +236,10 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else if opts.NewPIDNamespace {
pidns = pidns.NewChild(userns)
}
+
tg := t.tg
+ rseqAddr := usermem.Addr(0)
+ rseqSignature := uint32(0)
if opts.NewThreadGroup {
tg.mounts.IncRef()
sh := t.tg.signalHandlers
@@ -244,6 +247,8 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
sh = sh.Fork()
}
tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+ rseqAddr = t.rseqAddr
+ rseqSignature = t.rseqSignature
}
cfg := &TaskConfig{
@@ -260,6 +265,8 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
UTSNamespace: utsns,
IPCNamespace: ipcns,
AbstractSocketNamespace: t.abstractSockets,
+ RSeqAddr: rseqAddr,
+ RSeqSignature: rseqSignature,
ContainerID: t.ContainerID(),
}
if opts.NewThreadGroup {
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 90a6190f1..fa6528386 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -190,9 +190,11 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.updateRSSLocked()
// Restartable sequence state is discarded.
t.rseqPreempted = false
- t.rseqCPUAddr = 0
t.rseqCPU = -1
- t.tg.rscr.Store(&RSEQCriticalRegion{})
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+ t.oldRSeqCPUAddr = 0
+ t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
t.tg.pidns.owner.mu.Unlock()
// Remove FDs with the CloseOnExec flag set.
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index d97f8c189..6357273d3 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -169,12 +169,22 @@ func (*runApp) execute(t *Task) taskRunState {
// Apply restartable sequences.
if t.rseqPreempted {
t.rseqPreempted = false
- if t.rseqCPUAddr != 0 {
+ if t.rseqAddr != 0 || t.oldRSeqCPUAddr != 0 {
+ // Linux writes the CPU on every preemption. We only do
+ // so if it changed. Thus we may delay delivery of
+ // SIGSEGV if rseqAddr/oldRSeqCPUAddr is invalid.
cpu := int32(hostcpu.GetCPU())
if t.rseqCPU != cpu {
t.rseqCPU = cpu
if err := t.rseqCopyOutCPU(); err != nil {
- t.Warningf("Failed to copy CPU to %#x for RSEQ: %v", t.rseqCPUAddr, err)
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // Re-enter the task run loop for signal delivery.
+ return (*runApp)(nil)
+ }
+ if err := t.oldRSeqCopyOutCPU(); err != nil {
+ t.Debugf("Failed to copy CPU to %#x for old rseq: %v", t.oldRSeqCPUAddr, err)
t.forceSignal(linux.SIGSEGV, false)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
// Re-enter the task run loop for signal delivery.
@@ -320,7 +330,7 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextCPUPreempted:
- // Ensure that RSEQ critical sections are interrupted and per-thread
+ // Ensure that rseq critical sections are interrupted and per-thread
// CPU values are updated before the next platform.Context.Switch().
t.rseqPreempted = true
return (*runApp)(nil)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 3522a4ae5..58af16ee2 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -79,6 +80,13 @@ type TaskConfig struct {
// AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
AbstractSocketNamespace *AbstractSocketNamespace
+ // RSeqAddr is a pointer to the the userspace linux.RSeq structure.
+ RSeqAddr usermem.Addr
+
+ // RSeqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ RSeqSignature uint32
+
// ContainerID is the container the new task belongs to.
ContainerID string
}
@@ -126,6 +134,8 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
rseqCPU: -1,
+ rseqAddr: cfg.RSeqAddr,
+ rseqSignature: cfg.RSeqSignature,
futexWaiter: futex.NewWaiter(),
containerID: cfg.ContainerID,
}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 0cded73f6..c0197a563 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -238,8 +238,8 @@ type ThreadGroup struct {
// execed is protected by the TaskSet mutex.
execed bool
- // rscr is the thread group's RSEQ critical region.
- rscr atomic.Value `state:".(*RSEQCriticalRegion)"`
+ // oldRSeqCritical is the thread group's old rseq critical region.
+ oldRSeqCritical atomic.Value `state:".(*OldRSeqCriticalRegion)"`
// mounts is the thread group's mount namespace. This does not really
// correspond to a "mount namespace" in Linux, but is more like a
@@ -273,18 +273,18 @@ func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, s
}
tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg})
tg.timers = make(map[linux.TimerID]*IntervalTimer)
- tg.rscr.Store(&RSEQCriticalRegion{})
+ tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
return tg
}
-// saveRscr is invoked by stateify.
-func (tg *ThreadGroup) saveRscr() *RSEQCriticalRegion {
- return tg.rscr.Load().(*RSEQCriticalRegion)
+// saveOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) saveOldRSeqCritical() *OldRSeqCriticalRegion {
+ return tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
}
-// loadRscr is invoked by stateify.
-func (tg *ThreadGroup) loadRscr(rscr *RSEQCriticalRegion) {
- tg.rscr.Store(rscr)
+// loadOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) loadOldRSeqCritical(r *OldRSeqCriticalRegion) {
+ tg.oldRSeqCritical.Store(r)
}
// SignalHandlers returns the signal handlers used by tg.
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index 8c2246bb4..79610acb7 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -66,8 +66,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.appendVMAMapsEntryLocked(ctx, vseg, buf)
}
@@ -81,7 +79,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallMapsEntry)
}
}
@@ -97,8 +94,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaMapsEntryLocked(ctx, vseg),
@@ -116,7 +111,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallMapsEntry),
@@ -187,15 +181,12 @@ func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffe
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf)
}
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallSmapsEntry)
}
}
@@ -211,8 +202,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaSmapsEntryLocked(ctx, vseg),
@@ -223,7 +212,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallSmapsEntry),
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
index 64c718d21..16f9c523e 100644
--- a/pkg/sentry/platform/ptrace/stub_amd64.s
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -64,6 +64,8 @@ begin:
CMPQ AX, $0
JL error
+ MOVQ $0, BX
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -73,23 +75,26 @@ begin:
MOVQ $SIGSTOP, SI
SYSCALL
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+ // The sentry sets BX to 1 when creating stub process.
+ CMPQ BX, $1
+ JE clone
+
+ // Notify the Sentry that syscall exited.
+done:
+ INT $3
+ // Be paranoid.
+ JMP done
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R15 has been updated with the expected PPID.
- JMP begin
+ CMPQ AX, $0
+ JE begin
+ // The clone syscall returns a non-zero value.
+ JMP done
error:
// Exit with -errno.
MOVQ AX, DI
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
index 2c5e4d5cb..6162df02a 100644
--- a/pkg/sentry/platform/ptrace/stub_arm64.s
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -59,6 +59,8 @@ begin:
CMP $0x0, R0
BLT error
+ MOVD $0, R9
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -66,22 +68,26 @@ begin:
MOVD $SYS_KILL, R8
MOVD $SIGSTOP, R1
SVC
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+
+ // The sentry sets R9 to 1 when creating stub process.
+ CMP $1, R9
+ BEQ clone
+
+done:
+ // Notify the Sentry that syscall exited.
+ BRK $3
+ B done // Be paranoid.
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R7 has been updated with the expected PPID.
- B begin
+ CMP $0, R0
+ BEQ begin
+
+ // The clone system call returned a non-zero value.
+ B done
error:
// Exit with -errno.
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 821f6848d..20244fd95 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -430,13 +430,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
for {
- // Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ // Execute the syscall instruction. The task has to stop on the
+ // trap instruction which is right after the syscall
+ // instruction.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_CONT, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
+ if sig == syscall.SIGTRAP {
// Reached syscall-enter-stop.
break
} else {
@@ -448,18 +450,6 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
}
- // Complete the actual system call.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- // Wait for syscall-exit-stop. "[Signal-delivery-stop] never happens
- // between syscall-enter-stop and syscall-exit-stop; it happens *after*
- // syscall-exit-stop.)" - ptrace(2), "Syscall-stops"
- if sig := t.wait(stopped); sig != (syscallEvent | syscall.SIGTRAP) {
- t.dumpAndPanic(fmt.Sprintf("wait failed: expected SIGTRAP, got %v [%d]", sig, sig))
- }
-
// Grab registers.
if err := t.getRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace get regs failed: %v", err))
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 606dc2b1d..e99798c56 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -141,9 +141,11 @@ func (t *thread) adjustInitRegsRip() {
t.initRegs.Rip -= initRegsRipAdjustment
}
-// Pass the expected PPID to the child via R15 when creating stub process
+// Pass the expected PPID to the child via R15 when creating stub process.
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.R15 = uint64(ppid)
+ // Rbx has to be set to 1 when creating stub process.
+ initregs.Rbx = 1
}
// patchSignalInfo patches the signal info to account for hitting the seccomp
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index 62a686ee7..7b975137f 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -127,6 +127,8 @@ func (t *thread) adjustInitRegsRip() {
// Pass the expected PPID to the child via X7 when creating stub process
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.Regs[7] = uint64(ppid)
+ // R9 has to be set to 1 when creating stub process.
+ initregs.Regs[9] = 1
}
// patchSignalInfo patches the signal info to account for hitting the seccomp
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 813ef9822..64e9c0845 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -357,6 +357,73 @@ TEXT ·Current(SB),NOSPLIT,$0-8
#define STACK_FRAME_SIZE 16
TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
+ // Step1, save sentry context into memory.
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+
+ WORD $0xd5384003 // MRS SPSR_EL1, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
+ MOVD R30, CPU_REGISTERS+PTRACE_PC(RSV_REG)
+ MOVD RSP, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_SP(RSV_REG)
+
+ MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3
+
+ // Step2, save SP_EL1, PSTATE into kernel temporary stack.
+ // switch to temporary stack.
+ LOAD_KERNEL_STACK(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R12
+ STP (R11, R12), 16*0(RSP)
+
+ MOVD CPU_REGISTERS+PTRACE_R11(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_R12(RSV_REG), R12
+
+ // Step3, test user pagetable.
+ // If user pagetable is empty, trapped in el1_ia.
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ // If pagetable is not empty, recovery kernel temporary stack.
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
+ // Step4, load app context pointer.
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP
+
+ // Step5, prepare the environment for container application.
+ // set sp_el0.
+ MOVD PTRACE_SP(RSV_REG_APP), R1
+ WORD $0xd5184101 //MSR R1, SP_EL0
+ // set pc.
+ MOVD PTRACE_PC(RSV_REG_APP), R1
+ MSR R1, ELR_EL1
+ // set pstate.
+ MOVD PTRACE_PSTATE(RSV_REG_APP), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ // RSV_REG & RSV_REG_APP will be loaded at the end.
+ REGISTERS_LOAD(RSV_REG_APP, 0)
+
+ // switch to user pagetable.
+ MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
+ MOVD PTRACE_R9(RSV_REG_APP), RSV_REG_APP
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP)
+
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
ERET()
TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index af1a4e95f..4301b697c 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -471,6 +471,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IP:
switch h.Type {
case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
i += AlignUp(length, width)
@@ -481,6 +484,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IPV6:
switch h.Type {
case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += AlignUp(length, width)
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 79589e3c8..136821963 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -22,7 +22,6 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 4a1b87a9a..d2e3644a6 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -500,29 +499,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
trunc := flags&linux.MSG_TRUNC != 0
r := unix.EndpointReader{
+ Ctx: t,
Endpoint: s.ep,
Peek: flags&linux.MSG_PEEK != 0,
}
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
// If MSG_TRUNC is set with a zero byte destination then we still need
// to read the message and discard it, or in the case where MSG_PEEK is
// set, leave it be. In both cases the full message length must be
- // returned. However, the memory manager for the destination will not read
- // the endpoint if the destination is zero length.
- //
- // In order for the endpoint to be read when the destination size is zero,
- // we must cause a read of the endpoint by using a separate fake zero
- // length block sequence and calling the EndpointReader directly.
+ // returned.
if trunc && dst.Addrs.NumBytes() == 0 {
- // Perform a read to a zero byte block sequence. We can ignore the
- // original destination since it was zero bytes. The length returned by
- // ReadToBlocks is ignored and we return the full message length to comply
- // with MSG_TRUNC.
- _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0))))
- return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
}
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
@@ -540,7 +539,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 140851c17..764f11a6b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -222,17 +222,25 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
+ // transport.Endpoint.SetSockOptBool.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+ // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
}
// SocketOperations encapsulates all the state needed to represent a network stack
@@ -977,13 +985,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if len(v) == 0 {
+ if v == 0 {
return []byte{}, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
}
- return append([]byte(v), 0), nil
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1213,12 +1231,15 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.V6OnlyOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ var o uint32
+ if v {
+ o = 1
+ }
+ return int32(o), nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1427,7 +1448,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
if n == -1 {
n = len(optVal)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
@@ -1621,7 +1655,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 2ec1a662d..2447f24ef 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -83,6 +83,19 @@ type EndpointReader struct {
ControlTrunc bool
}
+// Truncate calls RecvMsg on the endpoint without writing to a destination.
+func (r *EndpointReader) Truncate() error {
+ // Ignore bytes read since it will always be zero.
+ _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return err.ToError()
+ }
+ return nil
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 529a7a7a9..37c7ac3c1 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,17 +175,25 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptBool sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+ // GetSockOptBool gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
// State returns the current state of the socket, as represented by Linux in
// procfs.
@@ -851,11 +859,19 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+
+func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 885758054..91effe89a 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -544,8 +544,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if senderRequested {
r.From = &tcpip.FullAddress{}
}
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+
+ }
+
var total int64
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
@@ -580,7 +599,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index d46421199..aa1ac720c 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -10,7 +10,8 @@ go_library(
"capability.go",
"clone.go",
"futex.go",
- "linux64.go",
+ "linux64_amd64.go",
+ "linux64_arm64.go",
"open.go",
"poll.go",
"ptrace.go",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64_amd64.go
index e603f858f..1e823b685 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package strace
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
// linuxAMD64 provides a mapping of the Linux amd64 syscalls and their argument
// types for display / formatting.
var linuxAMD64 = SyscallMap{
@@ -365,3 +372,13 @@ var linuxAMD64 = SyscallMap{
434: makeSyscallInfo("pidfd_open", Hex, Hex),
435: makeSyscallInfo("clone3", Hex, Hex),
}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.AMD64,
+ syscalls: linuxAMD64,
+ },
+ )
+}
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
new file mode 100644
index 000000000..c3ac5248d
--- /dev/null
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -0,0 +1,323 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// linuxARM64 provides a mapping of the Linux arm64 syscalls and their argument
+// types for display / formatting.
+var linuxARM64 = SyscallMap{
+ 0: makeSyscallInfo("io_setup", Hex, Hex),
+ 1: makeSyscallInfo("io_destroy", Hex),
+ 2: makeSyscallInfo("io_submit", Hex, Hex, Hex),
+ 3: makeSyscallInfo("io_cancel", Hex, Hex, Hex),
+ 4: makeSyscallInfo("io_getevents", Hex, Hex, Hex, Hex, Timespec),
+ 5: makeSyscallInfo("setxattr", Path, Path, Hex, Hex, Hex),
+ 6: makeSyscallInfo("lsetxattr", Path, Path, Hex, Hex, Hex),
+ 7: makeSyscallInfo("fsetxattr", FD, Path, Hex, Hex, Hex),
+ 8: makeSyscallInfo("getxattr", Path, Path, Hex, Hex),
+ 9: makeSyscallInfo("lgetxattr", Path, Path, Hex, Hex),
+ 10: makeSyscallInfo("fgetxattr", FD, Path, Hex, Hex),
+ 11: makeSyscallInfo("listxattr", Path, Path, Hex),
+ 12: makeSyscallInfo("llistxattr", Path, Path, Hex),
+ 13: makeSyscallInfo("flistxattr", FD, Path, Hex),
+ 14: makeSyscallInfo("removexattr", Path, Path),
+ 15: makeSyscallInfo("lremovexattr", Path, Path),
+ 16: makeSyscallInfo("fremovexattr", FD, Path),
+ 17: makeSyscallInfo("getcwd", PostPath, Hex),
+ 18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
+ 19: makeSyscallInfo("eventfd2", Hex, Hex),
+ 20: makeSyscallInfo("epoll_create1", Hex),
+ 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 23: makeSyscallInfo("dup", FD),
+ 24: makeSyscallInfo("dup3", FD, FD, Hex),
+ 25: makeSyscallInfo("fcntl", FD, Hex, Hex),
+ 26: makeSyscallInfo("inotify_init1", Hex),
+ 27: makeSyscallInfo("inotify_add_watch", Hex, Path, Hex),
+ 28: makeSyscallInfo("inotify_rm_watch", Hex, Hex),
+ 29: makeSyscallInfo("ioctl", FD, Hex, Hex),
+ 30: makeSyscallInfo("ioprio_set", Hex, Hex, Hex),
+ 31: makeSyscallInfo("ioprio_get", Hex, Hex),
+ 32: makeSyscallInfo("flock", FD, Hex),
+ 33: makeSyscallInfo("mknodat", FD, Path, Mode, Hex),
+ 34: makeSyscallInfo("mkdirat", FD, Path, Hex),
+ 35: makeSyscallInfo("unlinkat", FD, Path, Hex),
+ 36: makeSyscallInfo("symlinkat", Path, Hex, Path),
+ 37: makeSyscallInfo("linkat", FD, Path, Hex, Path, Hex),
+ 38: makeSyscallInfo("renameat", FD, Path, Hex, Path),
+ 39: makeSyscallInfo("umount2", Path, Hex),
+ 40: makeSyscallInfo("mount", Path, Path, Path, Hex, Path),
+ 41: makeSyscallInfo("pivot_root", Path, Path),
+ 42: makeSyscallInfo("nfsservctl", Hex, Hex, Hex),
+ 43: makeSyscallInfo("statfs", Path, Hex),
+ 44: makeSyscallInfo("fstatfs", FD, Hex),
+ 45: makeSyscallInfo("truncate", Path, Hex),
+ 46: makeSyscallInfo("ftruncate", FD, Hex),
+ 47: makeSyscallInfo("fallocate", FD, Hex, Hex, Hex),
+ 48: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
+ 49: makeSyscallInfo("chdir", Path),
+ 50: makeSyscallInfo("fchdir", FD),
+ 51: makeSyscallInfo("chroot", Path),
+ 52: makeSyscallInfo("fchmod", FD, Mode),
+ 53: makeSyscallInfo("fchmodat", FD, Path, Mode),
+ 54: makeSyscallInfo("fchownat", FD, Path, Hex, Hex, Hex),
+ 55: makeSyscallInfo("fchown", FD, Hex, Hex),
+ 56: makeSyscallInfo("openat", FD, Path, OpenFlags, Mode),
+ 57: makeSyscallInfo("close", FD),
+ 58: makeSyscallInfo("vhangup"),
+ 59: makeSyscallInfo("pipe2", PipeFDs, Hex),
+ 60: makeSyscallInfo("quotactl", Hex, Hex, Hex, Hex),
+ 61: makeSyscallInfo("getdents64", FD, Hex, Hex),
+ 62: makeSyscallInfo("lseek", Hex, Hex, Hex),
+ 63: makeSyscallInfo("read", FD, ReadBuffer, Hex),
+ 64: makeSyscallInfo("write", FD, WriteBuffer, Hex),
+ 65: makeSyscallInfo("readv", FD, ReadIOVec, Hex),
+ 66: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
+ 67: makeSyscallInfo("pread64", FD, ReadBuffer, Hex, Hex),
+ 68: makeSyscallInfo("pwrite64", FD, WriteBuffer, Hex, Hex),
+ 69: makeSyscallInfo("preadv", FD, ReadIOVec, Hex, Hex),
+ 70: makeSyscallInfo("pwritev", FD, WriteIOVec, Hex, Hex),
+ 71: makeSyscallInfo("sendfile", FD, FD, Hex, Hex),
+ 72: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 73: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
+ 74: makeSyscallInfo("signalfd4", Hex, Hex, Hex, Hex),
+ 75: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
+ 76: makeSyscallInfo("splice", FD, Hex, FD, Hex, Hex, Hex),
+ 77: makeSyscallInfo("tee", FD, FD, Hex, Hex),
+ 78: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
+ 79: makeSyscallInfo("fstatat", FD, Path, Stat, Hex),
+ 80: makeSyscallInfo("fstat", FD, Stat),
+ 81: makeSyscallInfo("sync"),
+ 82: makeSyscallInfo("fsync", FD),
+ 83: makeSyscallInfo("fdatasync", FD),
+ 84: makeSyscallInfo("sync_file_range", FD, Hex, Hex, Hex),
+ 85: makeSyscallInfo("timerfd_create", Hex, Hex),
+ 86: makeSyscallInfo("timerfd_settime", FD, Hex, ItimerSpec, PostItimerSpec),
+ 87: makeSyscallInfo("timerfd_gettime", FD, PostItimerSpec),
+ 88: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
+ 89: makeSyscallInfo("acct", Hex),
+ 90: makeSyscallInfo("capget", CapHeader, PostCapData),
+ 91: makeSyscallInfo("capset", CapHeader, CapData),
+ 92: makeSyscallInfo("personality", Hex),
+ 93: makeSyscallInfo("exit", Hex),
+ 94: makeSyscallInfo("exit_group", Hex),
+ 95: makeSyscallInfo("waitid", Hex, Hex, Hex, Hex, Rusage),
+ 96: makeSyscallInfo("set_tid_address", Hex),
+ 97: makeSyscallInfo("unshare", CloneFlags),
+ 98: makeSyscallInfo("futex", Hex, FutexOp, Hex, Timespec, Hex, Hex),
+ 99: makeSyscallInfo("set_robust_list", Hex, Hex),
+ 100: makeSyscallInfo("get_robust_list", Hex, Hex, Hex),
+ 101: makeSyscallInfo("nanosleep", Timespec, PostTimespec),
+ 102: makeSyscallInfo("getitimer", ItimerType, PostItimerVal),
+ 103: makeSyscallInfo("setitimer", ItimerType, ItimerVal, PostItimerVal),
+ 104: makeSyscallInfo("kexec_load", Hex, Hex, Hex, Hex),
+ 105: makeSyscallInfo("init_module", Hex, Hex, Hex),
+ 106: makeSyscallInfo("delete_module", Hex, Hex),
+ 107: makeSyscallInfo("timer_create", Hex, Hex, Hex),
+ 108: makeSyscallInfo("timer_gettime", Hex, PostItimerSpec),
+ 109: makeSyscallInfo("timer_getoverrun", Hex),
+ 110: makeSyscallInfo("timer_settime", Hex, Hex, ItimerSpec, PostItimerSpec),
+ 111: makeSyscallInfo("timer_delete", Hex),
+ 112: makeSyscallInfo("clock_settime", Hex, Timespec),
+ 113: makeSyscallInfo("clock_gettime", Hex, PostTimespec),
+ 114: makeSyscallInfo("clock_getres", Hex, PostTimespec),
+ 115: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
+ 116: makeSyscallInfo("syslog", Hex, Hex, Hex),
+ 117: makeSyscallInfo("ptrace", PtraceRequest, Hex, Hex, Hex),
+ 118: makeSyscallInfo("sched_setparam", Hex, Hex),
+ 119: makeSyscallInfo("sched_setscheduler", Hex, Hex, Hex),
+ 120: makeSyscallInfo("sched_getscheduler", Hex),
+ 121: makeSyscallInfo("sched_getparam", Hex, Hex),
+ 122: makeSyscallInfo("sched_setaffinity", Hex, Hex, Hex),
+ 123: makeSyscallInfo("sched_getaffinity", Hex, Hex, Hex),
+ 124: makeSyscallInfo("sched_yield"),
+ 125: makeSyscallInfo("sched_get_priority_max", Hex),
+ 126: makeSyscallInfo("sched_get_priority_min", Hex),
+ 127: makeSyscallInfo("sched_rr_get_interval", Hex, Hex),
+ 128: makeSyscallInfo("restart_syscall"),
+ 129: makeSyscallInfo("kill", Hex, Signal),
+ 130: makeSyscallInfo("tkill", Hex, Signal),
+ 131: makeSyscallInfo("tgkill", Hex, Hex, Signal),
+ 132: makeSyscallInfo("sigaltstack", Hex, Hex),
+ 133: makeSyscallInfo("rt_sigsuspend", Hex),
+ 134: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction),
+ 135: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
+ 136: makeSyscallInfo("rt_sigpending", Hex),
+ 137: makeSyscallInfo("rt_sigtimedwait", SigSet, Hex, Timespec, Hex),
+ 138: makeSyscallInfo("rt_sigqueueinfo", Hex, Signal, Hex),
+ 139: makeSyscallInfo("rt_sigreturn"),
+ 140: makeSyscallInfo("setpriority", Hex, Hex, Hex),
+ 141: makeSyscallInfo("getpriority", Hex, Hex),
+ 142: makeSyscallInfo("reboot", Hex, Hex, Hex, Hex),
+ 143: makeSyscallInfo("setregid", Hex, Hex),
+ 144: makeSyscallInfo("setgid", Hex),
+ 145: makeSyscallInfo("setreuid", Hex, Hex),
+ 146: makeSyscallInfo("setuid", Hex),
+ 147: makeSyscallInfo("setresuid", Hex, Hex, Hex),
+ 148: makeSyscallInfo("getresuid", Hex, Hex, Hex),
+ 149: makeSyscallInfo("setresgid", Hex, Hex, Hex),
+ 150: makeSyscallInfo("getresgid", Hex, Hex, Hex),
+ 151: makeSyscallInfo("setfsuid", Hex),
+ 152: makeSyscallInfo("setfsgid", Hex),
+ 153: makeSyscallInfo("times", Hex),
+ 154: makeSyscallInfo("setpgid", Hex, Hex),
+ 155: makeSyscallInfo("getpgid", Hex),
+ 156: makeSyscallInfo("getsid", Hex),
+ 157: makeSyscallInfo("setsid"),
+ 158: makeSyscallInfo("getgroups", Hex, Hex),
+ 159: makeSyscallInfo("setgroups", Hex, Hex),
+ 160: makeSyscallInfo("uname", Uname),
+ 161: makeSyscallInfo("sethostname", Hex, Hex),
+ 162: makeSyscallInfo("setdomainname", Hex, Hex),
+ 163: makeSyscallInfo("getrlimit", Hex, Hex),
+ 164: makeSyscallInfo("setrlimit", Hex, Hex),
+ 165: makeSyscallInfo("getrusage", Hex, Rusage),
+ 166: makeSyscallInfo("umask", Hex),
+ 167: makeSyscallInfo("prctl", Hex, Hex, Hex, Hex, Hex),
+ 168: makeSyscallInfo("getcpu", Hex, Hex, Hex),
+ 169: makeSyscallInfo("gettimeofday", Timeval, Hex),
+ 170: makeSyscallInfo("settimeofday", Timeval, Hex),
+ 171: makeSyscallInfo("adjtimex", Hex),
+ 172: makeSyscallInfo("getpid"),
+ 173: makeSyscallInfo("getppid"),
+ 174: makeSyscallInfo("getuid"),
+ 175: makeSyscallInfo("geteuid"),
+ 176: makeSyscallInfo("getgid"),
+ 177: makeSyscallInfo("getegid"),
+ 178: makeSyscallInfo("gettid"),
+ 179: makeSyscallInfo("sysinfo", Hex),
+ 180: makeSyscallInfo("mq_open", Hex, Hex, Hex, Hex),
+ 181: makeSyscallInfo("mq_unlink", Hex),
+ 182: makeSyscallInfo("mq_timedsend", Hex, Hex, Hex, Hex, Hex),
+ 183: makeSyscallInfo("mq_timedreceive", Hex, Hex, Hex, Hex, Hex),
+ 184: makeSyscallInfo("mq_notify", Hex, Hex),
+ 185: makeSyscallInfo("mq_getsetattr", Hex, Hex, Hex),
+ 186: makeSyscallInfo("msgget", Hex, Hex),
+ 187: makeSyscallInfo("msgctl", Hex, Hex, Hex),
+ 188: makeSyscallInfo("msgrcv", Hex, Hex, Hex, Hex, Hex),
+ 189: makeSyscallInfo("msgsnd", Hex, Hex, Hex, Hex),
+ 190: makeSyscallInfo("semget", Hex, Hex, Hex),
+ 191: makeSyscallInfo("semctl", Hex, Hex, Hex, Hex),
+ 192: makeSyscallInfo("semtimedop", Hex, Hex, Hex, Hex),
+ 193: makeSyscallInfo("semop", Hex, Hex, Hex),
+ 194: makeSyscallInfo("shmget", Hex, Hex, Hex),
+ 195: makeSyscallInfo("shmctl", Hex, Hex, Hex),
+ 196: makeSyscallInfo("shmat", Hex, Hex, Hex),
+ 197: makeSyscallInfo("shmdt", Hex),
+ 198: makeSyscallInfo("socket", SockFamily, SockType, SockProtocol),
+ 199: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
+ 200: makeSyscallInfo("bind", FD, SockAddr, Hex),
+ 201: makeSyscallInfo("listen", FD, Hex),
+ 202: makeSyscallInfo("accept", FD, PostSockAddr, SockLen),
+ 203: makeSyscallInfo("connect", FD, SockAddr, Hex),
+ 204: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
+ 205: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
+ 206: makeSyscallInfo("sendto", FD, Hex, Hex, Hex, SockAddr, Hex),
+ 207: makeSyscallInfo("recvfrom", FD, Hex, Hex, Hex, PostSockAddr, SockLen),
+ 208: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
+ 209: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 210: makeSyscallInfo("shutdown", FD, Hex),
+ 211: makeSyscallInfo("sendmsg", FD, SendMsgHdr, Hex),
+ 212: makeSyscallInfo("recvmsg", FD, RecvMsgHdr, Hex),
+ 213: makeSyscallInfo("readahead", Hex, Hex, Hex),
+ 214: makeSyscallInfo("brk", Hex),
+ 215: makeSyscallInfo("munmap", Hex, Hex),
+ 216: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
+ 217: makeSyscallInfo("add_key", Hex, Hex, Hex, Hex, Hex),
+ 218: makeSyscallInfo("request_key", Hex, Hex, Hex, Hex),
+ 219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
+ 220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
+ 221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
+ 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
+ 224: makeSyscallInfo("swapon", Hex, Hex),
+ 225: makeSyscallInfo("swapoff", Hex),
+ 226: makeSyscallInfo("mprotect", Hex, Hex, Hex),
+ 227: makeSyscallInfo("msync", Hex, Hex, Hex),
+ 228: makeSyscallInfo("mlock", Hex, Hex),
+ 229: makeSyscallInfo("munlock", Hex, Hex),
+ 230: makeSyscallInfo("mlockall", Hex),
+ 231: makeSyscallInfo("munlockall"),
+ 232: makeSyscallInfo("mincore", Hex, Hex, Hex),
+ 233: makeSyscallInfo("madvise", Hex, Hex, Hex),
+ 234: makeSyscallInfo("remap_file_pages", Hex, Hex, Hex, Hex, Hex),
+ 235: makeSyscallInfo("mbind", Hex, Hex, Hex, Hex, Hex, Hex),
+ 236: makeSyscallInfo("get_mempolicy", Hex, Hex, Hex, Hex, Hex),
+ 237: makeSyscallInfo("set_mempolicy", Hex, Hex, Hex),
+ 238: makeSyscallInfo("migrate_pages", Hex, Hex, Hex, Hex),
+ 239: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
+ 240: makeSyscallInfo("rt_tgsigqueueinfo", Hex, Hex, Signal, Hex),
+ 241: makeSyscallInfo("perf_event_open", Hex, Hex, Hex, Hex, Hex),
+ 242: makeSyscallInfo("accept4", FD, PostSockAddr, SockLen, SockFlags),
+ 243: makeSyscallInfo("recvmmsg", FD, Hex, Hex, Hex, Hex),
+
+ 260: makeSyscallInfo("wait4", Hex, Hex, Hex, Rusage),
+ 261: makeSyscallInfo("prlimit64", Hex, Hex, Hex, Hex),
+ 262: makeSyscallInfo("fanotify_init", Hex, Hex),
+ 263: makeSyscallInfo("fanotify_mark", Hex, Hex, Hex, Hex, Hex),
+ 264: makeSyscallInfo("name_to_handle_at", FD, Hex, Hex, Hex, Hex),
+ 265: makeSyscallInfo("open_by_handle_at", FD, Hex, Hex),
+ 266: makeSyscallInfo("clock_adjtime", Hex, Hex),
+ 267: makeSyscallInfo("syncfs", FD),
+ 268: makeSyscallInfo("setns", FD, Hex),
+ 269: makeSyscallInfo("sendmmsg", FD, Hex, Hex, Hex),
+ 270: makeSyscallInfo("process_vm_readv", Hex, ReadIOVec, Hex, IOVec, Hex, Hex),
+ 271: makeSyscallInfo("process_vm_writev", Hex, IOVec, Hex, WriteIOVec, Hex, Hex),
+ 272: makeSyscallInfo("kcmp", Hex, Hex, Hex, Hex, Hex),
+ 273: makeSyscallInfo("finit_module", Hex, Hex, Hex),
+ 274: makeSyscallInfo("sched_setattr", Hex, Hex, Hex),
+ 275: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
+ 276: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
+ 277: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 278: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 279: makeSyscallInfo("memfd_create", Path, Hex),
+ 280: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 281: makeSyscallInfo("execveat", FD, Path, Hex, Hex, Hex),
+ 282: makeSyscallInfo("userfaultfd", Hex),
+ 283: makeSyscallInfo("membarrier", Hex),
+ 284: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 285: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 286: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 287: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 291: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 292: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 293: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
+}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.ARM64,
+ syscalls: linuxARM64})
+}
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index e5d486c4e..24e29a2ba 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -250,14 +250,7 @@ type syscallTable struct {
syscalls SyscallMap
}
-// syscallTables contains all syscall tables.
-var syscallTables = []syscallTable{
- {
- os: abi.Linux,
- arch: arch.AMD64,
- syscalls: linuxAMD64,
- },
-}
+var syscallTables []syscallTable
// Lookup returns the SyscallMap for the OS/Arch combination. The returned map
// must not be changed.
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 6766ba587..a76975cee 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -30,6 +30,7 @@ go_library(
"sys_random.go",
"sys_read.go",
"sys_rlimit.go",
+ "sys_rseq.go",
"sys_rusage.go",
"sys_sched.go",
"sys_seccomp.go",
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index 272ae9991..479c5f6ff 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -377,7 +377,7 @@ var AMD64 = &kernel.SyscallTable{
331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
332: syscalls.Supported("statx", Statx),
333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
- 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+ 334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index 3b584eed9..d3f61f5e8 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -307,7 +307,7 @@ var ARM64 = &kernel.SyscallTable{
290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
291: syscalls.Supported("statx", Statx),
292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
- 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+ 293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go
new file mode 100644
index 000000000..90db10ea6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rseq.go
@@ -0,0 +1,48 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// RSeq implements syscall rseq(2).
+func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Uint()
+ flags := args[2].Int()
+ signature := args[3].Uint()
+
+ if !t.RSeqAvailable() {
+ // Event for applications that want rseq on a configuration
+ // that doesn't support them.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ switch flags {
+ case 0:
+ // Register.
+ return 0, nil, t.SetRSeq(addr, length, signature)
+ case linux.RSEQ_FLAG_UNREGISTER:
+ return 0, nil, t.ClearRSeq(addr, length, signature)
+ default:
+ // Unknown flag.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index d57ffb3a1..4a8bc24a2 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -39,10 +39,13 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
+ defer segment.DecRef()
return uintptr(segment.ID), nil, nil
}
// findSegment retrives a shm segment by the given id.
+//
+// findSegment returns a reference on Shm.
func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
r := t.IPCNamespace().ShmRegistry()
segment := r.FindByID(id)
@@ -63,6 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{
Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC,
@@ -72,7 +76,6 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- defer segment.DecRef()
addr, err = t.MemoryManager().MMap(t, opts)
return uintptr(addr), nil, err
}
@@ -105,6 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
stat, err := segment.IPCStat(t)
if err == nil {
@@ -128,6 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
switch cmd {
case linux.IPC_SET:
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 65d4d0cd8..e07ebd153 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -10,6 +10,7 @@ go_library(
"packet_buffer_state.go",
"tcpip.go",
"time_unsafe.go",
+ "timer.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
@@ -26,3 +27,10 @@ go_test(
srcs = ["tcpip_test.go"],
embed = [":tcpip"],
)
+
+go_test(
+ name = "timer_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ deps = [":tcpip"],
+)
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index f1d837196..f2061c778 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -44,6 +44,7 @@ go_test(
],
deps = [
":header",
+ "//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"@com_github_google_go-cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index fc671e439..135a60b12 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -15,6 +15,7 @@
package header
import (
+ "crypto/sha256"
"encoding/binary"
"strings"
@@ -102,6 +103,11 @@ const (
// bytes including and after the IIDOffsetInIPv6Address-th byte are
// for the IID.
IIDOffsetInIPv6Address = 8
+
+ // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
+ // for the secret key used to generate an opaque interface identifier as
+ // outlined by RFC 7217.
+ OpaqueIIDSecretKeyMinBytes = 16
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -326,3 +332,42 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
}
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+
+// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
+// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
+//
+// The opaque IID is generated from the cryptographic hash of the concatenation
+// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
+// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
+// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
+// requirements for security.
+//
+// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
+// array for the buffer will not be allocated.
+func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
+ // As per RFC 7217 section 5, the opaque identifier can be generated as a
+ // cryptographic hash of the concatenation of each of the function parameters.
+ // Note, we omit the optional Network_ID field.
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address]))
+ h.Write([]byte(nicName))
+ h.Write([]byte{dadCounter})
+ h.Write(secretKey)
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ return append(buf, sum[:IIDSize]...)
+}
+
+// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
+// an opaque IID.
+func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey))
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 42c5c6fc1..1994003ed 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -15,9 +15,12 @@
package header_test
import (
+ "bytes"
+ "crypto/sha256"
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -43,3 +46,163 @@ func TestLinkLocalAddr(t *testing.T) {
t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
}
}
+
+func TestAppendOpaqueInterfaceIdentifier(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ prefix: header.IPv6LinkLocalPrefix.Subnet(),
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ h := sha256.New()
+ h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address]))
+ h.Write([]byte(test.nicName))
+ h.Write([]byte{test.dadCounter})
+ if k := test.secretKey; k != nil {
+ h.Write(k)
+ }
+ var hashSum [sha256.Size]byte
+ h.Sum(hashSum[:0])
+ want := hashSum[:header.IIDSize]
+
+ // Passing a nil buffer should result in a new buffer returned with the
+ // IID.
+ if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+
+ // Passing a buffer with sufficient capacity for the IID should populate
+ // the buffer provided.
+ var iidBuf [header.IIDSize]byte
+ if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ if got := iidBuf[:]; !bytes.Equal(got, want) {
+ t.Errorf("got iidBuf = %x, want = %x", got, want)
+ }
+ })
+ }
+}
+
+func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ prefix := header.IPv6LinkLocalPrefix.Subnet()
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ addrBytes := [header.IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ test.nicName,
+ test.dadCounter,
+ test.secretKey,
+ ))
+
+ if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want {
+ t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index da8482509..42cacb8a6 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -79,16 +79,16 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4144a7837..f1bc33adf 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -239,7 +239,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -480,7 +480,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e645cf62c..4ee3d5b45 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -238,11 +238,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
@@ -256,7 +256,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
@@ -270,11 +270,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return len(pkts), nil
}
@@ -289,7 +289,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
ip := header.IPv4(pkt.Data.First())
@@ -324,10 +324,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLo
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
e.HandlePacket(r, pkt.Clone())
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e13f1fabf..58c3c79b9 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -112,11 +112,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
@@ -130,7 +130,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
@@ -139,11 +139,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return len(pkts), nil
}
@@ -161,7 +161,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
// TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 69077669a..826fca4de 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -52,23 +52,21 @@ go_test(
name = "stack_x_test",
size = "small",
srcs = [
- "ndp_test.go",
"stack_test.go",
"transport_demuxer_test.go",
"transport_test.go",
],
deps = [
":stack",
+ "//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go-cmp//cmp:go_default_library",
@@ -85,3 +83,23 @@ go_test(
"//pkg/tcpip",
],
)
+
+go_test(
+ name = "ndp_test",
+ size = "small",
+ srcs = ["ndp_test.go"],
+ deps = [
+ ":stack",
+ "//pkg/rand",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index d9ab59336..a9dd322db 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -115,6 +115,30 @@ var (
MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
)
+// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
+// NDP Router Advertisement informed the Stack about.
+type DHCPv6ConfigurationFromNDPRA int
+
+const (
+ // DHCPv6NoConfiguration indicates that no configurations are available via
+ // DHCPv6.
+ DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota
+
+ // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
+ //
+ // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
+ // will return all available configuration information.
+ DHCPv6ManagedAddress
+
+ // DHCPv6OtherConfigurations indicates that other configuration information is
+ // available via DHCPv6.
+ //
+ // Other configurations are configurations other than addresses. Examples of
+ // other configurations are recursive DNS server list, DNS search lists and
+ // default gateway.
+ DHCPv6OtherConfigurations
+)
+
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
@@ -169,6 +193,15 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+ // OnAutoGenAddressDeprecated will be called when an auto-generated
+ // address (as part of SLAAC) has been deprecated, but is still
+ // considered valid. Note, if an address is invalidated at the same
+ // time it is deprecated, the deprecation event MAY be omitted.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
+
// OnAutoGenAddressInvalidated will be called when an auto-generated
// address (as part of SLAAC) has been invalidated.
//
@@ -185,7 +218,20 @@ type NDPDispatcher interface {
// already known DNS servers. If called with known DNS servers, their
// valid lifetimes must be refreshed to lifetime (it may be increased,
// decreased, or completely invalidated when lifetime = 0).
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+
+ // OnDHCPv6Configuration will be called with an updated configuration that is
+ // available via DHCPv6 for a specified NIC.
+ //
+ // NDPDispatcher assumes that the initial configuration available by DHCPv6 is
+ // DHCPv6NoConfiguration.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
}
// NDPConfigurations is the NDP configurations for the netstack.
@@ -272,6 +318,9 @@ type ndpState struct {
// The addresses generated by SLAAC.
autoGenAddresses map[tcpip.Address]autoGenAddressState
+
+ // The last learned DHCPv6 configuration from an NDP RA.
+ dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
}
// dadState holds the Duplicate Address Detection timer and channel to signal
@@ -290,71 +339,27 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- invalidationTimer *time.Timer
-
- // Used to inform the timer not to invalidate the default router (R) in
- // a race condition (T1 is a goroutine that handles an RA from R and T2
- // is the goroutine that handles R's invalidation timer firing):
- // T1: Receive a new RA from R
- // T1: Obtain the NIC's lock before processing the RA
- // T2: R's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends R's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates R immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate R, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
+ invalidationTimer tcpip.CancellableTimer
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- invalidationTimer *time.Timer
-
- // Used to signal the timer not to invalidate the on-link prefix (P) in
- // a race condition (T1 is a goroutine that handles a PI for P and T2
- // is the goroutine that handles P's invalidation timer firing):
- // T1: Receive a new PI for P
- // T1: Obtain the NIC's lock before processing the PI
- // T2: P's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends P's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates P immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate P, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
+ invalidationTimer tcpip.CancellableTimer
}
// autoGenAddressState holds data associated with an address generated via
// SLAAC.
type autoGenAddressState struct {
- invalidationTimer *time.Timer
-
- // Used to signal the timer not to invalidate the SLAAC address (A) in
- // a race condition (T1 is a goroutine that handles a PI for A and T2
- // is the goroutine that handles A's invalidation timer firing):
- // T1: Receive a new PI for A
- // T1: Obtain the NIC's lock before processing the PI
- // T2: A's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends A's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates A immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate A, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
-
- // Nonzero only when the address is not valid forever (invalidationTimer
- // is not nil).
+ // A reference to the referencedNetworkEndpoint that this autoGenAddressState
+ // is holding state for.
+ ref *referencedNetworkEndpoint
+
+ deprecationTimer tcpip.CancellableTimer
+ invalidationTimer tcpip.CancellableTimer
+
+ // Nonzero only when the address is not valid forever.
validUntil time.Time
}
@@ -556,7 +561,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
// handleRA handles a Router Advertisement message that arrived on the NIC
// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Is the NIC configured to handle RAs at all?
//
@@ -568,6 +573,28 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
return
}
+ // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
+ // only inform the dispatcher on configuration changes. We do nothing else
+ // with the information.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ var configuration DHCPv6ConfigurationFromNDPRA
+ switch {
+ case ra.ManagedAddrConfFlag():
+ configuration = DHCPv6ManagedAddress
+
+ case ra.OtherConfFlag():
+ configuration = DHCPv6OtherConfigurations
+
+ default:
+ configuration = DHCPv6NoConfiguration
+ }
+
+ if ndp.dhcpv6Configuration != configuration {
+ ndp.dhcpv6Configuration = configuration
+ ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ }
+ }
+
// Is the NIC configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
rtr, ok := ndp.defaultRouters[ip]
@@ -585,27 +612,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
// the invalidation timer.
- timer := rtr.invalidationTimer
-
- // We should ALWAYS have an invalidation timer for a
- // discovered router.
- if timer == nil {
- panic("ndphandlera: RA invalidation timer should not be nil")
- }
-
- if !timer.Stop() {
- // If we reach this point, then we know the
- // timer fired after we already took the NIC
- // lock. Inform the timer not to invalidate the
- // router when it obtains the lock as we just
- // got a new RA that refreshes its lifetime to a
- // non-zero value. See
- // defaultRouterState.doNotInvalidate for more
- // details.
- *rtr.doNotInvalidate = true
- }
-
- timer.Reset(rl)
+ rtr.invalidationTimer.StopLocked()
+ rtr.invalidationTimer.Reset(rl)
+ ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
// We know about the router but it is no longer to be
@@ -672,10 +681,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.Stop()
- rtr.invalidationTimer = nil
- *rtr.doNotInvalidate = true
- rtr.doNotInvalidate = nil
+ rtr.invalidationTimer.StopLocked()
delete(ndp.defaultRouters, ip)
@@ -704,27 +710,15 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
return
}
- // Used to signal the timer not to invalidate the default router (R) in
- // a race condition. See defaultRouterState.doNotInvalidate for more
- // details.
- var doNotInvalidate bool
-
- ndp.defaultRouters[ip] = defaultRouterState{
- invalidationTimer: time.AfterFunc(rl, func() {
- ndp.nic.stack.mu.Lock()
- defer ndp.nic.stack.mu.Unlock()
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if doNotInvalidate {
- doNotInvalidate = false
- return
- }
-
+ state := defaultRouterState{
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
- doNotInvalidate: &doNotInvalidate,
}
+
+ state.invalidationTimer.Reset(rl)
+
+ ndp.defaultRouters[ip] = state
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -746,21 +740,17 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
return
}
- // Used to signal the timer not to invalidate the on-link prefix (P) in
- // a race condition. See onLinkPrefixState.doNotInvalidate for more
- // details.
- var doNotInvalidate bool
- var timer *time.Timer
+ state := onLinkPrefixState{
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }),
+ }
- // Only create a timer if the lifetime is not infinite.
if l < header.NDPInfiniteLifetime {
- timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate)
+ state.invalidationTimer.Reset(l)
}
- ndp.onLinkPrefixes[prefix] = onLinkPrefixState{
- invalidationTimer: timer,
- doNotInvalidate: &doNotInvalidate,
- }
+ ndp.onLinkPrefixes[prefix] = state
}
// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
@@ -775,13 +765,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- if s.invalidationTimer != nil {
- s.invalidationTimer.Stop()
- s.invalidationTimer = nil
- *s.doNotInvalidate = true
- }
-
- s.doNotInvalidate = nil
+ s.invalidationTimer.StopLocked()
delete(ndp.onLinkPrefixes, prefix)
@@ -791,28 +775,6 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
}
}
-// prefixInvalidationCallback returns a new on-link prefix invalidation timer
-// for prefix that fires after vl.
-//
-// doNotInvalidate is used to signal the timer when it fires at the same time
-// that a prefix's valid lifetime gets refreshed. See
-// onLinkPrefixState.doNotInvalidate for more details.
-func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer {
- return time.AfterFunc(vl, func() {
- ndp.nic.stack.mu.Lock()
- defer ndp.nic.stack.mu.Unlock()
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if *doNotInvalidate {
- *doNotInvalidate = false
- return
- }
-
- ndp.invalidateOnLinkPrefix(prefix)
- })
-}
-
// handleOnLinkPrefixInformation handles a Prefix Information option with
// its on-link flag set, as per RFC 4861 section 6.3.4.
//
@@ -852,42 +814,17 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
+ //
// Update the invalidation timer.
- timer := prefixState.invalidationTimer
-
- if timer == nil && vl >= header.NDPInfiniteLifetime {
- // Had infinite valid lifetime before and
- // continues to have an invalid lifetime. Do
- // nothing further.
- return
- }
- if timer != nil && !timer.Stop() {
- // If we reach this point, then we know the timer alread fired
- // after we took the NIC lock. Inform the timer to not
- // invalidate the prefix once it obtains the lock as we just
- // got a new PI that refreshes its lifetime to a non-zero value.
- // See onLinkPrefixState.doNotInvalidate for more details.
- *prefixState.doNotInvalidate = true
- }
-
- if vl >= header.NDPInfiniteLifetime {
- // Prefix is now valid forever so we don't need
- // an invalidation timer.
- prefixState.invalidationTimer = nil
- ndp.onLinkPrefixes[prefix] = prefixState
- return
- }
+ prefixState.invalidationTimer.StopLocked()
- if timer != nil {
- // We already have a timer so just reset it to
- // expire after the new valid lifetime.
- timer.Reset(vl)
- return
+ if vl < header.NDPInfiniteLifetime {
+ // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // the new valid lifetime.
+ prefixState.invalidationTimer.Reset(vl)
}
- // We do not have a timer so just create a new one.
- prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate)
ndp.onLinkPrefixes[prefix] = prefixState
}
@@ -897,7 +834,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the autonomous flag is set.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
vl := pi.ValidLifetime()
pl := pi.PreferredLifetime()
@@ -912,103 +849,30 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
prefix := pi.Subnet()
// Check if we already have an auto-generated address for prefix.
- for _, ref := range ndp.nic.endpoints {
- if ref.protocol != header.IPv6ProtocolNumber {
- continue
- }
-
- if ref.configType != slaac {
- continue
- }
-
- addr := ref.ep.ID().LocalAddress
- refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()}
+ for addr, addrState := range ndp.autoGenAddresses {
+ refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()}
if refAddrWithPrefix.Subnet() != prefix {
continue
}
- //
- // At this point, we know we are refreshing a SLAAC generated
- // IPv6 address with the prefix, prefix. Do the work as outlined
- // by RFC 4862 section 5.5.3.e.
- //
-
- addrState, ok := ndp.autoGenAddresses[addr]
- if !ok {
- panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr))
- }
-
- // TODO(b/143713887): Handle deprecating auto-generated address
- // after the preferred lifetime.
-
- // As per RFC 4862 section 5.5.3.e, the valid lifetime of the
- // address generated by SLAAC is as follows:
- //
- // 1) If the received Valid Lifetime is greater than 2 hours or
- // greater than RemainingLifetime, set the valid lifetime of
- // the address to the advertised Valid Lifetime.
- //
- // 2) If RemainingLifetime is less than or equal to 2 hours,
- // ignore the advertised Valid Lifetime.
- //
- // 3) Otherwise, reset the valid lifetime of the address to 2
- // hours.
-
- // Handle the infinite valid lifetime separately as we do not
- // keep a timer in this case.
- if vl >= header.NDPInfiniteLifetime {
- if addrState.invalidationTimer != nil {
- // Valid lifetime was finite before, but now it
- // is valid forever.
- if !addrState.invalidationTimer.Stop() {
- *addrState.doNotInvalidate = true
- }
- addrState.invalidationTimer = nil
- addrState.validUntil = time.Time{}
- ndp.autoGenAddresses[addr] = addrState
- }
-
- return
- }
-
- var effectiveVl time.Duration
- var rl time.Duration
-
- // If the address was originally set to be valid forever,
- // assume the remaining time to be the maximum possible value.
- if addrState.invalidationTimer == nil {
- rl = header.NDPInfiniteLifetime
- } else {
- rl = time.Until(addrState.validUntil)
- }
-
- if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
- effectiveVl = vl
- } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
- ndp.autoGenAddresses[addr] = addrState
- return
- } else {
- effectiveVl = MinPrefixInformationValidLifetimeForUpdate
- }
-
- if addrState.invalidationTimer == nil {
- addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate)
- } else {
- if !addrState.invalidationTimer.Stop() {
- *addrState.doNotInvalidate = true
- }
- addrState.invalidationTimer.Reset(effectiveVl)
- }
-
- addrState.validUntil = time.Now().Add(effectiveVl)
- ndp.autoGenAddresses[addr] = addrState
+ // At this point, we know we are refreshing a SLAAC generated IPv6 address
+ // with the prefix prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.e.
+ ndp.refreshAutoGenAddressLifetimes(addr, pl, vl)
return
}
// We do not already have an address within the prefix, prefix. Do the
// work as outlined by RFC 4862 section 5.5.3.d if n is configured
// to auto-generated global addresses by SLAAC.
+ ndp.newAutoGenAddress(prefix, pl, vl)
+}
+// newAutoGenAddress generates a new SLAAC address with the provided lifetimes
+// for prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) {
// Are we configured to auto-generate new global addresses?
if !ndp.configs.AutoGenGlobalAddresses {
return
@@ -1028,22 +892,24 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
return
}
- // Only attempt to generate an interface-specific IID if we have a valid
- // link address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address
- // (provided by LinkEndpoint.LinkAddress) before reaching this
- // point.
- linkAddr := ndp.nic.linkEP.LinkAddress()
- if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return
- }
+ addrBytes := []byte(prefix.ID())
+ if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey)
+ } else {
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return
+ }
- // Generate an address within prefix from the modified EUI-64 of ndp's
- // NIC's Ethernet MAC address.
- addrBytes := make([]byte, header.IPv6AddressSize)
- copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address])
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ // Generate an address within prefix from the modified EUI-64 of ndp's NIC's
+ // Ethernet MAC address.
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ }
addr := tcpip.Address(addrBytes)
addrWithPrefix := tcpip.AddressWithPrefix{
Address: addr,
@@ -1065,29 +931,132 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
return
}
- if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{
+ protocolAddr := tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addrWithPrefix,
- }, FirstPrimaryEndpoint, permanent, slaac); err != nil {
- panic(err)
+ }
+ // If the preferred lifetime is zero, then the address should be considered
+ // deprecated.
+ deprecated := pl == 0
+ ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
+ if err != nil {
+ log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err)
+ }
+
+ state := autoGenAddressState{
+ ref: ref,
+ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)
+ }
+ addrState.ref.deprecated = true
+ ndp.notifyAutoGenAddressDeprecated(addr)
+ }),
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateAutoGenAddress(addr)
+ }),
}
- // Setup the timers to deprecate and invalidate this newly generated
+ // Setup the initial timers to deprecate and invalidate this newly generated
// address.
- // TODO(b/143713887): Handle deprecating auto-generated addresses
- // after the preferred lifetime.
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ state.deprecationTimer.Reset(pl)
+ }
- var doNotInvalidate bool
- var vTimer *time.Timer
if vl < header.NDPInfiniteLifetime {
- vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate)
+ state.invalidationTimer.Reset(vl)
+ state.validUntil = time.Now().Add(vl)
}
- ndp.autoGenAddresses[addr] = autoGenAddressState{
- invalidationTimer: vTimer,
- doNotInvalidate: &doNotInvalidate,
- validUntil: time.Now().Add(vl),
+ ndp.autoGenAddresses[addr] = state
+}
+
+// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated
+// address addr.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) {
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr)
+ }
+ defer func() { ndp.autoGenAddresses[addr] = addrState }()
+
+ // If the preferred lifetime is zero, then the address should be considered
+ // deprecated.
+ deprecated := pl == 0
+ wasDeprecated := addrState.ref.deprecated
+ addrState.ref.deprecated = deprecated
+
+ // Only send the deprecation event if the deprecated status for addr just
+ // changed from non-deprecated to deprecated.
+ if !wasDeprecated && deprecated {
+ ndp.notifyAutoGenAddressDeprecated(addr)
+ }
+
+ // If addr was preferred for some finite lifetime before, stop the deprecation
+ // timer so it can be reset.
+ addrState.deprecationTimer.StopLocked()
+
+ // Reset the deprecation timer if addr has a finite preferred lifetime.
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ addrState.deprecationTimer.Reset(pl)
+ }
+
+ // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address
+ //
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or greater than
+ // RemainingLifetime, set the valid lifetime of the address to the
+ // advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
+ // advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the address to 2 hours.
+
+ // Handle the infinite valid lifetime separately as we do not keep a timer in
+ // this case.
+ if vl >= header.NDPInfiniteLifetime {
+ addrState.invalidationTimer.StopLocked()
+ addrState.validUntil = time.Time{}
+ return
+ }
+
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the address was originally set to be valid forever, assume the remaining
+ // time to be the maximum possible value.
+ if addrState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(addrState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
+ return
+ } else {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ addrState.invalidationTimer.StopLocked()
+ addrState.invalidationTimer.Reset(effectiveVl)
+ addrState.validUntil = time.Now().Add(effectiveVl)
+}
+
+// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr
+// has been deprecated.
+func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: validPrefixLenForAutoGen,
+ })
}
}
@@ -1111,19 +1080,12 @@ func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
state, ok := ndp.autoGenAddresses[addr]
-
if !ok {
return false
}
- if state.invalidationTimer != nil {
- state.invalidationTimer.Stop()
- state.invalidationTimer = nil
- *state.doNotInvalidate = true
- }
-
- state.doNotInvalidate = nil
-
+ state.deprecationTimer.StopLocked()
+ state.invalidationTimer.StopLocked()
delete(ndp.autoGenAddresses, addr)
if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
@@ -1136,26 +1098,6 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// autoGenAddrInvalidationTimer returns a new invalidation timer for an
-// auto-generated address that fires after vl.
-//
-// doNotInvalidate is used to inform the timer when it fires at the same time
-// that an auto-generated address's valid lifetime gets refreshed. See
-// autoGenAddrState.doNotInvalidate for more details.
-func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer {
- return time.AfterFunc(vl, func() {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if *doNotInvalidate {
- *doNotInvalidate = false
- return
- }
-
- ndp.invalidateAutoGenAddress(addr)
- })
-}
-
// cleanupHostOnlyState cleans up any state that is only useful for hosts.
//
// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 64a9a2b20..108762b6e 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack_test
+package ndp_test
import (
"encoding/binary"
@@ -21,6 +21,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -29,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
@@ -45,6 +48,10 @@ var (
llAddr1 = header.LinkLocalAddr(linkAddr1)
llAddr2 = header.LinkLocalAddr(linkAddr2)
llAddr3 = header.LinkLocalAddr(linkAddr3)
+ dstAddr = tcpip.FullAddress{
+ Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ Port: 25,
+ }
)
func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix {
@@ -135,6 +142,7 @@ type ndpAutoGenAddrEventType int
const (
newAddr ndpAutoGenAddrEventType = iota
+ deprecatedAddr
invalidatedAddr
)
@@ -154,18 +162,24 @@ type ndpRDNSSEvent struct {
rdnss ndpRDNSS
}
+type ndpDHCPv6Event struct {
+ nicID tcpip.NICID
+ configuration stack.DHCPv6ConfigurationFromNDPRA
+}
+
var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
// related events happen for test purposes.
type ndpDispatcher struct {
- dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
- prefixC chan ndpPrefixEvent
- rememberPrefix bool
- autoGenAddrC chan ndpAutoGenAddrEvent
- rdnssC chan ndpRDNSSEvent
+ dadC chan ndpDADEvent
+ routerC chan ndpRouterEvent
+ rememberRouter bool
+ prefixC chan ndpPrefixEvent
+ rememberPrefix bool
+ autoGenAddrC chan ndpAutoGenAddrEvent
+ rdnssC chan ndpRDNSSEvent
+ dhcpv6ConfigurationC chan ndpDHCPv6Event
}
// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
@@ -239,6 +253,16 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi
return true
}
+func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ deprecatedAddr,
+ }
+ }
+}
+
func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
if c := n.autoGenAddrC; c != nil {
c <- ndpAutoGenAddrEvent{
@@ -262,11 +286,23 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
+// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ if c := n.dhcpv6ConfigurationC; c != nil {
+ c <- ndpDHCPv6Event{
+ nicID,
+ configuration,
+ }
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
func TestDADResolve(t *testing.T) {
+ t.Parallel()
+
tests := []struct {
name string
dupAddrDetectTransmits uint8
@@ -399,6 +435,8 @@ func TestDADResolve(t *testing.T) {
// a node doing DAD for the same address), or if another node is detected to own
// the address already (receive an NA message for the tentative address).
func TestDADFail(t *testing.T) {
+ t.Parallel()
+
tests := []struct {
name string
makeBuf func(tgt tcpip.Address) buffer.Prependable
@@ -542,6 +580,8 @@ func TestDADFail(t *testing.T) {
// TestDADStop tests to make sure that the DAD process stops when an address is
// removed.
func TestDADStop(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
@@ -614,6 +654,71 @@ func TestDADStop(t *testing.T) {
}
}
+// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
+// link-local addresses will only be assigned after the DAD process resolves.
+func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Address should not be considered bound to the
+ // NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ linkLocalAddr := header.LinkLocalAddr(linkAddr1)
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // We should get a resolution event after 1s (default time to
+ // resolve as per default NDP configurations). Waiting for that
+ // resolution time + an extra 1s without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != linkLocalAddr {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+}
+
// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
// we attempt to update NDP configurations using an invalid NICID.
func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
@@ -631,6 +736,8 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
// configurations without affecting the default NDP configurations or other
// interfaces' configurations.
func TestSetNDPConfigurations(t *testing.T) {
+ t.Parallel()
+
tests := []struct {
name string
dupAddrDetectTransmits uint8
@@ -779,21 +886,32 @@ func TestSetNDPConfigurations(t *testing.T) {
}
}
-// raBufWithOpts returns a valid NDP Router Advertisement with options.
-//
-// Note, raBufWithOpts does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
+// and DHCPv6 configurations specified.
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- ra := header.NDPRouterAdvert(pkt.NDPPayload())
+ raPayload := pkt.NDPPayload()
+ ra := header.NDPRouterAdvert(raPayload)
+ // Populate the Router Lifetime.
+ binary.BigEndian.PutUint16(raPayload[2:], rl)
+ // Populate the Managed Address flag field.
+ if managedAddress {
+ // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
+ // of the RA payload.
+ raPayload[1] |= (1 << 7)
+ }
+ // Populate the Other Configurations flag field.
+ if otherConfigurations {
+ // The Other Configurations flag field is the 6th bit of byte #1
+ // (0-indexing) of the RA payload.
+ raPayload[1] |= (1 << 6)
+ }
opts := ra.Options()
opts.Serialize(optSer)
- // Populate the Router Lifetime.
- binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -808,6 +926,23 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
+// raBufWithOpts returns a valid NDP Router Advertisement with options.
+//
+// Note, raBufWithOpts does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+}
+
+// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
+// fields set.
+//
+// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
+// DHCPv6 related ones.
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+}
+
// raBuf returns a valid NDP Router Advertisement.
//
// Note, raBuf does not populate any of the RA fields other than the
@@ -857,6 +992,8 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
// TestNoRouterDiscovery tests that router discovery will not be performed if
// configured not to.
func TestNoRouterDiscovery(t *testing.T) {
+ t.Parallel()
+
// Being configured to discover routers means handle and
// discover are set to true and forwarding is set to false.
// This tests all possible combinations of the configurations,
@@ -869,8 +1006,6 @@ func TestNoRouterDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
}
@@ -1011,13 +1146,13 @@ func TestRouterDiscovery(t *testing.T) {
expectRouterEvent(llAddr2, true)
// Rx an RA from another router (lladdr3) with non-zero lifetime.
- l3Lifetime := time.Duration(6)
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime)))
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
expectRouterEvent(llAddr3, true)
// Rx an RA from lladdr2 with lesser lifetime.
- l2Lifetime := time.Duration(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime)))
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
select {
case <-ndpDisp.routerC:
t.Fatal("Should not receive a router event when updating lifetimes for known routers")
@@ -1031,7 +1166,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1048,7 +1183,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1105,6 +1240,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
// configured not to.
func TestNoPrefixDiscovery(t *testing.T) {
+ t.Parallel()
+
prefix := tcpip.AddressWithPrefix{
Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
PrefixLen: 64,
@@ -1122,8 +1259,6 @@ func TestNoPrefixDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, 1),
}
@@ -1480,6 +1615,8 @@ func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
func TestNoAutoGenAddr(t *testing.T) {
+ t.Parallel()
+
prefix, _, _ := prefixSubnetAddr(0, "")
// Being configured to auto-generate addresses means handle and
@@ -1494,8 +1631,6 @@ func TestNoAutoGenAddr(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
@@ -1637,12 +1772,529 @@ func TestAutoGenAddr(t *testing.T) {
}
}
+// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
+// channel.Endpoint and stack.Stack.
+//
+// stack.Stack will have a default route through the router (llAddr3) installed
+// and a static link-address (linkAddr3) added to the link address cache for the
+// router.
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+ t.Helper()
+ ndpDisp := &ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ return ndpDisp, e, s
+}
+
+// addrForNewConnection returns the local address used when creating a new
+// connection.
+func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Connect(dstAddr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// addrForNewConnectionWithAddr returns the local address used when creating a
+// new connection with a specific local address.
+func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Bind(addr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", addr, err)
+ }
+ if err := ep.Connect(dstAddr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
+// receiving a PI with 0 preferred lifetime.
+func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ }
+
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ }
+
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+}
+
+// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// when its preferred lifetime expires.
+func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+ const nicID = 1
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout)
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
+ }
+
+ case <-time.After(newMinVLDuration + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+}
+
+// Tests transitioning a SLAAC address's valid lifetime between finite and
+// infinite values.
+func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
+ const infiniteVLSeconds = 2
+ const minVLSeconds = 1
+ savedIL := header.NDPInfiniteLifetime
+ savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ header.NDPInfiniteLifetime = savedIL
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ infiniteVL uint32
+ }{
+ {
+ name: "EqualToInfiniteVL",
+ infiniteVL: infiniteVLSeconds,
+ },
+ // Our implementation supports changing header.NDPInfiniteLifetime for tests
+ // such that a packet can be received where the lifetime field has a value
+ // greater than header.NDPInfiniteLifetime. Because of this, we test to make
+ // sure that receiving a value greater than header.NDPInfiniteLifetime is
+ // handled the same as when receiving a value equal to
+ // header.NDPInfiniteLifetime.
+ {
+ name: "MoreThanInfiniteVL",
+ infiniteVL: infiniteVLSeconds + 1,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ case <-time.After(minVLSeconds*time.Second + defaultTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
// auto-generated address only gets updated when required to, as specified in
// RFC 4862 section 5.5.3.e.
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
- const newMinVL = 5
+ const newMinVL = 4
saved := stack.MinPrefixInformationValidLifetimeForUpdate
defer func() {
stack.MinPrefixInformationValidLifetimeForUpdate = saved
@@ -1911,6 +2563,110 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
}
+// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
+// opaque interface identifiers when configured to do so.
+func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
+ t.Parallel()
+
+ const nicID = 1
+ const nicName = "nic1"
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
+ // addr1 and addr2 are the addresses that are expected to be generated when
+ // stack.Stack is configured to generate opaque interface identifiers as
+ // defined by RFC 7217.
+ addrBytes := []byte(subnet1.ID())
+ addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ addrBytes = []byte(subnet2.ID())
+ addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in a PI.
+ const validLifetimeSecondPrefix1 = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in a PI with a large valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+}
+
// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
// to the integrator when an RA is received with the NDP Recursive DNS Server
// option with at least one valid address.
@@ -2312,3 +3068,94 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
default:
}
}
+
+// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly
+// informed when new information about what configurations are available via
+// DHCPv6 is learned.
+func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ t.Helper()
+ select {
+ case e := <-ndpDisp.dhcpv6ConfigurationC:
+ if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" {
+ t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DHCPv6 configuration event")
+ }
+ }
+
+ expectNoDHCPv6Event := func() {
+ t.Helper()
+ select {
+ case <-ndpDisp.dhcpv6ConfigurationC:
+ t.Fatal("unexpected DHCPv6 configuration event")
+ default:
+ }
+ }
+
+ // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ // Receiving the same update again should not result in an event to the
+ // NDPDispatcher.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to none.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ //
+ // Note, when the M flag is set, the O flag is redundant.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+ // Even though the DHCPv6 flags are different, the effective configuration is
+ // the same so we should not receive a new event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index ddd014658..3810c6602 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -27,11 +27,11 @@ import (
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- stack *Stack
- id tcpip.NICID
- name string
- linkEP LinkEndpoint
- loopback bool
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ context NICContext
mu sync.RWMutex
spoofing bool
@@ -85,7 +85,7 @@ const (
)
// newNIC returns a new NIC using the default NDP configurations from stack.
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
// example, make sure that the link address it provides is a valid
// unicast ethernet address.
@@ -99,7 +99,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
id: id,
name: name,
linkEP: ep,
- loopback: loopback,
+ context: ctx,
primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -174,23 +174,28 @@ func (n *NIC) enable() *tcpip.Error {
return err
}
- if !n.stack.autoGenIPv6LinkLocal {
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() {
return nil
}
- l2addr := n.linkEP.LinkAddress()
+ var addr tcpip.Address
+ if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
+ } else {
+ l2addr := n.linkEP.LinkAddress()
- // Only attempt to generate the link-local address if we have a
- // valid MAC address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address
- // (provided by LinkEndpoint.LinkAddress) before reaching this
- // point.
- if !header.IsValidUnicastEthernetAddress(l2addr) {
- return nil
- }
+ // Only attempt to generate the link-local address if we have a valid MAC
+ // address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
- addr := header.LinkLocalAddr(l2addr)
+ addr = header.LinkLocalAddr(l2addr)
+ }
_, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
@@ -235,6 +240,10 @@ func (n *NIC) isPromiscuousMode() bool {
return rv
}
+func (n *NIC) isLoopback() bool {
+ return n.linkEP.Capabilities()&CapabilityLoopback != 0
+}
+
// setSpoofing enables or disables address spoofing.
func (n *NIC) setSpoofing(enable bool) {
n.mu.Lock()
@@ -244,17 +253,47 @@ func (n *NIC) setSpoofing(enable bool) {
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
+//
+// primaryEndpoint will return the first non-deprecated endpoint if such an
+// endpoint exists. If no non-deprecated endpoint exists, the first deprecated
+// endpoint will be returned.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
n.mu.RLock()
defer n.mu.RUnlock()
+ var deprecatedEndpoint *referencedNetworkEndpoint
for _, r := range n.primary[protocol] {
- if r.isValidForOutgoing() && r.tryIncRef() {
- return r
+ if !r.isValidForOutgoing() {
+ continue
+ }
+
+ if !r.deprecated {
+ if r.tryIncRef() {
+ // r is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ deprecatedEndpoint.decRefLocked()
+ deprecatedEndpoint = nil
+ }
+
+ return r
+ }
+ } else if deprecatedEndpoint == nil && r.tryIncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of r in
+ // case n doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, r's reference count
+ // will be decremented when such an endpoint is found.
+ deprecatedEndpoint = r
}
}
- return nil
+ // n doesn't have any valid non-deprecated endpoints, so return
+ // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
+ // endpoints either).
+ return deprecatedEndpoint
}
// hasPermanentAddrLocked returns true if n has a permanent (including currently
@@ -362,7 +401,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, temporary, static)
+ }, peb, temporary, static, false)
n.mu.Unlock()
return ref
@@ -411,10 +450,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent, static)
+ return n.addAddressLocked(protocolAddress, peb, permanent, static, false)
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
// TODO(b/141022673): Validate IP address before adding them.
// Sanity check.
@@ -450,6 +489,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
protocol: protocolAddress.Protocol,
kind: kind,
configType: configType,
+ deprecated: deprecated,
}
// Set up cache if link address resolution exists for this protocol.
@@ -548,6 +588,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
return addrs
}
+// primaryAddress returns the primary address associated with this NIC.
+//
+// primaryAddress will return the first non-deprecated address if such an
+// address exists. If no non-deprecated address exists, the first deprecated
+// address will be returned.
+func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list, ok := n.primary[proto]
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ var deprecatedEndpoint *referencedNetworkEndpoint
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints to avoid confusion
+ // and prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentTentative, permanentExpired, temporary:
+ continue
+ }
+
+ if !ref.deprecated {
+ return tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ }
+ }
+
+ if deprecatedEndpoint == nil {
+ deprecatedEndpoint = ref
+ }
+ }
+
+ if deprecatedEndpoint != nil {
+ return tcpip.AddressWithPrefix{
+ Address: deprecatedEndpoint.ep.ID().LocalAddress,
+ PrefixLen: deprecatedEndpoint.ep.PrefixLen(),
+ }
+ }
+
+ return tcpip.AddressWithPrefix{}
+}
+
// AddAddressRange adds a range of addresses to n, so that it starts accepting
// packets targeted at the given addresses and network protocol. The range is
// given by a subnet address, and all addresses contained in the subnet are
@@ -1104,6 +1189,11 @@ type referencedNetworkEndpoint struct {
// configType is the method that was used to configure this endpoint.
// This must never change after the endpoint is added to a NIC.
configType networkEndpointConfigType
+
+ // deprecated indicates whether or not the endpoint should be considered
+ // deprecated. That is, when deprecated is true, other endpoints that are not
+ // deprecated should be preferred.
+ deprecated bool
}
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 61fd46d66..2b8751d49 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -234,15 +234,15 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 34307ae07..517f4b941 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -158,7 +158,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt)
+ err := r.ref.ep.WritePacket(r, gso, params, pkt)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
@@ -174,7 +174,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop)
+ n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n))
}
@@ -195,7 +195,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 7a9600679..41bf9fd9b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -352,6 +352,38 @@ func (u *uniqueIDGenerator) UniqueID() uint64 {
return atomic.AddUint64((*uint64)(u), 1)
}
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it will be passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -412,8 +444,8 @@ type Stack struct {
ndpConfigs NDPConfigurations
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
- // See the AutoGenIPv6LinkLocal field of Options for more details.
+ // to auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
autoGenIPv6LinkLocal bool
// ndpDisp is the NDP event dispatcher that is used to send the netstack
@@ -422,6 +454,10 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
}
// UniqueID is an abstract generator of unique identifiers.
@@ -460,13 +496,15 @@ type Options struct {
// before assigning an address to a NIC.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
// Note, setting this to true does not mean that a link-local address
- // will be assigned right away, or at all. If Duplicate Address
- // Detection is enabled, an address will only be assigned if it
- // successfully resolves. If it fails, no further attempt will be made
- // to auto-generate an IPv6 link-local address.
+ // will be assigned right away, or at all. If Duplicate Address Detection
+ // is enabled, an address will only be assigned if it successfully resolves.
+ // If it fails, no further attempt will be made to auto-generate an IPv6
+ // link-local address.
//
// The generated link-local address will follow RFC 4291 Appendix A
// guidelines.
@@ -479,6 +517,10 @@ type Options struct {
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
+
+ // OpaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -549,6 +591,7 @@ func New(opts Options) *Stack {
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
}
// Add specified network protocols.
@@ -753,9 +796,30 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum
return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
}
-// createNIC creates a NIC with the provided id and link-layer endpoint, and
-// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
+// NICContext is an opaque pointer used to store client-supplied NIC metadata.
+type NICContext interface{}
+
+// NICOptions specifies the configuration of a NIC as it is being created.
+// The zero value creates an enabled, unnamed NIC.
+type NICOptions struct {
+ // Name specifies the name of the NIC.
+ Name string
+
+ // Disabled specifies whether to avoid calling Attach on the passed
+ // LinkEndpoint.
+ Disabled bool
+
+ // Context specifies user-defined data that will be returned in stack.NICInfo
+ // for the NIC. Clients of this library can use it to add metadata that
+ // should be tracked alongside a NIC, to avoid having to keep a
+ // map[tcpip.NICID]metadata mirroring stack.Stack's nic map.
+ Context NICContext
+}
+
+// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
+// NICOptions. See the documentation on type NICOptions for details on how
+// NICs can be configured.
+func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -764,44 +828,20 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled,
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, name, ep, loopback)
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
- if enabled {
+ if !opts.Disabled {
return n.enable()
}
return nil
}
-// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
+// `LinkEndpoint.Attach` to start delivering packets to it.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, true, false)
-}
-
-// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
-// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, false)
-}
-
-// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
-// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, true)
-}
-
-// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
-// but leave it disable. Stack.EnableNIC must be called before the link-layer
-// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, false, false)
-}
-
-// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
-// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, false, false)
+ return s.CreateNICWithOptions(id, ep, NICOptions{})
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -855,6 +895,18 @@ type NICInfo struct {
MTU uint32
Stats NICStats
+
+ // Context is user-supplied data optionally supplied in CreateNICWithOptions.
+ // See type NICOptions for more details.
+ Context NICContext
+}
+
+// HasNIC returns true if the NICID is defined in the stack.
+func (s *Stack) HasNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ _, ok := s.nics[id]
+ s.mu.RUnlock()
+ return ok
}
// NICInfo returns a map of NICIDs to their associated information.
@@ -868,7 +920,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Up: true, // Netstack interfaces are always up.
Running: nic.linkEP.IsAttached(),
Promiscuous: nic.isPromiscuousMode(),
- Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
+ Loopback: nic.isLoopback(),
}
nics[id] = NICInfo{
Name: nic.name,
@@ -877,6 +929,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
+ Context: nic.context,
}
}
return nics
@@ -993,9 +1046,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
return nics
}
-// GetMainNICAddress returns the first primary address and prefix for the given
-// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
-// value if the NIC doesn't have a primary address for the given protocol.
+// GetMainNICAddress returns the first non-deprecated primary address and prefix
+// for the given NIC and protocol. If no non-deprecated primary address exists,
+// a deprecated primary address and prefix will be returned. Returns an error if
+// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
+// address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1005,12 +1060,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- for _, a := range nic.PrimaryAddresses() {
- if a.Protocol == protocol {
- return a.AddressWithPrefix, nil
- }
- }
- return tcpip.AddressWithPrefix{}, nil
+ return nic.primaryAddress(protocol), nil
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
@@ -1032,7 +1082,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok {
if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
}
} else {
@@ -1048,7 +1098,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
remoteAddr = ref.ep.ID().LocalAddress
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback)
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
if needRoute {
r.NextHop = route.Gateway
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8fc034ca1..e8de4e87d 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -24,13 +24,14 @@ import (
"sort"
"strings"
"testing"
- "time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -48,6 +49,8 @@ const (
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
+
+ linkAddr = "\x02\x02\x03\x04\x05\x06"
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
@@ -122,7 +125,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -133,7 +136,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
b[1] = f.id.LocalAddress[0]
b[2] = byte(params.Protocol)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
@@ -141,7 +144,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
@@ -149,11 +152,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -1894,55 +1897,67 @@ func TestNICForwarding(t *testing.T) {
}
// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
-// (or lack there-of if disabled (default)). Note, DAD will be disabled in
-// these tests.
+// using the modified EUI-64 of the NIC's MAC address (or lack there-of if
+// disabled (default)). Note, DAD will be disabled in these tests.
func TestNICAutoGenAddr(t *testing.T) {
tests := []struct {
name string
autoGen bool
linkAddr tcpip.LinkAddress
+ iidOpts stack.OpaqueInterfaceIdentifierOptions
shouldGen bool
}{
{
"Disabled",
false,
- linkAddr1,
+ linkAddr,
+ stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(nicID tcpip.NICID, _ string) string {
+ return fmt.Sprintf("nic%d", nicID)
+ },
+ },
false,
},
{
"Enabled",
true,
- linkAddr1,
+ linkAddr,
+ stack.OpaqueInterfaceIdentifierOptions{},
true,
},
{
"Nil MAC",
true,
tcpip.LinkAddress([]byte(nil)),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Empty MAC",
true,
tcpip.LinkAddress(""),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Invalid MAC",
true,
tcpip.LinkAddress("\x01\x02\x03"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Multicast MAC",
true,
tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Unspecified MAC",
true,
tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
}
@@ -1951,13 +1966,12 @@ func TestNICAutoGenAddr(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ OpaqueIIDOpts: test.iidOpts,
}
if test.autoGen {
- // Only set opts.AutoGenIPv6LinkLocal when
- // test.autoGen is true because
- // opts.AutoGenIPv6LinkLocal should be false by
- // default.
+ // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by default.
opts.AutoGenIPv6LinkLocal = true
}
@@ -1973,8 +1987,8 @@ func TestNICAutoGenAddr(t *testing.T) {
}
if test.shouldGen {
- // Should have auto-generated an address and
- // resolved immediately (DAD is disabled).
+ // Should have auto-generated an address and resolved immediately (DAD
+ // is disabled).
if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
}
@@ -1988,66 +2002,215 @@ func TestNICAutoGenAddr(t *testing.T) {
}
}
-// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
-// link-local addresses will only be assigned after the DAD process resolves.
-func TestNICAutoGenAddrDoesDAD(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+// TestNICContextPreservation tests that you can read out via stack.NICInfo the
+// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
+func TestNICContextPreservation(t *testing.T) {
+ var ctx *int
+ tests := []struct {
+ name string
+ opts stack.NICOptions
+ want stack.NICContext
+ }{
+ {
+ "context_set",
+ stack.NICOptions{Context: ctx},
+ ctx,
+ },
+ {
+ "context_not_set",
+ stack.NICOptions{},
+ nil,
+ },
}
- ndpConfigs := stack.DefaultNDPConfigurations()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ id := tcpip.NICID(1)
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
+ t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
+ }
+ nicinfos := s.NICInfo()
+ nicinfo, ok := nicinfos[id]
+ if !ok {
+ t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
+ }
+ if got, want := nicinfo.Context == test.want, true; got != want {
+ t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ }
+ })
}
+}
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local
+// addresses with opaque interface identifiers. Link Local addresses should
+// always be generated with opaque IIDs if configured to use them, even if the
+// NIC has an invalid MAC address.
+func TestNICAutoGenAddrWithOpaque(t *testing.T) {
+ const nicID = 1
- // Address should not be considered bound to the
- // NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
+ n, err := rand.Read(secretKey[:])
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("rand.Read(_): %s", err)
}
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
}
- linkLocalAddr := header.LinkLocalAddr(linkAddr1)
+ tests := []struct {
+ name string
+ nicName string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ secretKey []byte
+ }{
+ {
+ name: "Disabled",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr,
+ secretKey: secretKey[:],
+ },
+ {
+ name: "Enabled",
+ nicName: "nic1",
+ autoGen: true,
+ linkAddr: linkAddr,
+ secretKey: secretKey[:],
+ },
+ // These are all cases where we would not have generated a
+ // link-local address if opaque IIDs were disabled.
+ {
+ name: "Nil MAC and empty nicName",
+ nicName: "",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress([]byte(nil)),
+ secretKey: secretKey[:1],
+ },
+ {
+ name: "Empty MAC and empty nicName",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress(""),
+ secretKey: secretKey[:2],
+ },
+ {
+ name: "Invalid MAC",
+ nicName: "test",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x01\x02\x03"),
+ secretKey: secretKey[:3],
+ },
+ {
+ name: "Multicast MAC",
+ nicName: "test2",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ secretKey: secretKey[:4],
+ },
+ {
+ name: "Unspecified MAC and nil SecretKey",
+ nicName: "test3",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ },
+ }
- // Wait for DAD to resolve.
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // We should get a resolution event after 1s (default time to
- // resolve as per default NDP configurations). Waiting for that
- // resolution time + an extra 1s without a resolution event
- // means something is wrong.
- t.Fatal("timed out waiting for DAD resolution")
- case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicID != 1 {
- t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
- }
- if e.addr != linkLocalAddr {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
- }
- if !e.resolved {
- t.Fatal("got DAD event w/ resolved = false, want = true")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: test.secretKey,
+ },
+ }
+
+ if test.autoGen {
+ // Only set opts.AutoGenIPv6LinkLocal when
+ // test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by
+ // default.
+ opts.AutoGenIPv6LinkLocal = true
+ }
+
+ e := channel.New(10, 1280, test.linkAddr)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: test.nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+
+ if test.autoGen {
+ // Should have auto-generated an address and
+ // resolved immediately (DAD is disabled).
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+ } else {
+ // Should not have auto-generated an address.
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ }
+ })
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+}
+
+// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are
+// not auto-generated for loopback NICs.
+func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
+ const nicID = 1
+ const nicName = "nicName"
+
+ tests := []struct {
+ name string
+ opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ }{
+ {
+ name: "IID From MAC",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ },
+ {
+ name: "Opaque IID",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ },
+ },
}
- if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ }
+
+ e := loopback.New()
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 3b28b06d0..5e9237de9 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -41,7 +41,7 @@ const (
type testContext struct {
t *testing.T
- linkEPs map[string]*channel.Endpoint
+ linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
@@ -61,35 +61,29 @@ func (c *testContext) createV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
}
-// newDualTestContextMultiNic creates the testing context and also linkEpNames
-// named NICs.
-func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
+func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- linkEPs := make(map[string]*channel.Endpoint)
- for i, linkEpName := range linkEpNames {
- channelEP := channel.New(256, mtu, "")
- nicID := tcpip.NICID(i + 1)
- if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil {
+ linkEps := make(map[tcpip.NICID]*channel.Endpoint)
+ for _, linkEpID := range linkEpIDs {
+ channelEp := channel.New(256, mtu, "")
+ if err := s.CreateNIC(linkEpID, channelEp); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- linkEPs[linkEpName] = channelEP
+ linkEps[linkEpID] = channelEp
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress IPv4 failed: %v", err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress IPv6 failed: %v", err)
}
}
@@ -108,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string)
return &testContext{
t: t,
s: s,
- linkEPs: linkEPs,
+ linkEps: linkEps,
}
}
@@ -125,7 +119,7 @@ func newPayload() []byte {
return b
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -156,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -186,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
func TestDistribution(t *testing.T) {
type endpointSockopts struct {
reuse int
- bindToDevice string
+ bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
@@ -194,71 +188,71 @@ func TestDistribution(t *testing.T) {
endpoints []endpointSockopts
// wantedDistribution is the wanted ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[string][]float64
+ wantedDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
- "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ 1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{0, "dev0"},
- endpointSockopts{0, "dev1"},
- endpointSockopts{0, "dev2"},
+ {0, 1},
+ {0, 2},
+ {0, 3},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
- "dev0": []float64{1, 0, 0},
+ 1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
- "dev1": []float64{0, 1, 0},
+ 2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
- "dev2": []float64{0, 0, 1},
+ 3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, ""},
+ {1, 1},
+ {1, 1},
+ {1, 2},
+ {1, 2},
+ {1, 2},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
- "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ 1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
- "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
- "dev999": []float64{0, 0, 0, 0, 0, 1},
+ 1000: {0, 0, 0, 0, 0, 1},
},
},
} {
t.Run(test.name, func(t *testing.T) {
for device, wantedDistribution := range test.wantedDistributions {
- t.Run(device, func(t *testing.T) {
- var devices []string
+ t.Run(string(device), func(t *testing.T) {
+ var devices []tcpip.NICID
for d := range test.wantedDistributions {
devices = append(devices, d)
}
- c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ c := newDualTestContextMultiNIC(t, defaultMTU, devices)
defer c.cleanup()
c.createV6Endpoint(false)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 748ce4ea5..f50604a8a 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -102,13 +102,23 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f62fd729f..72b5ce179 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -423,17 +423,25 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptBool sets a socket option, for simple cases where a value
+ // has the bool type.
+ SetSockOptBool(opt SockOptBool, v bool) *Error
+
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
- SetSockOptInt(opt SockOpt, v int) *Error
+ SetSockOptInt(opt SockOptInt, v int) *Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
+ // GetSockOptBool gets a socket option for simple cases where a return
+ // value has the bool type.
+ GetSockOptBool(SockOptBool) (bool, *Error)
+
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
- GetSockOptInt(SockOpt) (int, *Error)
+ GetSockOptInt(SockOptInt) (int, *Error)
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
@@ -488,13 +496,22 @@ type WriteOptions struct {
Atomic bool
}
-// SockOpt represents socket options which values have the int type.
-type SockOpt int
+// SockOptBool represents socket options which values have the bool type.
+type SockOptBool int
+
+const (
+ // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
+ // socket is to be restricted to sending and receiving IPv6 packets only.
+ V6OnlyOption SockOptBool = iota
+)
+
+// SockOptInt represents socket options which values have the int type.
+type SockOptInt int
const (
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the input buffer should be returned.
- ReceiveQueueSizeOption SockOpt = iota
+ ReceiveQueueSizeOption SockOptInt = iota
// SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the send buffer size option.
@@ -521,10 +538,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
-// socket is to be restricted to sending and receiving IPv6 packets only.
-type V6OnlyOption int
-
// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
// held until segments are full by the TCP transport protocol.
type CorkOption int
@@ -539,7 +552,7 @@ type ReusePortOption int
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
-type BindToDeviceOption string
+type BindToDeviceOption NICID
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
new file mode 100644
index 000000000..f5f01f32f
--- /dev/null
+++ b/pkg/tcpip/timer.go
@@ -0,0 +1,161 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "sync"
+ "time"
+)
+
+// cancellableTimerInstance is a specific instance of CancellableTimer.
+//
+// Different instances are created each time CancellableTimer is Reset so each
+// timer has its own earlyReturn signal. This is to address a bug when a
+// CancellableTimer is stopped and reset in quick succession resulting in a
+// timer instance's earlyReturn signal being affected or seen by another timer
+// instance.
+//
+// Consider the following sceneario where timer instances share a common
+// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
+// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
+// (B), third (C), and fourth (D) instance of the timer firing, respectively):
+// T1: Obtain L
+// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T2: instance A fires, blocked trying to obtain L.
+// T1: Attempt to stop instance A (set earlyReturn = true)
+// T1: Reset timer (create instance B)
+// T3: instance B fires, blocked trying to obtain L.
+// T1: Attempt to stop instance B (set earlyReturn = true)
+// T1: Reset timer (create instance C)
+// T4: instance C fires, blocked trying to obtain L.
+// T1: Attempt to stop instance C (set earlyReturn = true)
+// T1: Reset timer (create instance D)
+// T5: instance D fires, blocked trying to obtain L.
+// T1: Release L
+//
+// Now that T1 has released L, any of the 4 timer instances can take L and check
+// earlyReturn. If the timers simply check earlyReturn and then do nothing
+// further, then instance D will never early return even though it was not
+// requested to stop. If the timers reset earlyReturn before early returning,
+// then all but one of the timers will do work when only one was expected to.
+// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// will fire (again, when only one was expected to).
+//
+// To address the above concerns the simplest solution was to give each timer
+// its own earlyReturn signal.
+type cancellableTimerInstance struct {
+ timer *time.Timer
+
+ // Used to inform the timer to early return when it gets stopped while the
+ // lock the timer tries to obtain when fired is held (T1 is a goroutine that
+ // tries to cancel the timer and T2 is the goroutine that handles the timer
+ // firing):
+ // T1: Obtain the lock, then call StopLocked()
+ // T2: timer fires, and gets blocked on obtaining the lock
+ // T1: Releases lock
+ // T2: Obtains lock does unintended work
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using earlyReturn to return early so that once T2 obtains
+ // the lock, it will see that it is set to true and do nothing further.
+ earlyReturn *bool
+}
+
+// stop stops the timer instance t from firing if it hasn't fired already. If it
+// has fired and is blocked at obtaining the lock, earlyReturn will be set to
+// true so that it will early return when it obtains the lock.
+func (t *cancellableTimerInstance) stop() {
+ if t.timer != nil {
+ t.timer.Stop()
+ *t.earlyReturn = true
+ }
+}
+
+// CancellableTimer is a timer that does some work and can be safely cancelled
+// when it fires at the same time some "related work" is being done.
+//
+// The term "related work" is defined as some work that needs to be done while
+// holding some lock that the timer must also hold while doing some work.
+type CancellableTimer struct {
+ // The active instance of a cancellable timer.
+ instance cancellableTimerInstance
+
+ // locker is the lock taken by the timer immediately after it fires and must
+ // be held when attempting to stop the timer.
+ //
+ // Must never change after being assigned.
+ locker sync.Locker
+
+ // fn is the function that will be called when a timer fires and has not been
+ // signaled to early return.
+ //
+ // fn MUST NOT attempt to lock locker.
+ //
+ // Must never change after being assigned.
+ fn func()
+}
+
+// StopLocked prevents the Timer from firing if it has not fired already.
+//
+// If the timer is blocked on obtaining the t.locker lock when StopLocked is
+// called, it will early return instead of calling t.fn.
+//
+// Note, t will be modified.
+//
+// t.locker MUST be locked.
+func (t *CancellableTimer) StopLocked() {
+ t.instance.stop()
+
+ // Nothing to do with the stopped instance anymore.
+ t.instance = cancellableTimerInstance{}
+}
+
+// Reset changes the timer to expire after duration d.
+//
+// Note, t will be modified.
+//
+// Reset should only be called on stopped or expired timers. To be safe, callers
+// should always call StopLocked before calling Reset.
+func (t *CancellableTimer) Reset(d time.Duration) {
+ // Create a new instance.
+ earlyReturn := false
+ t.instance = cancellableTimerInstance{
+ timer: time.AfterFunc(d, func() {
+ t.locker.Lock()
+ defer t.locker.Unlock()
+
+ if earlyReturn {
+ // If we reach this point, it means that the timer fired while another
+ // goroutine called StopLocked while it had the lock. Simply return
+ // here and do nothing further.
+ earlyReturn = false
+ return
+ }
+
+ t.fn()
+ }),
+ earlyReturn: &earlyReturn,
+ }
+}
+
+// MakeCancellableTimer returns an unscheduled CancellableTimer with the given
+// locker and fn.
+//
+// fn MUST NOT attempt to lock locker.
+//
+// Callers must call Reset to schedule the timer to fire.
+func MakeCancellableTimer(locker sync.Locker, fn func()) CancellableTimer {
+ return CancellableTimer{locker: locker, fn: fn}
+}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
new file mode 100644
index 000000000..1f735d735
--- /dev/null
+++ b/pkg/tcpip/timer_test.go
@@ -0,0 +1,236 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package timer_test
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ shortDuration = 1 * time.Nanosecond
+ middleDuration = 100 * time.Millisecond
+ longDuration = 1 * time.Second
+)
+
+func TestCancellableTimerFire(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.MakeCancellableTimer(&lock, func() {
+ ch <- struct{}{}
+ })
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromLongDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ timer.StopLocked()
+ lock.Unlock()
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerImmediatelyStop(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ for i := 0; i < 1000; i++ {
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ for i := 0; i < 10; i++ {
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration * 2)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for double the duration so timers that weren't correctly stopped can
+ // fire.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration * 2):
+ }
+}
+
+func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration)
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 9c40931b5..c7ce74cdd 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -350,13 +350,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return nil
+}
+
// SetSockOptInt sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 0010b5e5f..07ffa8aba 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -247,17 +247,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
+ return tcpip.ErrUnknownProtocolOption
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -265,6 +265,16 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrNotSupported
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
ep.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 5aafe2615..85f7eb76b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -509,13 +509,38 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -544,21 +569,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) {
e.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index dfaa4a559..4f361b226 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) {
// Make sure we get the same error when calling the original ep and the
// new one. This validates that v4-mapped endpoints are still able to
// query the V6Only flag, whereas pure v4 endpoints are not.
- var v tcpip.V6OnlyOption
- expected := c.EP.GetSockOpt(&v)
- if err := nep.GetSockOpt(&v); err != expected {
+ _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
}
@@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) {
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
- var v tcpip.V6OnlyOption
- if err := nep.GetSockOpt(&v); err != nil {
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
t.Fatalf("GetSockOpt failed failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 8ff125855..830bc1e3e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -885,8 +885,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
+ availBefore := e.receiveBufferAvailableLocked()
e.rcvBufSize = rcvWnd
- e.notifyProtocolGoroutine(notifyReceiveWindowChanged)
+ availAfter := e.receiveBufferAvailableLocked()
+ mask := uint32(notifyReceiveWindowChanged)
+ if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -956,11 +962,11 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
e.rcvBufUsed -= len(v)
- // If the window was small before this read and if the read
- // freed up enough buffer space, to either fit an aMSS or half
- // a receive buffer (whichever smaller), then notify the
- // protocol goroutine to send a window update.
- if e.windowCrossedACKThreshold(len(v)) == 1 {
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1134,17 +1140,20 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// windowCrossedACKThreshold checks if the receive window to be announced now would be
-// under aMSS or under half receive buffer, whichever smaller. This is useful as
-// a receive side silly window syndrome prevention mechanism. If window grows
-// to reasonable value, we should send ACK to the sender to inform the rx space is now
-// large. We also want ensure a series of small read()'s won't trigger a flood of
-// spurious tiny ACK's.
+// windowCrossedACKThreshold checks if the receive window to be announced now
+// would be under aMSS or under half receive buffer, whichever smaller. This is
+// useful as a receive side silly window syndrome prevention mechanism. If
+// window grows to reasonable value, we should send ACK to the sender to inform
+// the rx space is now large. We also want ensure a series of small read()'s
+// won't trigger a flood of spurious tiny ACK's.
//
-// For large receive buffers, the threshold is aMSS - once reader reads more than aMSS
-// we'll send ACK. For tiny receive buffers, the threshold is half of receive buffer size.
-// This is chosen arbitrairly.
-func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) int {
+// For large receive buffers, the threshold is aMSS - once reader reads more
+// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
+// receive buffer size. This is chosen arbitrairly.
+// crossed will be true if the window size crossed the ACK threshold.
+// above will be true if the new window is >= ACK threshold and false
+// otherwise.
+func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
newAvail := e.receiveBufferAvailableLocked()
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
@@ -1158,15 +1167,38 @@ func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) int {
switch {
case oldAvail < threshold && newAvail >= threshold:
- return 1
+ return true, true
case oldAvail >= threshold && newAvail < threshold:
- return -1
+ return true, false
}
- return 0
+ return false, false
+}
+
+// SetSockOptBool sets a socket option.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != StateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v
+ }
+
+ return nil
}
// SetSockOptInt sets a socket option.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
@@ -1207,11 +1239,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
e.rcvAutoParams.disabled = true
- // Immediatelly send an ACK to uncork the sender silly
- // window syndrome prevetion, when our available space
- // grows above aMSS or half receive buffer, whichever
- // smaller.
- if e.windowCrossedACKThreshold(availAfter-availBefore) == 1 {
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
@@ -1283,19 +1314,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.QuickAckOption:
if v == 0 {
@@ -1316,23 +1342,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.v6only = v != 0
- return nil
-
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
@@ -1473,8 +1482,27 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ return v, nil
+ }
+
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1552,12 +1580,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = ""
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.QuickAckOption:
@@ -1567,22 +1591,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.TTLOption:
e.mu.Lock()
*o = tcpip.TTLOption(e.ttl)
@@ -2252,10 +2260,9 @@ func (e *endpoint) readyToRead(s *segment) {
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
- // Increase counter if the receive window falls down
- // below MSS or half receive buffer size, whichever
- // smaller.
- if e.windowCrossedACKThreshold(-s.data.Size()) == -1 {
+ // Increase counter if the receive window falls down below MSS
+ // or half receive buffer size, whichever smaller.
+ if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
e.rcvList.PushBack(s)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 4c2e458e3..6edfa8dce 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1083,12 +1083,12 @@ func TestTrafficClassV6(t *testing.T) {
func TestConnectBindToDevice(t *testing.T) {
for _, test := range []struct {
name string
- device string
+ device tcpip.NICID
want tcp.EndpointState
}{
- {"RightDevice", "nic1", tcp.StateEstablished},
- {"WrongDevice", "nic2", tcp.StateSynSent},
- {"AnyDevice", "", tcp.StateEstablished},
+ {"RightDevice", 1, tcp.StateEstablished},
+ {"WrongDevice", 2, tcp.StateSynSent},
+ {"AnyDevice", 0, tcp.StateEstablished},
} {
t.Run(test.name, func(t *testing.T) {
c := context.New(t, defaultMTU)
@@ -3798,46 +3798,41 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
- }
-
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ if err := s.CreateNIC(321, loopback.New()); err != nil {
t.Errorf("CreateNIC failed: %v", err)
}
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -4031,12 +4026,12 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err)
}
case "dual":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err)
}
default:
t.Fatalf("unknown network: '%s'", network)
@@ -6614,16 +6609,14 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
t.Fatalf("expected small, non-zero window: %d", lastWnd)
}
- // We now have < 1 MSS in the buffer space. Read the data! An
+ // We now have < 1 MSS in the buffer space. Read the data! An
// ack should be sent in response to that. The window was not
// zero, but it grew to larger than MSS.
- _, _, err := c.EP.Read(nil)
- if err != nil {
+ if _, _, err := c.EP.Read(nil); err != nil {
t.Fatalf("Read failed: %v", err)
}
- _, _, err = c.EP.Read(nil)
- if err != nil {
+ if _, _, err := c.EP.Read(nil); err != nil {
t.Fatalf("Read failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index b0a376eba..822907998 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -158,15 +158,17 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts2 := stack.NICOptions{Name: "nic2"}
+ if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
@@ -473,11 +475,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.EP.SetSockOpt(v); err != nil {
+ if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1ac4705af..864dc8733 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -456,14 +456,9 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
- return nil
-}
-
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -478,8 +473,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- e.v6only = v != 0
+ e.v6only = v
+ }
+
+ return nil
+}
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
@@ -624,19 +631,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.BroadcastOption:
e.mu.Lock()
@@ -660,8 +662,27 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ return v, nil
+ }
+
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -695,22 +716,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.TTLOption:
e.mu.Lock()
*o = tcpip.TTLOption(e.ttl)
@@ -757,12 +762,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = tcpip.BindToDeviceOption("")
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 7051a7a9c..0a82bc4fa 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -335,7 +335,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
} else if flow.isBroadcast() {
@@ -508,46 +508,42 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "my_device"}
+ if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
- }
-
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index dd4926bb9..6a8765ec8 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -126,7 +126,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
linkEP := loopback.New()
log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
return err
}
@@ -173,7 +173,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
return err
}
@@ -218,15 +218,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
// createNICWithAddrs creates a NIC in the network stack and adds the given
// addresses.
-func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error {
- if loopback {
- if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil {
- return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, ep, err)
- }
- } else {
- if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil {
- return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, ep, err)
- }
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP) error {
+ opts := stack.NICOptions{Name: name}
+ if err := n.Stack.CreateNICWithOptions(id, sniffer.New(ep), opts); err != nil {
+ return fmt.Errorf("CreateNICWithOptions(%d, _, %+v) failed: %v", id, opts, err)
}
// Always start with an arp address for the NIC.
diff --git a/scripts/issue_reviver.sh b/scripts/issue_reviver.sh
new file mode 100755
index 000000000..bac9b9192
--- /dev/null
+++ b/scripts/issue_reviver.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+DIR=$(dirname $0)
+source "${DIR}"/common.sh
+
+# Provide a credential file if available.
+export OAUTH_TOKEN_FILE=""
+if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ OAUTH_TOKEN_FILE="${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}"
+fi
+
+REPO_ROOT=$(cd "$(dirname "${DIR}")"; pwd)
+run //tools/issue_reviver:issue_reviver --path "${REPO_ROOT}" --oauth-token-file="${OAUTH_TOKEN_FILE}"
diff --git a/test/iptables/README.md b/test/iptables/README.md
index b37cb2a96..9f8e34420 100644
--- a/test/iptables/README.md
+++ b/test/iptables/README.md
@@ -1,6 +1,6 @@
# iptables Tests
-iptables tests are run via `scripts/iptables\_test.sh`.
+iptables tests are run via `scripts/iptables_test.sh`.
## Test Structure
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 064ce8429..ce8abe217 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2888,7 +2888,6 @@ cc_library(
":unix_domain_socket_test_util",
"//test/util:test_util",
"//test/util:thread_util",
- "//test/util:timer_util",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index 7384c27dc..59ec9940a 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -1591,6 +1591,34 @@ TEST(Inotify, EpollNoDeadlock) {
}
}
+TEST(Inotify, SpliceEvent) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ const int watcher = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ char buf;
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr,
+ sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK),
+ SyscallSucceedsWithValue(sizeof(struct inotify_event)));
+
+ const FileDescriptor read_fd(pipes[0]);
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)}));
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index 8398fc95f..6b472eb2f 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -187,24 +187,24 @@ PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) {
return InterfaceIndex(name);
}
-std::string GetAddr4Str(in_addr* a) {
+std::string GetAddr4Str(const in_addr* a) {
char str[INET_ADDRSTRLEN];
inet_ntop(AF_INET, a, str, sizeof(str));
return std::string(str);
}
-std::string GetAddr6Str(in6_addr* a) {
+std::string GetAddr6Str(const in6_addr* a) {
char str[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, a, str, sizeof(str));
return std::string(str);
}
-std::string GetAddrStr(sockaddr* a) {
+std::string GetAddrStr(const sockaddr* a) {
if (a->sa_family == AF_INET) {
- auto src = &(reinterpret_cast<sockaddr_in*>(a)->sin_addr);
+ auto src = &(reinterpret_cast<const sockaddr_in*>(a)->sin_addr);
return GetAddr4Str(src);
} else if (a->sa_family == AF_INET6) {
- auto src = &(reinterpret_cast<sockaddr_in6*>(a)->sin6_addr);
+ auto src = &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr);
return GetAddr6Str(src);
}
return std::string("<invalid>");
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 9cb4566db..0f58e0f77 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -105,14 +105,14 @@ class IfAddrHelper {
};
// GetAddr4Str returns the given IPv4 network address structure as a string.
-std::string GetAddr4Str(in_addr* a);
+std::string GetAddr4Str(const in_addr* a);
// GetAddr6Str returns the given IPv6 network address structure as a string.
-std::string GetAddr6Str(in6_addr* a);
+std::string GetAddr6Str(const in6_addr* a);
// GetAddrStr returns the given IPv4 or IPv6 network address structure as a
// string.
-std::string GetAddrStr(sockaddr* a);
+std::string GetAddrStr(const sockaddr* a);
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc
index 33822ee57..df7129acc 100644
--- a/test/syscalls/linux/partial_bad_buffer.cc
+++ b/test/syscalls/linux/partial_bad_buffer.cc
@@ -18,7 +18,9 @@
#include <netinet/tcp.h>
#include <sys/mman.h>
#include <sys/socket.h>
+#include <sys/stat.h>
#include <sys/syscall.h>
+#include <sys/types.h>
#include <sys/uio.h>
#include <unistd.h>
@@ -62,9 +64,9 @@ class PartialBadBufferTest : public ::testing::Test {
// Write some initial data.
size_t size = sizeof(kMessage) - 1;
EXPECT_THAT(WriteFd(fd_, &kMessage, size), SyscallSucceedsWithValue(size));
-
ASSERT_THAT(lseek(fd_, 0, SEEK_SET), SyscallSucceeds());
+ // Map a useable buffer.
addr_ = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
ASSERT_NE(addr_, MAP_FAILED);
@@ -79,6 +81,15 @@ class PartialBadBufferTest : public ::testing::Test {
bad_buffer_ = buf + kPageSize - 1;
}
+ off_t Size() {
+ struct stat st;
+ int rc = fstat(fd_, &st);
+ if (rc < 0) {
+ return static_cast<off_t>(rc);
+ }
+ return st.st_size;
+ }
+
void TearDown() override {
EXPECT_THAT(munmap(addr_, 2 * kPageSize), SyscallSucceeds()) << addr_;
EXPECT_THAT(close(fd_), SyscallSucceeds());
@@ -165,97 +176,99 @@ TEST_F(PartialBadBufferTest, PreadvSmall) {
}
TEST_F(PartialBadBufferTest, WriteBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, kPageSize),
- SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, kPageSize)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WriteSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10),
- SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, 10)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwriteBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwriteSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, 10, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, 10, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WritevBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WritevSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwritevBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwritevSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
// getdents returns EFAULT when the you claim the buffer is large enough, but
@@ -283,29 +296,6 @@ TEST_F(PartialBadBufferTest, GetdentsOneEntry) {
SyscallSucceedsWithValue(Gt(0)));
}
-// Verify that when write returns EFAULT the kernel hasn't silently written
-// the initial valid bytes.
-TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
- bad_buffer_[0] = 'A';
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10),
- SyscallFailsWithErrno(EFAULT));
-
- size_t size = 255;
- char buf[255];
- memset(buf, 0, size);
-
- EXPECT_THAT(RetryEINTR(pread)(fd_, buf, size, 0),
- SyscallSucceedsWithValue(sizeof(kMessage) - 1));
-
- // 'A' has not been written.
- EXPECT_STREQ(buf, kMessage);
-}
-
PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
struct sockaddr_storage addr;
memset(&addr, 0, sizeof(addr));
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
index 5767181a1..5ed57625c 100644
--- a/test/syscalls/linux/socket_bind_to_device_distribution.cc
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -183,7 +183,14 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -198,15 +205,29 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
for (int i = 0; i < kConnectAttempts; i++) {
- FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
// Join threads to be sure that all connections have been counted.
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 619d41901..138024d9e 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -714,7 +714,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
sockaddr_storage listen_addr = listener.addr;
sockaddr_storage conn_addr = connector.addr;
constexpr int kThreadCount = 3;
- constexpr int kConnectAttempts = 4096;
+ constexpr int kConnectAttempts = 10000;
// Create the listening socket.
FileDescriptor listener_fds[kThreadCount];
@@ -729,7 +729,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
ASSERT_THAT(
bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
SyscallSucceeds());
- ASSERT_THAT(listen(fd, kConnectAttempts / 3), SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
// On the first bind we need to determine which port was bound.
if (i != 0) {
@@ -772,7 +772,14 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -795,8 +802,22 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
});
diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc
index b6754111f..ca597e267 100644
--- a/test/syscalls/linux/socket_ip_unbound.cc
+++ b/test/syscalls/linux/socket_ip_unbound.cc
@@ -129,6 +129,7 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) {
struct TOSOption {
int level;
int option;
+ int cmsg_level;
};
constexpr int INET_ECN_MASK = 3;
@@ -139,10 +140,12 @@ static TOSOption GetTOSOption(int domain) {
case AF_INET:
opt.level = IPPROTO_IP;
opt.option = IP_TOS;
+ opt.cmsg_level = SOL_IP;
break;
case AF_INET6:
opt.level = IPPROTO_IPV6;
opt.option = IPV6_TCLASS;
+ opt.cmsg_level = SOL_IPV6;
break;
}
return opt;
@@ -386,6 +389,36 @@ TEST_P(IPUnboundSocketTest, NullTOS) {
SyscallFailsWithErrno(EFAULT));
}
+TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) {
+ SKIP_IF(GetParam().protocol == IPPROTO_TCP);
+
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+
+ in_addr addr4;
+ in6_addr addr6;
+ ASSERT_THAT(inet_pton(AF_INET, "127.0.0.1", &addr4), ::testing::Eq(1));
+ ASSERT_THAT(inet_pton(AF_INET6, "fe80::", &addr6), ::testing::Eq(1));
+
+ cmsghdr cmsg = {};
+ cmsg.cmsg_len = sizeof(cmsg);
+ cmsg.cmsg_level = t.cmsg_level;
+ cmsg.cmsg_type = t.option;
+
+ msghdr msg = {};
+ msg.msg_control = &cmsg;
+ msg.msg_controllen = sizeof(cmsg);
+ if (GetParam().domain == AF_INET) {
+ msg.msg_name = &addr4;
+ msg.msg_namelen = sizeof(addr4);
+ } else {
+ msg.msg_name = &addr6;
+ msg.msg_namelen = sizeof(addr6);
+ }
+
+ EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL));
+}
+
INSTANTIATE_TEST_SUITE_P(
IPUnboundSockets, IPUnboundSocketTest,
::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>(
diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc
index d91c5ed39..c61817f14 100644
--- a/test/syscalls/linux/socket_non_stream.cc
+++ b/test/syscalls/linux/socket_non_stream.cc
@@ -113,7 +113,7 @@ TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were updated.
- EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
}
// Stream sockets allow data sent with multiple sends to be peeked at in a
@@ -193,7 +193,7 @@ TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were updated.
- EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
}
TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) {
@@ -224,5 +224,114 @@ TEST_P(NonStreamSocketPairTest, MsgTruncNotFull) {
EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
}
+// This test tests reading from a socket with MSG_TRUNC and a zero length
+// receive buffer. The user should be able to get the message length.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags.
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer. The user should be able to get the message length
+// without reading data off the socket.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncMsgPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags because the receive buffer is
+ // smaller than the message size.
+ EXPECT_EQ(peek_msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+
+ char received_data[sizeof(sent_data)] = {};
+
+ struct iovec received_iov;
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = sizeof(received_data);
+ struct msghdr received_msg = {};
+ received_msg.msg_flags = -1;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+
+ // Next we can read the actual data.
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &received_msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is not set on msghdr flags because we read the whole
+ // message.
+ EXPECT_EQ(received_msg.msg_flags & MSG_TRUNC, 0);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The user should be able to get an
+// EAGAIN or EWOULDBLOCK error response.
+TEST_P(NonStreamSocketPairTest, RecvmsgTruncPeekDontwaitZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't send any data on the socket.
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // recvmsg fails with EAGAIN because no data is available on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK | MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc
index 62d87c1af..b052f6e61 100644
--- a/test/syscalls/linux/socket_non_stream_blocking.cc
+++ b/test/syscalls/linux/socket_non_stream_blocking.cc
@@ -25,6 +25,7 @@
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -44,5 +45,41 @@ TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
SyscallSucceedsWithValue(sizeof(sent_data)));
}
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The recvmsg call should block on
+// reading the data.
+TEST_P(BlockingNonStreamSocketPairTest,
+ RecvmsgTruncPeekDontwaitZeroLenBlocking) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't initially send any data on the socket.
+ const int data_size = 10;
+ char sent_data[data_size];
+ RandomizeBuffer(sent_data, data_size);
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ ScopedThread t([&]() {
+ // The syscall succeeds returning the full size of the message on the
+ // socket. This should block until there is data on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(data_size));
+ });
+
+ absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data, data_size, 0),
+ SyscallSucceedsWithValue(data_size));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc
index 346443f96..6522b2e01 100644
--- a/test/syscalls/linux/socket_stream.cc
+++ b/test/syscalls/linux/socket_stream.cc
@@ -104,7 +104,60 @@ TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were cleared (MSG_TRUNC was not set).
- EXPECT_EQ(msg.msg_flags, 0);
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
}
TEST_P(StreamSocketPairTest, MsgTrunc) {
diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD
new file mode 100644
index 000000000..ee7ea11fd
--- /dev/null
+++ b/tools/issue_reviver/BUILD
@@ -0,0 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "issue_reviver",
+ srcs = ["main.go"],
+ deps = [
+ "//tools/issue_reviver/github",
+ "//tools/issue_reviver/reviver",
+ ],
+)
diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD
new file mode 100644
index 000000000..6da22ba1c
--- /dev/null
+++ b/tools/issue_reviver/github/BUILD
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "github",
+ srcs = ["github.go"],
+ importpath = "gvisor.dev/gvisor/tools/issue_reviver/github",
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+ deps = [
+ "//tools/issue_reviver/reviver",
+ "@com_github_google_go-github//github:go_default_library",
+ "@org_golang_x_oauth2//:go_default_library",
+ ],
+)
diff --git a/tools/issue_reviver/github/github.go b/tools/issue_reviver/github/github.go
new file mode 100644
index 000000000..e07949c8f
--- /dev/null
+++ b/tools/issue_reviver/github/github.go
@@ -0,0 +1,164 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package github implements reviver.Bugger interface on top of Github issues.
+package github
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/go-github/github"
+ "golang.org/x/oauth2"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+// Bugger implements reviver.Bugger interface for github issues.
+type Bugger struct {
+ owner string
+ repo string
+ dryRun bool
+
+ client *github.Client
+ issues map[int]*github.Issue
+}
+
+// NewBugger creates a new Bugger.
+func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) {
+ b := &Bugger{
+ owner: owner,
+ repo: repo,
+ dryRun: dryRun,
+ issues: map[int]*github.Issue{},
+ }
+ if err := b.load(token); err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+func (b *Bugger) load(token string) error {
+ ctx := context.Background()
+ if len(token) == 0 {
+ fmt.Print("No OAUTH token provided, using unauthenticated account.\n")
+ b.client = github.NewClient(nil)
+ } else {
+ ts := oauth2.StaticTokenSource(
+ &oauth2.Token{AccessToken: token},
+ )
+ tc := oauth2.NewClient(ctx, ts)
+ b.client = github.NewClient(tc)
+ }
+
+ err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) {
+ opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts}
+ tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts)
+ if err != nil {
+ return resp, err
+ }
+ for _, issue := range tmps {
+ b.issues[issue.GetNumber()] = issue
+ }
+ return resp, nil
+ })
+ if err != nil {
+ return err
+ }
+
+ fmt.Printf("Loaded %d issues from github.com/%s/%s\n", len(b.issues), b.owner, b.repo)
+ return nil
+}
+
+// Activate implements reviver.Bugger.
+func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) {
+ const prefix = "gvisor.dev/issue/"
+
+ // First check if I can handle the TODO.
+ idStr := strings.TrimPrefix(todo.Issue, prefix)
+ if len(todo.Issue) == len(idStr) {
+ return false, nil
+ }
+
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ return true, err
+ }
+
+ // Check against active issues cache.
+ if _, ok := b.issues[id]; ok {
+ fmt.Printf("%q is active: OK\n", todo.Issue)
+ return true, nil
+ }
+
+ fmt.Printf("%q is not active: reopening issue %d\n", todo.Issue, id)
+
+ // Format comment with TODO locations and search link.
+ comment := strings.Builder{}
+ fmt.Fprintln(&comment, "There are TODOs still referencing this issue:")
+ for _, l := range todo.Locations {
+ fmt.Fprintf(&comment,
+ "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n",
+ l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment)
+ }
+ fmt.Fprintf(&comment,
+ "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%d%%22)", b.owner, b.repo, prefix, id)
+
+ if b.dryRun {
+ fmt.Printf("[dry-run: skipping change to issue %d]\n%s\n=======================\n", id, comment.String())
+ return true, nil
+ }
+
+ ctx := context.Background()
+ req := &github.IssueRequest{State: github.String("open")}
+ _, _, err = b.client.Issues.Edit(ctx, b.owner, b.repo, id, req)
+ if err != nil {
+ return true, fmt.Errorf("failed to reactivate issue %d: %v", id, err)
+ }
+
+ cmt := &github.IssueComment{
+ Body: github.String(comment.String()),
+ Reactions: &github.Reactions{Confused: github.Int(1)},
+ }
+ if _, _, err := b.client.Issues.CreateComment(ctx, b.owner, b.repo, id, cmt); err != nil {
+ return true, fmt.Errorf("failed to add comment to issue %d: %v", id, err)
+ }
+
+ return true, nil
+}
+
+func processAllPages(fn func(github.ListOptions) (*github.Response, error)) error {
+ opts := github.ListOptions{PerPage: 1000}
+ for {
+ resp, err := fn(opts)
+ if err != nil {
+ if rateErr, ok := err.(*github.RateLimitError); ok {
+ duration := rateErr.Rate.Reset.Sub(time.Now())
+ if duration > 5*time.Minute {
+ return fmt.Errorf("Rate limited for too long: %v", duration)
+ }
+ fmt.Printf("Rate limited, sleeping for: %v\n", duration)
+ time.Sleep(duration)
+ continue
+ }
+ return err
+ }
+ if resp.NextPage == 0 {
+ return nil
+ }
+ opts.Page = resp.NextPage
+ }
+}
diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go
new file mode 100644
index 000000000..4256f5a6c
--- /dev/null
+++ b/tools/issue_reviver/main.go
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package main is the entry point for issue_reviver.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+
+ "gvisor.dev/gvisor/tools/issue_reviver/github"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+var (
+ owner string
+ repo string
+ tokenFile string
+ path string
+ dryRun bool
+)
+
+// Keep the options simple for now. Supports only a single path and repo.
+func init() {
+ flag.StringVar(&owner, "owner", "google", "Github project org/owner to look for issues")
+ flag.StringVar(&repo, "repo", "gvisor", "Github repo to look for issues")
+ flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github")
+ flag.StringVar(&path, "path", "", "Path to scan for TODOs")
+ flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues")
+}
+
+func main() {
+ flag.Parse()
+
+ // Check for mandatory parameters.
+ if len(owner) == 0 {
+ fmt.Println("missing --owner option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(repo) == 0 {
+ fmt.Println("missing --repo option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(path) == 0 {
+ fmt.Println("missing --path option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // Token is passed as a file so it doesn't show up in command line arguments.
+ var token string
+ if len(tokenFile) != 0 {
+ bytes, err := ioutil.ReadFile(tokenFile)
+ if err != nil {
+ fmt.Println(err.Error())
+ os.Exit(1)
+ }
+ token = string(bytes)
+ }
+
+ bugger, err := github.NewBugger(token, owner, repo, dryRun)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error getting github issues:", err)
+ os.Exit(1)
+ }
+ rev := reviver.New([]string{path}, []reviver.Bugger{bugger})
+ if errs := rev.Run(); len(errs) > 0 {
+ fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs))
+ for _, err := range errs {
+ fmt.Fprintf(os.Stderr, "\t%v\n", err)
+ }
+ os.Exit(1)
+ }
+}
diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD
new file mode 100644
index 000000000..2c3675977
--- /dev/null
+++ b/tools/issue_reviver/reviver/BUILD
@@ -0,0 +1,19 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "reviver",
+ srcs = ["reviver.go"],
+ importpath = "gvisor.dev/gvisor/tools/issue_reviver/reviver",
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+)
+
+go_test(
+ name = "reviver_test",
+ size = "small",
+ srcs = ["reviver_test.go"],
+ embed = [":reviver"],
+)
diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/issue_reviver/reviver/reviver.go
new file mode 100644
index 000000000..682db0c01
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver.go
@@ -0,0 +1,192 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package reviver scans the code looking for TODOs and pass them to registered
+// Buggers to ensure TODOs point to active issues.
+package reviver
+
+import (
+ "bufio"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "regexp"
+ "sync"
+)
+
+// This is how a TODO looks like.
+var regexTodo = regexp.MustCompile(`(\/\/|#)\s*(TODO|FIXME)\(([a-zA-Z0-9.\/]+)\):\s*(.+)`)
+
+// Bugger interface is called for every TODO found in the code. If it can handle
+// the TODO, it must return true. If it returns false, the next Bugger is
+// called. If no Bugger handles the TODO, it's dropped on the floor.
+type Bugger interface {
+ Activate(todo *Todo) (bool, error)
+}
+
+// Location saves the location where the TODO was found.
+type Location struct {
+ Comment string
+ File string
+ Line uint
+}
+
+// Todo represents a unique TODO. There can be several TODOs pointing to the
+// same issue in the code. They are all grouped together.
+type Todo struct {
+ Issue string
+ Locations []Location
+}
+
+// Reviver scans the given paths for TODOs and calls Buggers to handle them.
+type Reviver struct {
+ paths []string
+ buggers []Bugger
+
+ mu sync.Mutex
+ todos map[string]*Todo
+ errs []error
+}
+
+// New create a new Reviver.
+func New(paths []string, buggers []Bugger) *Reviver {
+ return &Reviver{
+ paths: paths,
+ buggers: buggers,
+ todos: map[string]*Todo{},
+ }
+}
+
+// Run runs. It returns all errors found during processing, it doesn't stop
+// on errors.
+func (r *Reviver) Run() []error {
+ // Process each directory in parallel.
+ wg := sync.WaitGroup{}
+ for _, path := range r.paths {
+ wg.Add(1)
+ go func(path string) {
+ defer wg.Done()
+ r.processPath(path, &wg)
+ }(path)
+ }
+
+ wg.Wait()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ fmt.Printf("Processing %d TODOs (%d errors)...\n", len(r.todos), len(r.errs))
+ dropped := 0
+ for _, todo := range r.todos {
+ ok, err := r.processTodo(todo)
+ if err != nil {
+ r.errs = append(r.errs, err)
+ }
+ if !ok {
+ dropped++
+ }
+ }
+ fmt.Printf("Processed %d TODOs, %d were skipped (%d errors)\n", len(r.todos)-dropped, dropped, len(r.errs))
+
+ return r.errs
+}
+
+func (r *Reviver) processPath(path string, wg *sync.WaitGroup) {
+ fmt.Printf("Processing dir %q\n", path)
+ fis, err := ioutil.ReadDir(path)
+ if err != nil {
+ r.addErr(fmt.Errorf("error processing dir %q: %v", path, err))
+ return
+ }
+
+ for _, fi := range fis {
+ childPath := filepath.Join(path, fi.Name())
+ switch {
+ case fi.Mode().IsDir():
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r.processPath(childPath, wg)
+ }()
+
+ case fi.Mode().IsRegular():
+ file, err := os.Open(childPath)
+ if err != nil {
+ r.addErr(err)
+ continue
+ }
+
+ scanner := bufio.NewScanner(file)
+ lineno := uint(0)
+ for scanner.Scan() {
+ lineno++
+ line := scanner.Text()
+ if todo := r.processLine(line, childPath, lineno); todo != nil {
+ r.addTodo(todo)
+ }
+ }
+ }
+ }
+}
+
+func (r *Reviver) processLine(line, path string, lineno uint) *Todo {
+ matches := regexTodo.FindStringSubmatch(line)
+ if matches == nil {
+ return nil
+ }
+ if len(matches) != 5 {
+ panic(fmt.Sprintf("regex returned wrong matches for %q: %v", line, matches))
+ }
+ return &Todo{
+ Issue: matches[3],
+ Locations: []Location{
+ {
+ File: path,
+ Line: lineno,
+ Comment: matches[4],
+ },
+ },
+ }
+}
+
+func (r *Reviver) addTodo(newTodo *Todo) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if todo := r.todos[newTodo.Issue]; todo == nil {
+ r.todos[newTodo.Issue] = newTodo
+ } else {
+ todo.Locations = append(todo.Locations, newTodo.Locations...)
+ }
+}
+
+func (r *Reviver) addErr(err error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.errs = append(r.errs, err)
+}
+
+func (r *Reviver) processTodo(todo *Todo) (bool, error) {
+ for _, bugger := range r.buggers {
+ ok, err := bugger.Activate(todo)
+ if err != nil {
+ return false, err
+ }
+ if ok {
+ return true, nil
+ }
+ }
+ return false, nil
+}
diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/issue_reviver/reviver/reviver_test.go
new file mode 100644
index 000000000..a9fb1f9f1
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver_test.go
@@ -0,0 +1,88 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reviver
+
+import (
+ "testing"
+)
+
+func TestProcessLine(t *testing.T) {
+ for _, tc := range []struct {
+ line string
+ want *Todo
+ }{
+ {
+ line: "// TODO(foobar.com/issue/123): comment, bla. blabla.",
+ want: &Todo{
+ Issue: "foobar.com/issue/123",
+ Locations: []Location{
+ {Comment: "comment, bla. blabla."},
+ },
+ },
+ },
+ {
+ line: "// FIXME(b/123): internal bug",
+ want: &Todo{
+ Issue: "b/123",
+ Locations: []Location{
+ {Comment: "internal bug"},
+ },
+ },
+ },
+ {
+ line: "TODO(issue): not todo",
+ },
+ {
+ line: "FIXME(issue): not todo",
+ },
+ {
+ line: "// TODO (issue): not todo",
+ },
+ {
+ line: "// TODO(issue) not todo",
+ },
+ {
+ line: "// todo(issue): not todo",
+ },
+ {
+ line: "// TODO(issue):",
+ },
+ } {
+ t.Logf("Testing: %s", tc.line)
+ r := Reviver{}
+ got := r.processLine(tc.line, "test", 0)
+ if got == nil {
+ if tc.want != nil {
+ t.Errorf("failed to process line, want: %+v", tc.want)
+ }
+ } else {
+ if tc.want == nil {
+ t.Errorf("expected error, got: %+v", got)
+ continue
+ }
+ if got.Issue != tc.want.Issue {
+ t.Errorf("wrong issue, got: %v, want: %v", got.Issue, tc.want.Issue)
+ }
+ if len(got.Locations) != len(tc.want.Locations) {
+ t.Errorf("wrong number of locations, got: %v, want: %v, locations: %+v", len(got.Locations), len(tc.want.Locations), got.Locations)
+ }
+ for i, wantLoc := range tc.want.Locations {
+ if got.Locations[i].Comment != wantLoc.Comment {
+ t.Errorf("wrong comment, got: %v, want: %v", got.Locations[i].Comment, wantLoc.Comment)
+ }
+ }
+ }
+ }
+}