diff --git a/website/_sass/front.scss b/website/_sass/front.scss
index 44a7e3473..0e4208f3c 100644
--- a/website/_sass/front.scss
+++ b/website/_sass/front.scss
@@ -4,12 +4,14 @@
background-repeat: no-repeat;
background-size: cover;
background-blend-mode: darken;
- background-color: rgba(0, 0, 0, 0.1);
+ background-color: rgba(0, 0, 0, 0.3);
p {
color: #fff;
margin-top: 0;
margin-bottom: 0;
font-weight: 300;
+ font-size: 24px;
+ line-height: 30px;
}
}
diff --git a/website/_sass/style.scss b/website/_sass/style.scss
index 520ea469a..4deb945d4 100644
--- a/website/_sass/style.scss
+++ b/website/_sass/style.scss
@@ -142,3 +142,13 @@ table th {
margin-top: 10px;
margin-bottom: 20px;
}
+
+.docs-content * img {
+ display: block;
+ margin: 20px auto;
+}
+
+.blog-content * img {
+ display: block;
+ margin: 20px auto;
+}
diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md
index ed6d97ffe..fbdd511dd 100644
--- a/website/blog/2019-11-18-security-basics.md
+++ b/website/blog/2019-11-18-security-basics.md
@@ -56,15 +56,9 @@ in combination: redundant walls, scattered draw bridges, small bottle-neck
entrances, moats, etc.
A simplified version of the design is below
-([more detailed version](/docs/architecture_guide/))[^2]:
+([more detailed version](/docs/))[^2]:
---------------------------------------------------------------------------------
-
-![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png)
-
-Figure 1: Simplified design of gVisor.
-
---------------------------------------------------------------------------------
+![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png "Simplified design of gVisor.")
In order to discuss design principles, the following components are important to
know:
@@ -134,13 +128,7 @@ minimum level of permission is required for it to perform its function.
Specifically, the closer you are to the untrusted application, the less
privilege you have.
---------------------------------------------------------------------------------
-
-![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png)
-
-Figure 2: runsc components and their privileges.
-
---------------------------------------------------------------------------------
+![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png "runsc components and their privileges.")
This is evident in how runsc (the drop in gVisor binary for Docker/Kubernetes)
constructs the sandbox. The Sentry has the least privilege possible (it can't
@@ -222,15 +210,7 @@ the host Linux syscalls. In other words, with gVisor, applications get the vast
majority (and growing) functionality of Linux containers for only 68 possible
syscalls to the Host OS. 350 syscalls to 68 is attack surface reduction.
---------------------------------------------------------------------------------
-
-![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png)
-
-Figure 3: Reduction of Attack Surface of the Syscall Table. Note that the
-Senty's Syscall Emulation Layer keeps the Containerized Process from ever
-calling the Host OS.
-
---------------------------------------------------------------------------------
+![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png "Reduction of Attack Surface of the Syscall Table. Note that the Senty's Syscall Emulation Layer keeps the Containerized Process from ever calling the Host OS.")
## Secure-by-default
diff --git a/website/blog/2020-04-02-networking-security.md b/website/blog/2020-04-02-networking-security.md
index 78f0a6714..5a5e38fd7 100644
--- a/website/blog/2020-04-02-networking-security.md
+++ b/website/blog/2020-04-02-networking-security.md
@@ -69,13 +69,7 @@ a similar syscall). Moreover, because packets typically come from off-host (e.g.
the internet), the Host OS's packet processing code has received a lot of
scrutiny, hopefully resulting in a high degree of hardening.
---------------------------------------------------------------------------------
-
-![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png)
-
-Figure 1: Netstack and gVisor
-
---------------------------------------------------------------------------------
+![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png "Network and gVisor.")
## Writing a network stack
diff --git a/website/index.md b/website/index.md
index 95d5d16f0..84f877d49 100644
--- a/website/index.md
+++ b/website/index.md
@@ -3,10 +3,10 @@
-
gVisor is an application kernel and container runtime providing defense-in-depth for containers anywhere.
+
gVisor is an application kernel for containers that provides efficient defense-in-depth anywhere.
By providing each container with its own userspace kernel, gVisor limits
- the attack surface of the host. This protection does not limit
+
By providing each container with its own application kernel, gVisor
+ limits the attack surface of the host. This protection does not limit
functionality: gVisor runs unmodified binaries and integrates with container
orchestration systems, such as Docker and Kubernetes, and supports features
such as volumes and sidecars.
@@ -43,7 +43,7 @@
The pluggable platform architecture of gVisor allows it to run anywhere,
enabling consistent security policies across multiple environments without
having to rearchitect your infrastructure.
- Get Started »
+ Read More »
--
cgit v1.2.3
From 526df4f52a07a02687dd43ceb752621a41883f95 Mon Sep 17 00:00:00 2001
From: Bhasker Hariharan
Date: Fri, 5 Jun 2020 13:41:19 -0700
Subject: Fix error code returned due to Port exhaustion.
For TCP sockets gVisor incorrectly returns EAGAIN when no ephemeral ports are
available to bind during a connect. Linux returns EADDRNOTAVAIL. This change
fixes gVisor to return the correct code and adds a test for the same.
This change also fixes a minor bug for ping sockets where connect() would fail
with EINVAL unless the socket was bound first.
Also added tests for testing UDP Port exhaustion and Ping socket port
exhaustion.
PiperOrigin-RevId: 314988525
---
pkg/sentry/socket/netstack/BUILD | 1 +
pkg/sentry/socket/netstack/netstack.go | 9 ++
pkg/tcpip/transport/icmp/endpoint.go | 1 +
test/syscalls/BUILD | 16 ++
test/syscalls/linux/BUILD | 35 +++++
test/syscalls/linux/ping_socket.cc | 91 +++++++++++
test/syscalls/linux/poll.cc | 6 +-
.../linux/socket_inet_loopback_nogotsan.cc | 171 +++++++++++++++++++++
test/syscalls/linux/socket_ipv4_udp_unbound.cc | 33 ++++
9 files changed, 360 insertions(+), 3 deletions(-)
create mode 100644 test/syscalls/linux/ping_socket.cc
create mode 100644 test/syscalls/linux/socket_inet_loopback_nogotsan.cc
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 333e0042e..8f0f5466e 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -50,5 +50,6 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 60df51dae..e1e0c5931 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -33,6 +33,7 @@ import (
"syscall"
"time"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/binary"
@@ -719,6 +720,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
defer s.EventUnregister(&e)
if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
+ // TCP unlike UDP returns EADDRNOTAVAIL when it can't
+ // find an available local ephemeral port.
+ if err == tcpip.ErrNoPortAvailable {
+ return syserr.ErrAddressNotAvailable
+ }
+ }
+
return syserr.TranslateNetstackError(err)
}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 3bc72bc19..57e0a069b 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -506,6 +506,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
localPort := uint16(0)
switch e.state {
+ case stateInitial:
case stateBound, stateConnected:
localPort = e.ID.LocalPort
if e.BindNICID == 0 {
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 3406a2de8..d68afbe44 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -400,6 +400,14 @@ syscall_test(
vfs2 = "True",
)
+syscall_test(
+ size = "medium",
+ # Takes too long under gotsan to run.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:ping_socket_test",
+ vfs2 = "True",
+)
+
syscall_test(
size = "large",
add_overlay = True,
@@ -697,6 +705,14 @@ syscall_test(
test = "//test/syscalls/linux:socket_inet_loopback_test",
)
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ # Takes too long for TSAN. Creates a lot of TCP sockets.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test",
+)
+
syscall_test(
size = "large",
shard_count = 50,
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index f4b5de18d..ae2aa44dc 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1411,6 +1411,21 @@ cc_binary(
],
)
+cc_binary(
+ name = "ping_socket_test",
+ testonly = 1,
+ srcs = ["ping_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
cc_binary(
name = "pipe_test",
testonly = 1,
@@ -2780,6 +2795,26 @@ cc_binary(
],
)
+cc_binary(
+ name = "socket_inet_loopback_nogotsan_test",
+ testonly = 1,
+ srcs = ["socket_inet_loopback_nogotsan.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
cc_binary(
name = "socket_netlink_test",
testonly = 1,
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
new file mode 100644
index 000000000..a9bfdb37b
--- /dev/null
+++ b/test/syscalls/linux/ping_socket.cc
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class PingSocket : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // The loopback address.
+ struct sockaddr_in addr_;
+};
+
+void PingSocket::SetUp() {
+ // On some hosts ping sockets are restricted to specific groups using the
+ // sysctl "ping_group_range".
+ int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
+ if (s < 0 && errno == EPERM) {
+ GTEST_SKIP();
+ }
+ close(s);
+
+ addr_ = {};
+ // Just a random port as the destination port number is irrelevant for ping
+ // sockets.
+ addr_.sin_port = 12345;
+ addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr_.sin_family = AF_INET;
+}
+
+void PingSocket::TearDown() {}
+
+// Test ICMP port exhaustion returns EAGAIN.
+//
+// We disable both random/cooperative S/R for this test as it makes way too many
+// syscalls.
+TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) {
+ DisableSave ds;
+ std::vector sockets;
+ constexpr int kSockets = 65536;
+ addr_.sin_port = 0;
+ for (int i = 0; i < kSockets; i++) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+ int ret = connect(s.get(), reinterpret_cast(&addr_),
+ sizeof(addr_));
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index 1e35a4a8b..7a316427d 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -259,9 +259,9 @@ TEST_F(PollTest, Nfds) {
TEST_PCHECK(getrlimit(RLIMIT_NOFILE, &rlim) == 0);
// gVisor caps the number of FDs that epoll can use beyond RLIMIT_NOFILE.
- constexpr rlim_t gVisorMax = 1048576;
- if (rlim.rlim_cur > gVisorMax) {
- rlim.rlim_cur = gVisorMax;
+ constexpr rlim_t maxFD = 4096;
+ if (rlim.rlim_cur > maxFD) {
+ rlim.rlim_cur = maxFD;
TEST_PCHECK(setrlimit(RLIMIT_NOFILE, &rlim) == 0);
}
diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
new file mode 100644
index 000000000..2324c7f6a
--- /dev/null
+++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::Gt;
+
+PosixErrorOr AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast(
+ reinterpret_cast(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast(
+ reinterpret_cast(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+struct TestParam {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+std::string DescribeTestParam(::testing::TestParamInfo const& info) {
+ return absl::StrCat("Listen", info.param.listener.description, "_Connect",
+ info.param.connector.description);
+}
+
+using SocketInetLoopbackTest = ::testing::TestWithParam;
+
+// This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral
+// ports are already in use for a given destination ip/port.
+// We disable S/R because this test creates a large number of sockets.
+TEST_P(SocketInetLoopbackTest, TestTCPPortExhaustion_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 10;
+ constexpr int kClients = 65536;
+
+ // Create the listening socket.
+ auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+
+ // Now we keep opening connections till we run out of local ephemeral ports.
+ // and assert the error we get back.
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector clients;
+ std::vector servers;
+
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast(&conn_addr),
+ connector.addr_len);
+ if (ret == 0) {
+ clients.push_back(std::move(client));
+ FileDescriptor server =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ servers.push_back(std::move(server));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRNOTAVAIL));
+ break;
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ All, SocketInetLoopbackTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Any(), V4MappedAny()},
+ TestParam{V4Any(), V4MappedLoopback()},
+ TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+ TestParam{V4MappedAny(), V4Any()},
+ TestParam{V4MappedAny(), V4Loopback()},
+ TestParam{V4MappedAny(), V4MappedAny()},
+ TestParam{V4MappedAny(), V4MappedLoopback()},
+ TestParam{V4MappedLoopback(), V4Any()},
+ TestParam{V4MappedLoopback(), V4Loopback()},
+ TestParam{V4MappedLoopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()},
+ TestParam{V6Any(), V4MappedAny()},
+ TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()},
+ TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Any()},
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index bc4b07a62..1294d9050 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -2129,6 +2129,39 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
SyscallSucceedsWithValue(kMessageSize));
}
+// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports.
+// We disable S/R because this test creates a large number of sockets.
+TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) {
+ auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ constexpr int kClients = 65536;
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(receiver1->get(), reinterpret_cast(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(receiver1->get(),
+ reinterpret_cast(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+ std::vector> sockets;
+ for (int i = 0; i < kClients; i++) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int ret = connect(s->get(), reinterpret_cast(&addr.addr),
+ addr.addr_len);
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
// Test that socket will receive packet info control message.
TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
// TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
--
cgit v1.2.3
From 4b9652d63b319414e764696f1b77ee39cd36d96d Mon Sep 17 00:00:00 2001
From: Nayana Bidari
Date: Wed, 10 Jun 2020 13:36:02 -0700
Subject: {S,G}etsockopt for TCP_KEEPCNT option.
TCP_KEEPCNT is used to set the maximum keepalive probes to be
sent before dropping the connection.
WANT_LGTM=jchacon
PiperOrigin-RevId: 315758094
---
pkg/abi/linux/tcp.go | 1 +
pkg/sentry/socket/netstack/netstack.go | 35 ++++++++-----
test/syscalls/linux/socket_ip_tcp_generic.cc | 73 ++++++++++++++++++++++++++++
3 files changed, 98 insertions(+), 11 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go
index 174d470e2..2a8d4708b 100644
--- a/pkg/abi/linux/tcp.go
+++ b/pkg/abi/linux/tcp.go
@@ -57,4 +57,5 @@ const (
const (
MAX_TCP_KEEPIDLE = 32767
MAX_TCP_KEEPINTVL = 32767
+ MAX_TCP_KEEPCNT = 127
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e1e0c5931..738277391 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1246,6 +1246,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_KEEPCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1786,6 +1798,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_KEEPCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPCNT {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v)))
+
case linux.TCP_USER_TIMEOUT:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2115,30 +2138,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) {
switch name {
case linux.TCP_CONGESTION,
linux.TCP_CORK,
- linux.TCP_DEFER_ACCEPT,
linux.TCP_FASTOPEN,
linux.TCP_FASTOPEN_CONNECT,
linux.TCP_FASTOPEN_KEY,
linux.TCP_FASTOPEN_NO_COOKIE,
- linux.TCP_KEEPCNT,
- linux.TCP_KEEPIDLE,
- linux.TCP_KEEPINTVL,
- linux.TCP_LINGER2,
- linux.TCP_MAXSEG,
linux.TCP_QUEUE_SEQ,
- linux.TCP_QUICKACK,
linux.TCP_REPAIR,
linux.TCP_REPAIR_QUEUE,
linux.TCP_REPAIR_WINDOW,
linux.TCP_SAVED_SYN,
linux.TCP_SAVE_SYN,
- linux.TCP_SYNCNT,
linux.TCP_THIN_DUPACK,
linux.TCP_THIN_LINEAR_TIMEOUTS,
linux.TCP_TIMESTAMP,
- linux.TCP_ULP,
- linux.TCP_USER_TIMEOUT,
- linux.TCP_WINDOW_CLAMP:
+ linux.TCP_ULP:
t.Kernel().EmitUnimplementedEvent(t)
}
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index fa81845fd..15adc8d0e 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -524,6 +524,7 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) {
// Copied from include/net/tcp.h.
constexpr int MAX_TCP_KEEPIDLE = 32767;
constexpr int MAX_TCP_KEEPINTVL = 32767;
+constexpr int MAX_TCP_KEEPCNT = 127;
TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -575,6 +576,78 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) {
EXPECT_EQ(get, MAX_TCP_KEEPINTVL);
}
+TEST_P(TCPSocketPairTest, TCPKeepcountDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 9); // 9 keepalive probes.
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &kZero,
+ sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAboveMax = MAX_TCP_KEEPCNT + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &kAboveMax, sizeof(kAboveMax)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &MAX_TCP_KEEPCNT, sizeof(MAX_TCP_KEEPCNT)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, MAX_TCP_KEEPCNT);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToOne) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, keepaliveCount);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToNegative) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = -5;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST_P(TCPSocketPairTest, SetOOBInline) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
--
cgit v1.2.3
From d962f9f3842c5c352bc61411cf27e38ba2219317 Mon Sep 17 00:00:00 2001
From: gVisor bot
Date: Fri, 19 Jun 2020 11:41:37 -0700
Subject: Implement UDP cheksum verification.
Test:
- TestIncrementChecksumErrors
Fixes #2943
PiperOrigin-RevId: 317348158
---
pkg/sentry/socket/netstack/netstack.go | 1 +
pkg/sentry/socket/netstack/stack.go | 2 +-
pkg/tcpip/tcpip.go | 6 +
pkg/tcpip/transport/udp/endpoint.go | 18 ++
pkg/tcpip/transport/udp/udp_test.go | 320 +++++++++++++++++++++++++--------
5 files changed, 276 insertions(+), 71 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 738277391..c0b63a803 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -191,6 +191,7 @@ var Metrics = tcpip.Stats{
MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
+ ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
},
}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index ee11742a6..d2fb655ea 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -313,7 +313,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
udp.PacketsSent.Value(), // OutDatagrams.
udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
0, // Udp/SndbufErrors.
- 0, // Udp/InCsumErrors.
+ udp.ChecksumErrors.Value(), // Udp/InCsumErrors.
0, // Udp/IgnoredMulti.
}
default:
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 3ad130b23..956232a44 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1224,6 +1224,9 @@ type UDPStats struct {
// PacketSendErrors is the number of datagrams failed to be sent.
PacketSendErrors *StatCounter
+
+ // ChecksumErrors is the number of datagrams dropped due to bad checksums.
+ ChecksumErrors *StatCounter
}
// Stats holds statistics about the networking stack.
@@ -1267,6 +1270,9 @@ type ReceiveErrors struct {
// ClosedReceiver is the number of received packets dropped because
// of receiving endpoint state being closed.
ClosedReceiver StatCounter
+
+ // ChecksumErrors is the number of packets dropped due to bad checksums.
+ ChecksumErrors StatCounter
}
// SendErrors collects packet send errors within the transport layer for
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f51988047..40d66ef09 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1350,6 +1350,24 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
+ // Verify checksum unless RX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value means
+ // the transmitter omitted the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ if hdr.CalculateChecksum(xsum) != 0xffff {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
+ }
+
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 313a3f117..ff9f60cf9 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -292,15 +292,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio
wep = sniffer.New(ep)
}
if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -391,17 +391,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
} else {
- c.injectV6Packet(payload, &h, true /* valid */)
+ buf := c.buildV6Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
}
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -420,16 +424,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: l,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
@@ -439,17 +437,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ return buf
}
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -483,11 +476,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
-
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ return buf
}
func newPayload() []byte {
@@ -509,7 +498,7 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
@@ -643,7 +632,7 @@ func TestBindEphemeralPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}
@@ -654,19 +643,19 @@ func TestBindReservedPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
+ t.Fatalf("GetLocalAddress failed: %s", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
@@ -677,7 +666,7 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
@@ -687,7 +676,7 @@ func TestBindReservedPort(t *testing.T) {
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
@@ -697,11 +686,11 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
}
@@ -714,7 +703,7 @@ func TestV4ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -729,7 +718,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -744,7 +733,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -759,7 +748,7 @@ func TestV6ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -796,7 +785,10 @@ func TestV4ReadSelfSource(t *testing.T) {
h := unicastV4.header4Tuple(incoming)
h.srcAddr = h.dstAddr
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
@@ -817,7 +809,7 @@ func TestV4ReadOnV4(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -955,7 +947,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ c.t.Fatalf("Write failed: %s", err)
}
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
@@ -1005,7 +997,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
p := testDualWrite(c)
@@ -1022,7 +1014,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV6)
@@ -1043,7 +1035,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV4in6)
@@ -1070,7 +1062,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Write to v6 address.
@@ -1085,7 +1077,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV6)
@@ -1099,7 +1091,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV4)
@@ -1234,7 +1226,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testRead(c, unicastV4)
@@ -1506,12 +1498,12 @@ func TestMulticastInterfaceOption(t *testing.T) {
Port: stackPort,
}
if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
}
if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOpt failed: %s", err)
}
// Verify multicast interface addr and NIC were set correctly.
@@ -1519,7 +1511,7 @@ func TestMulticastInterfaceOption(t *testing.T) {
ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
+ c.t.Fatalf("GetSockOpt failed: %s", err)
}
if ifoptGot != ifoptWant {
c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
@@ -1691,7 +1683,7 @@ func TestV6UnknownDestination(t *testing.T) {
}
// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
+// global and endpoint stats are incremented.
func TestIncrementMalformedPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1699,20 +1691,25 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
payload := newPayload()
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
+ buf := c.buildV6Packet(payload, &h)
+ // Invalidate the packet length field in the UDP header by adding one.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
- var want uint64 = 1
+ const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
}
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -1728,7 +1725,6 @@ func TestShortHeader(t *testing.T) {
c.t.Fatalf("Bind failed: %s", err)
}
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
// Allocate a buffer for an IPv6 and too-short UDP header.
@@ -1768,6 +1764,190 @@ func TestShortHeader(t *testing.T) {
}
}
+// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Invalidate the checksum field in the UDP header by adding one.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Invalidate the checksum field in the UDP header by adding one.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
// TestShutdownRead verifies endpoint read shutdown and error
// stats increment on packet receive.
func TestShutdownRead(t *testing.T) {
@@ -1778,15 +1958,15 @@ func TestShutdownRead(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingRead(c, unicastV6, true /* expectReadError */)
@@ -1809,11 +1989,11 @@ func TestShutdownWrite(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
--
cgit v1.2.3
From 8dbeac53ce1b3c1cf4a5f2f0ccdd7196f4656fd8 Mon Sep 17 00:00:00 2001
From: gVisor bot
Date: Fri, 26 Jun 2020 17:49:32 -0700
Subject: Implement SO_NO_CHECK socket option.
SO_NO_CHECK is used to skip the UDP checksum generation on a TX socket
(UDP checksum is optional on IPv4).
Test:
- TestNoChecksum
- SoNoCheckOffByDefault (UdpSocketTest)
- SoNoCheck (UdpSocketTest)
Fixes #3055
PiperOrigin-RevId: 318575215
---
pkg/sentry/socket/netstack/netstack.go | 19 +++++
pkg/sentry/socket/socket.go | 1 -
pkg/tcpip/checker/checker.go | 16 +++++
pkg/tcpip/tcpip.go | 100 +++++++++++++++------------
pkg/tcpip/transport/udp/endpoint.go | 24 +++++--
pkg/tcpip/transport/udp/udp_test.go | 24 +++++++
test/syscalls/linux/udp_socket_test_cases.cc | 38 ++++++++++
7 files changed, 173 insertions(+), 49 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index c0b63a803..e7d2c83d7 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1169,6 +1169,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return int32(v), nil
+ case linux.SO_NO_CHECK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
@@ -1720,6 +1731,14 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v)))
+ case linux.SO_NO_CHECK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
return syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 6580bd6e9..fcd7f9d7f 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -407,7 +407,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) {
linux.SO_MARK,
linux.SO_MAX_PACING_RATE,
linux.SO_NOFCS,
- linux.SO_NO_CHECK,
linux.SO_OOBINLINE,
linux.SO_PASSCRED,
linux.SO_PASSSEC,
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index c1745ba6a..ee264b726 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -320,6 +320,22 @@ func DstPort(port uint16) TransportChecker {
}
}
+// NoChecksum creates a checker that checks if the checksum is zero.
+func NoChecksum(noChecksum bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ udp, ok := h.(header.UDP)
+ if !ok {
+ return
+ }
+
+ if b := udp.Checksum() == 0; b != noChecksum {
+ t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
+ }
+ }
+}
+
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4d45dcc42..2be1c107a 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -585,59 +585,68 @@ type WriteOptions struct {
type SockOptBool int
const (
- // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
- // datagram sockets are allowed to send packets to a broadcast address.
+ // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether datagram sockets are allowed to send packets to a broadcast
+ // address.
BroadcastOption SockOptBool = iota
- // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
- // held until segments are full by the TCP transport protocol.
+ // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be held until segments are full by the TCP transport
+ // protocol.
CorkOption
- // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
- // should be sent out immediately by the transport protocol. For TCP,
- // it determines if the Nagle algorithm is on or off.
+ // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be sent out immediately by the transport protocol. For
+ // TCP, it determines if the Nagle algorithm is on or off.
DelayOption
- // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
- // TCP keepalive is enabled for this socket.
+ // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether TCP keepalive is enabled for this socket.
KeepaliveEnabledOption
- // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
- // multicast packets sent over a non-loopback interface will be looped back.
+ // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether multicast packets sent over a non-loopback interface
+ // will be looped back.
MulticastLoopOption
- // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
- // SCM_CREDENTIALS socket control messages are enabled.
+ // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether UDP checksum is disabled for this socket.
+ NoChecksumOption
+
+ // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether SCM_CREDENTIALS socket control messages are enabled.
//
// Only supported on Unix sockets.
PasscredOption
- // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
+ // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
QuickAckOption
- // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
- // IPV6_TCLASS ancillary message is passed with incoming packets.
+ // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if the IPV6_TCLASS ancillary message is passed with incoming
+ // packets.
ReceiveTClassOption
- // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
- // ancillary message is passed with incoming packets.
+ // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
+ // if the TOS ancillary message is passed with incoming packets.
ReceiveTOSOption
- // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
- // if more inforamtion is provided with incoming packets such
- // as interface index and address.
+ // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if more inforamtion is provided with incoming packets such as
+ // interface index and address.
ReceiveIPPacketInfoOption
- // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
- // should allow reuse of local address.
+ // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether Bind() should allow reuse of local address.
ReuseAddressOption
- // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
- // to be bound to an identical socket address.
+ // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
+ // multiple sockets to be bound to an identical socket address.
ReusePortOption
- // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
- // socket is to be restricted to sending and receiving IPv6 packets only.
+ // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether an IPv6 socket is to be restricted to sending and receiving
+ // IPv6 packets only.
V6OnlyOption
)
@@ -645,25 +654,27 @@ const (
type SockOptInt int
const (
- // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
- // of un-ACKed TCP keepalives that will be sent before the connection is
- // closed.
+ // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the number of un-ACKed TCP keepalives that will be sent
+ // before the connection is closed.
KeepaliveCountOption SockOptInt = iota
- // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
+ // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS
// for all subsequent outgoing IPv4 packets from the endpoint.
IPv4TOSOption
- // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
- // for all subsequent outgoing IPv6 packets from the endpoint.
+ // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to
+ // specify TOS for all subsequent outgoing IPv6 packets from the
+ // endpoint.
IPv6TrafficClassOption
- // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
- // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
+ // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the
+ // current Maximum Segment Size(MSS) value as specified using the
+ // TCP_MAXSEG option.
MaxSegOption
- // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
- // TTL value for multicast messages. The default is 1.
+ // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control
+ // the default TTL value for multicast messages. The default is 1.
MulticastTTLOption
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
@@ -682,21 +693,22 @@ const (
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
- // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
- // limit value for unicast messages. The default is protocol specific.
+ // TTLOption is used by SetSockOptInt/GetSockOptInt to control the
+ // default TTL/hop limit value for unicast messages. The default is
+ // protocol specific.
//
// A zero value indicates the default.
TTLOption
- // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of
- // SYN retransmits that TCP should send before aborting the attempt to
- // connect. It cannot exceed 255.
+ // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify
+ // the number of SYN retransmits that TCP should send before aborting
+ // the attempt to connect. It cannot exceed 255.
//
// NOTE: This option is currently only stubbed out and is no-op.
TCPSynCountOption
- // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size
- // of the advertised window to this value.
+ // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound
+ // the size of the advertised window to this value.
//
// NOTE: This option is currently only stubed out and is a no-op
TCPWindowClampOption
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 8bdc1ee1f..cae29fbff 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -109,6 +109,7 @@ type endpoint struct {
portFlags ports.Flags
bindToDevice tcpip.NICID
broadcast bool
+ noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -529,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -553,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.multicastLoop = v
e.mu.Unlock()
+ case tcpip.NoChecksumOption:
+ e.mu.Lock()
+ e.noChecksum = v
+ e.mu.Unlock()
+
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
@@ -825,6 +831,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.NoChecksumOption:
+ e.mu.RLock()
+ v := e.noChecksum
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
@@ -959,7 +971,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -973,8 +985,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
Length: length,
})
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index ff9f60cf9..db59eb5a0 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1251,6 +1251,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
+func TestNoChecksum(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Disable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ // This option is effective on IPv4 only.
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
+
+ // Enable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
+ })
+ }
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index cc1db3de8..1d13432ca 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -1239,6 +1239,44 @@ TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) {
EXPECT_EQ(n, 0);
}
+TEST_P(UdpSocketTest, SoNoCheckOffByDefault) {
+ // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = -1;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
+TEST_P(UdpSocketTest, SoNoCheck) {
+ // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = kSockOptOn;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, optlen),
+ SyscallSucceeds());
+ v = -1;
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOn);
+ ASSERT_EQ(optlen, sizeof(v));
+
+ v = kSockOptOff;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, optlen),
+ SyscallSucceeds());
+ v = -1;
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
TEST_P(UdpSocketTest, SoTimestampOffByDefault) {
// TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by
// hostinet.
--
cgit v1.2.3
From 5946f111827fa4e342a2e6e9c043c198d2e5cb03 Mon Sep 17 00:00:00 2001
From: Bhasker Hariharan
Date: Thu, 9 Jul 2020 16:24:43 -0700
Subject: Add support for IP_HDRINCL IP option for raw sockets.
Updates #2746
Fixes #3158
PiperOrigin-RevId: 320497190
---
pkg/sentry/socket/netstack/netstack.go | 11 ++++-
pkg/tcpip/tcpip.go | 5 +++
pkg/tcpip/transport/raw/endpoint.go | 34 +++++++++++----
test/syscalls/linux/raw_socket_hdrincl.cc | 70 ++++++++++++++++++++++++++++++-
4 files changed, 110 insertions(+), 10 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e7d2c83d7..3b248a953 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2112,13 +2112,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ case linux.IP_HDRINCL:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
linux.IP_CHECKSUM,
linux.IP_DROP_SOURCE_MEMBERSHIP,
linux.IP_FREEBIND,
- linux.IP_HDRINCL,
linux.IP_IPSEC_POLICY,
linux.IP_MINTTL,
linux.IP_MSFILTER,
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 2f9872dc6..25534a10d 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -648,6 +648,11 @@ const (
// whether an IPv6 socket is to be restricted to sending and receiving
// IPv6 packets only.
V6OnlyOption
+
+ // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
+ // endpoint that all packets being written have an IP header and the
+ // endpoint should not attach an IP header.
+ IPHdrIncludedOption
)
// SockOptInt represents socket options which values have the int type.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 766c7648e..5b6e7d102 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -63,6 +63,7 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -108,6 +109,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
}
// Override with stack defaults.
@@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -263,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -353,7 +351,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if !e.associated {
+ if e.hdrIncluded {
if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
Data: buffer.View(payloadBytes).ToVectorisedView(),
}); err != nil {
@@ -513,6 +511,13 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
@@ -577,6 +582,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.KeepaliveEnabledOption:
return false, nil
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -616,8 +627,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc
index 16cfc1d75..5bb14d57c 100644
--- a/test/syscalls/linux/raw_socket_hdrincl.cc
+++ b/test/syscalls/linux/raw_socket_hdrincl.cc
@@ -167,7 +167,7 @@ TEST_F(RawHDRINCL, NotReadable) {
// nothing to be read.
char buf[117];
ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EINVAL));
+ SyscallFailsWithErrno(EAGAIN));
}
// Test that we can connect() to a valid IP (loopback).
@@ -332,6 +332,74 @@ TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK);
}
+// Send and receive a packet w/ the IP_HDRINCL option set.
+TEST_F(RawHDRINCL, SendAndReceiveIPHdrIncl) {
+ int port = 40000;
+ if (!IsRunningOnGvisor()) {
+ port = static_cast(ASSERT_NO_ERRNO_AND_VALUE(
+ PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
+ }
+
+ FileDescriptor recv_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ FileDescriptor send_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ // Enable IP_HDRINCL option so that we can build and send w/ an IP
+ // header.
+ constexpr int kSockOptOn = 1;
+ ASSERT_THAT(setsockopt(send_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ // This is not strictly required but we do it to make sure that setting
+ // IP_HDRINCL on a non IPPROTO_RAW socket does not prevent it from receiving
+ // packets.
+ ASSERT_THAT(setsockopt(recv_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Construct a packet with an IP header, UDP header, and payload.
+ constexpr char kPayload[] = "toto";
+ char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
+ ASSERT_TRUE(
+ FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
+
+ socklen_t addrlen = sizeof(addr_);
+ ASSERT_NO_FATAL_FAILURE(sendto(send_sock.get(), &packet, sizeof(packet), 0,
+ reinterpret_cast(&addr_),
+ addrlen));
+
+ // Receive the payload.
+ char recv_buf[sizeof(packet)];
+ struct sockaddr_in src;
+ socklen_t src_size = sizeof(src);
+ ASSERT_THAT(recvfrom(recv_sock.get(), recv_buf, sizeof(recv_buf), 0,
+ reinterpret_cast(&src), &src_size),
+ SyscallSucceedsWithValue(sizeof(packet)));
+ EXPECT_EQ(
+ memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr),
+ sizeof(kPayload)),
+ 0);
+ // The network stack should have set the source address.
+ EXPECT_EQ(src.sin_family, AF_INET);
+ EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
+ struct iphdr iphdr = {};
+ memcpy(&iphdr, recv_buf, sizeof(iphdr));
+ EXPECT_NE(iphdr.id, 0);
+
+ // Also verify that the packet we just sent was not delivered to the
+ // IPPROTO_RAW socket.
+ {
+ char recv_buf[sizeof(packet)];
+ struct sockaddr_in src;
+ socklen_t src_size = sizeof(src);
+ ASSERT_THAT(recvfrom(socket_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
+ reinterpret_cast(&src), &src_size),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
} // namespace
} // namespace testing
--
cgit v1.2.3
From 5df3a8fedef7e54550d4c6b7172e25216600ee9f Mon Sep 17 00:00:00 2001
From: gVisor bot
Date: Thu, 9 Jul 2020 22:33:53 -0700
Subject: Discard multicast UDP source address.
RFC-1122 (and others) specify that UDP should not receive
datagrams that have a source address that is a multicast address.
Packets should never be received FROM a multicast address.
See also, RFC 768: 'User Datagram Protocol'
J. Postel, ISI, 28 August 1980
A UDP datagram received with an invalid IP source address
(e.g., a broadcast or multicast address) must be discarded
by UDP or by the IP layer (see rfc 1122 Section 3.2.1.3).
This CL does not address TCP or broadcast which is more complicated.
Also adds a test for both ipv6 and ipv4 UDP.
Fixes #3154
PiperOrigin-RevId: 320547674
---
pkg/sentry/socket/netstack/netstack.go | 1 +
pkg/tcpip/tcpip.go | 5 +-
pkg/tcpip/transport/udp/endpoint.go | 11 +++-
pkg/tcpip/transport/udp/udp_test.go | 104 ++++++++++++++++++++++++++++-----
4 files changed, 103 insertions(+), 18 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 3b248a953..5a3cedd7c 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -192,6 +192,7 @@ var Metrics = tcpip.Stats{
PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
+ InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."),
},
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 25534a10d..cf7291d09 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -782,7 +782,7 @@ type CongestionControlOption string
// control algorithms.
type AvailableCongestionControlOption string
-// buffer moderation.
+// ModerateReceiveBufferOption is used by buffer moderation.
type ModerateReceiveBufferOption bool
// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
@@ -1244,6 +1244,9 @@ type UDPStats struct {
// ChecksumErrors is the number of datagrams dropped due to bad checksums.
ChecksumErrors *StatCounter
+
+ // InvalidSourceAddress is the number of invalid sourced datagrams dropped.
+ InvalidSourceAddress *StatCounter
}
// Stats holds statistics about the networking stack.
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 0584ec8dc..4e9e114a9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1377,6 +1377,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
+ // Never receive from a multicast address.
+ if header.IsV4MulticastAddress(id.RemoteAddress) ||
+ header.IsV6MulticastAddress(id.RemoteAddress) {
+ e.stack.Stats().UDP.InvalidSourceAddress.Increment()
+ e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+
// Verify checksum unless RX checksum offload is enabled.
// On IPv4, UDP checksum is optional, and a zero value means
// the transmitter omitted the checksum generation (RFC768).
@@ -1395,10 +1404,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
}
- e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 91ba031fa..90781cf49 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -83,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -117,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -872,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
+// TestReadFromMulticaststats checks that a discarded packet
+// that that was sent with multicast SOURCE address increments
+// the correct counters and that a regular packet does not.
+func TestReadFromMulticastStats(t *testing.T) {
+ t.Helper()
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ c.injectPacket(flow, payload)
+
+ var want uint64 = 0
+ if flow.isReverseMulticast() {
+ want = 1
+ }
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
+ t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
+ t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
--
cgit v1.2.3
From 216dcebc066c82907b0de790a77a3deb6a734805 Mon Sep 17 00:00:00 2001
From: Bhasker Hariharan
Date: Sat, 11 Jul 2020 06:21:34 -0700
Subject: Stub out SO_DETACH_FILTER.
Updates #2746
PiperOrigin-RevId: 320757963
---
pkg/sentry/socket/netstack/netstack.go | 5 +++
pkg/tcpip/tcpip.go | 5 ++-
pkg/tcpip/transport/icmp/endpoint.go | 4 ++
pkg/tcpip/transport/packet/endpoint.go | 8 +++-
pkg/tcpip/transport/raw/endpoint.go | 8 +++-
pkg/tcpip/transport/tcp/endpoint.go | 3 ++
pkg/tcpip/transport/udp/endpoint.go | 3 ++
test/syscalls/linux/BUILD | 3 ++
test/syscalls/linux/packet_socket_raw.cc | 34 ++++++++++++++++
test/syscalls/linux/raw_socket.cc | 37 +++++++++++++++--
test/syscalls/linux/tcp_socket.cc | 60 ++++++++++++++++++++++++++++
test/syscalls/linux/udp_socket_test_cases.cc | 54 +++++++++++++++++++++++++
12 files changed, 217 insertions(+), 7 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 5a3cedd7c..78a842973 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1754,6 +1754,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return nil
+ case linux.SO_DETACH_FILTER:
+ // optval is ignored.
+ var v tcpip.SocketDetachFilterOption
+ return syserr.TranslateNetstackError(ep.SetSockOpt(v))
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index cf7291d09..71bcee785 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -855,7 +855,10 @@ type OutOfBandInlineOption int
// a default TTL.
type DefaultTTLOption uint8
-//
+// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
+// classic BPF filter on a given endpoint.
+type SocketDetachFilterOption int
+
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 62d1acad4..678f4e016 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+ }
return nil
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index a8f8454dd..57b7f5c19 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -278,7 +278,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 5b6e7d102..c2e9fd29f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -506,7 +506,13 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index caac6ef57..83dc10ed0 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1792,6 +1792,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.deferAccept = time.Duration(v)
e.UnlockUser()
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
default:
return nil
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4e9e114a9..a14643ae8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -816,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
+
+ case tcpip.SocketDetachFilterOption:
+ return nil
}
return nil
}
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 9e097c888..662d780d8 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1330,6 +1330,7 @@ cc_binary(
name = "packet_socket_raw_test",
testonly = 1,
srcs = ["packet_socket_raw.cc"],
+ defines = select_system(),
linkstatic = 1,
deps = [
":socket_test_util",
@@ -1809,6 +1810,7 @@ cc_binary(
name = "raw_socket_test",
testonly = 1,
srcs = ["raw_socket.cc"],
+ defines = select_system(),
linkstatic = 1,
deps = [
":socket_test_util",
@@ -3407,6 +3409,7 @@ cc_binary(
name = "tcp_socket_test",
testonly = 1,
srcs = ["tcp_socket.cc"],
+ defines = select_system(),
linkstatic = 1,
deps = [
":socket_test_util",
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 4093ac813..6a963b12c 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -14,6 +14,9 @@
#include
#include
+#ifndef __fuchsia__
+#include
+#endif // __fuchsia__
#include
#include
#include
@@ -556,6 +559,37 @@ TEST_P(RawPacketTest, SetSocketSendBuf) {
ASSERT_EQ(quarter_sz, val);
}
+#ifndef __fuchsia__
+
+TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ //
+ // gVisor returns no error on SO_DETACH_FILTER even if there is no filter
+ // attached unlike linux which does return ENOENT in such cases. This is
+ // because gVisor doesn't support SO_ATTACH_FILTER and just silently returns
+ // success.
+ if (IsRunningOnGvisor()) {
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+ return;
+ }
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(RawPacketTest, GetSocketDetachFilter) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 05c4ed03f..ce54dc064 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -13,6 +13,9 @@
// limitations under the License.
#include
+#ifndef __fuchsia__
+#include
+#endif // __fuchsia__
#include
#include
#include
@@ -21,6 +24,7 @@
#include
#include
#include
+
#include
#include "gtest/gtest.h"
@@ -790,10 +794,30 @@ void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf,
ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len));
}
-INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest,
- ::testing::Combine(
- ::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
- ::testing::Values(AF_INET, AF_INET6)));
+#ifndef __fuchsia__
+
+TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ if (IsRunningOnGvisor()) {
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+ return;
+ }
+
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(RawSocketTest, GetSocketDetachFilter) {
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
// AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to.
TEST(RawSocketTest, IPv6ProtoRaw) {
@@ -813,6 +837,11 @@ TEST(RawSocketTest, IPv6ProtoRaw) {
SyscallFailsWithErrno(EINVAL));
}
+INSTANTIATE_TEST_SUITE_P(
+ AllInetTests, RawSocketTest,
+ ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
+ ::testing::Values(AF_INET, AF_INET6)));
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index a4d2953e1..0cea7d11f 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -13,6 +13,9 @@
// limitations under the License.
#include
+#ifndef __fuchsia__
+#include
+#endif // __fuchsia__
#include
#include
#include
@@ -1559,6 +1562,63 @@ TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) {
}
}
+#ifndef __fuchsia__
+
+// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+// gVisor currently silently ignores attaching a filter.
+TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // Program generated using sudo tcpdump -i lo tcp and port 1234 -dd
+ struct sock_filter code[] = {
+ {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd},
+ {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000006},
+ {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2},
+ {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2},
+ {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017},
+ {0x15, 0, 8, 0x00000006}, {0x28, 0, 0, 0x00000014},
+ {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e},
+ {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2},
+ {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2},
+ {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000},
+ };
+ struct sock_fprog bpf = {
+ .len = ABSL_ARRAYSIZE(code),
+ .filter = code,
+ };
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)),
+ SyscallSucceeds());
+
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+}
+
+TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ SKIP_IF(IsRunningOnGvisor());
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index 9cc6be4fb..60c48ed6e 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -16,6 +16,9 @@
#include
#include
+#ifndef __fuchsia__
+#include
+#endif // __fuchsia__
#include
#include
#include
@@ -1723,5 +1726,56 @@ TEST_P(UdpSocketTest, RecvBufLimits) {
}
}
+#ifndef __fuchsia__
+
+// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+// gVisor currently silently ignores attaching a filter.
+TEST_P(UdpSocketTest, SetSocketDetachFilter) {
+ // Program generated using sudo tcpdump -i lo udp and port 1234 -dd
+ struct sock_filter code[] = {
+ {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd},
+ {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011},
+ {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2},
+ {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2},
+ {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017},
+ {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014},
+ {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e},
+ {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2},
+ {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2},
+ {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000},
+ };
+ struct sock_fprog bpf = {
+ .len = ABSL_ARRAYSIZE(code),
+ .filter = code,
+ };
+ ASSERT_THAT(
+ setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)),
+ SyscallSucceeds());
+
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+}
+
+TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ SKIP_IF(IsRunningOnGvisor());
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(UdpSocketTest, GetSocketDetachFilter) {
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(
+ getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
+
} // namespace testing
} // namespace gvisor
--
cgit v1.2.3
From fef90c61c6186c113cfdb0bbcf53f4ca70f9741a Mon Sep 17 00:00:00 2001
From: Bhasker Hariharan
Date: Wed, 15 Jul 2020 14:13:42 -0700
Subject: Fix minor bugs in a couple of interface IOCTLs.
gVisor incorrectly returns the wrong ARP type for SIOGIFHWADDR. This breaks
tcpdump as it tries to interpret the packets incorrectly.
Similarly, SIOCETHTOOL is used by tcpdump to query interface properties which
fails with an EINVAL since we don't implement it. For now change it to return
EOPNOTSUPP to indicate that we don't support the query rather than return
EINVAL.
NOTE: ARPHRD types for link endpoints are distinct from NIC capabilities
and NIC flags. In Linux all 3 exist eg. ARPHRD types are stored in dev->type
field while NIC capabilities are more like the device features which can be
queried using SIOCETHTOOL but not modified and NIC Flags are fields that can
be modified from user space. eg. NIC status (UP/DOWN/MULTICAST/BROADCAST) etc.
Updates #2746
PiperOrigin-RevId: 321436525
---
pkg/abi/linux/ioctl.go | 27 +++++++++--
pkg/abi/linux/netlink_route.go | 2 +
pkg/sentry/socket/netstack/netstack.go | 74 ++++++++++++++++++------------
pkg/sentry/socket/netstack/stack.go | 22 ++++++---
pkg/tcpip/header/arp.go | 77 +++++++++++++++++++++-----------
pkg/tcpip/link/channel/BUILD | 1 +
pkg/tcpip/link/channel/channel.go | 6 +++
pkg/tcpip/link/fdbased/endpoint.go | 8 ++++
pkg/tcpip/link/loopback/loopback.go | 5 +++
pkg/tcpip/link/muxed/BUILD | 1 +
pkg/tcpip/link/muxed/injectable.go | 6 +++
pkg/tcpip/link/nested/BUILD | 1 +
pkg/tcpip/link/nested/nested.go | 6 +++
pkg/tcpip/link/qdisc/fifo/BUILD | 1 +
pkg/tcpip/link/qdisc/fifo/endpoint.go | 6 +++
pkg/tcpip/link/sharedmem/sharedmem.go | 5 +++
pkg/tcpip/link/tun/device.go | 10 +++++
pkg/tcpip/link/waitable/BUILD | 2 +
pkg/tcpip/link/waitable/waitable.go | 6 +++
pkg/tcpip/link/waitable/waitable_test.go | 6 +++
pkg/tcpip/network/ip_test.go | 9 +++-
pkg/tcpip/stack/forwarder_test.go | 6 +++
pkg/tcpip/stack/nic_test.go | 5 +++
pkg/tcpip/stack/registration.go | 7 +++
pkg/tcpip/stack/stack.go | 6 +++
test/syscalls/linux/socket_netdevice.cc | 23 ++++++++++
26 files changed, 263 insertions(+), 65 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 2062e6a4b..2c5e56ae5 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -67,10 +67,29 @@ const (
// ioctl(2) requests provided by uapi/linux/sockios.h
const (
- SIOCGIFMEM = 0x891f
- SIOCGIFPFLAGS = 0x8935
- SIOCGMIIPHY = 0x8947
- SIOCGMIIREG = 0x8948
+ SIOCGIFNAME = 0x8910
+ SIOCGIFCONF = 0x8912
+ SIOCGIFFLAGS = 0x8913
+ SIOCGIFADDR = 0x8915
+ SIOCGIFDSTADDR = 0x8917
+ SIOCGIFBRDADDR = 0x8919
+ SIOCGIFNETMASK = 0x891b
+ SIOCGIFMETRIC = 0x891d
+ SIOCGIFMTU = 0x8921
+ SIOCGIFMEM = 0x891f
+ SIOCGIFHWADDR = 0x8927
+ SIOCGIFINDEX = 0x8933
+ SIOCGIFPFLAGS = 0x8935
+ SIOCGIFTXQLEN = 0x8942
+ SIOCETHTOOL = 0x8946
+ SIOCGMIIPHY = 0x8947
+ SIOCGMIIREG = 0x8948
+ SIOCGIFMAP = 0x8970
+)
+
+// ioctl(2) requests provided by uapi/asm-generic/sockios.h
+const (
+ SIOCGSTAMP = 0x8906
)
// ioctl(2) directions. Used to calculate requests number.
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index 40bec566c..ceda0a8d3 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -187,6 +187,8 @@ const (
// Device types, from uapi/linux/if_arp.h.
const (
+ ARPHRD_NONE = 65534
+ ARPHRD_ETHER = 1
ARPHRD_LOOPBACK = 772
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 78a842973..0b1be1bd2 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2747,7 +2747,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
switch args[1].Int() {
- case syscall.SIOCGSTAMP:
+ case linux.SIOCGSTAMP:
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
@@ -2788,18 +2788,19 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
// Ioctl performs a socket ioctl.
func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch arg := int(args[1].Int()); arg {
- case syscall.SIOCGIFFLAGS,
- syscall.SIOCGIFADDR,
- syscall.SIOCGIFBRDADDR,
- syscall.SIOCGIFDSTADDR,
- syscall.SIOCGIFHWADDR,
- syscall.SIOCGIFINDEX,
- syscall.SIOCGIFMAP,
- syscall.SIOCGIFMETRIC,
- syscall.SIOCGIFMTU,
- syscall.SIOCGIFNAME,
- syscall.SIOCGIFNETMASK,
- syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFFLAGS,
+ linux.SIOCGIFADDR,
+ linux.SIOCGIFBRDADDR,
+ linux.SIOCGIFDSTADDR,
+ linux.SIOCGIFHWADDR,
+ linux.SIOCGIFINDEX,
+ linux.SIOCGIFMAP,
+ linux.SIOCGIFMETRIC,
+ linux.SIOCGIFMTU,
+ linux.SIOCGIFNAME,
+ linux.SIOCGIFNETMASK,
+ linux.SIOCGIFTXQLEN,
+ linux.SIOCETHTOOL:
var ifr linux.IFReq
if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
@@ -2815,7 +2816,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
})
return 0, err
- case syscall.SIOCGIFCONF:
+ case linux.SIOCGIFCONF:
// Return a list of interface addresses or the buffer size
// necessary to hold the list.
var ifc linux.IFConf
@@ -2889,7 +2890,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to
// identify a device.
- if arg == syscall.SIOCGIFNAME {
+ if arg == linux.SIOCGIFNAME {
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
@@ -2912,21 +2913,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
switch arg {
- case syscall.SIOCGIFINDEX:
+ case linux.SIOCGIFINDEX:
// Copy out the index to the data.
usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
- case syscall.SIOCGIFHWADDR:
+ case linux.SIOCGIFHWADDR:
// Copy the hardware address out.
- ifr.Data[0] = 6 // IEEE802.2 arp type.
- ifr.Data[1] = 0
+ //
+ // Refer: https://linux.die.net/man/7/netdevice
+ // SIOCGIFHWADDR, SIOCSIFHWADDR
+ //
+ // Get or set the hardware address of a device using
+ // ifr_hwaddr. The hardware address is specified in a struct
+ // sockaddr. sa_family contains the ARPHRD_* device type,
+ // sa_data the L2 hardware address starting from byte 0. Setting
+ // the hardware address is a privileged operation.
+ usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType)
n := copy(ifr.Data[2:], iface.Addr)
for i := 2 + n; i < len(ifr.Data); i++ {
ifr.Data[i] = 0 // Clear padding.
}
- usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
- case syscall.SIOCGIFFLAGS:
+ case linux.SIOCGIFFLAGS:
f, err := interfaceStatusFlags(stack, iface.Name)
if err != nil {
return err
@@ -2935,7 +2943,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// matches Linux behavior.
usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
- case syscall.SIOCGIFADDR:
+ case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2946,32 +2954,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
- case syscall.SIOCGIFMETRIC:
+ case linux.SIOCGIFMETRIC:
// Gets the metric of the device. As per netdevice(7), this
// always just sets ifr_metric to 0.
usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
- case syscall.SIOCGIFMTU:
+ case linux.SIOCGIFMTU:
// Gets the MTU of the device.
usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
- case syscall.SIOCGIFMAP:
+ case linux.SIOCGIFMAP:
// Gets the hardware parameters of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFTXQLEN:
// Gets the transmit queue length of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFDSTADDR:
+ case linux.SIOCGIFDSTADDR:
// Gets the destination address of a point-to-point device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFBRDADDR:
+ case linux.SIOCGIFBRDADDR:
// Gets the broadcast address of a device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFNETMASK:
+ case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2988,6 +2996,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
+ case linux.SIOCETHTOOL:
+ // Stubbed out for now, Ideally we should implement the required
+ // sub-commands for ETHTOOL
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c
+ return syserr.ErrEndpointOperation
+
default:
// Not a valid call.
return syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 548442b96..67737ae87 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -15,6 +15,8 @@
package netstack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -40,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool {
return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
}
+// Converts Netstack's ARPHardwareType to equivalent linux constants.
+func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 {
+ switch t {
+ case header.ARPHardwareNone:
+ return linux.ARPHRD_NONE
+ case header.ARPHardwareLoopback:
+ return linux.ARPHRD_LOOPBACK
+ case header.ARPHardwareEther:
+ return linux.ARPHRD_ETHER
+ default:
+ panic(fmt.Sprintf("unknown ARPHRD type: %d", t))
+ }
+}
+
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
is := make(map[int32]inet.Interface)
for id, ni := range s.Stack.NICInfo() {
- var devType uint16
- if ni.Flags.Loopback {
- devType = linux.ARPHRD_LOOPBACK
- }
is[int32(id)] = inet.Interface{
Name: ni.Name,
Addr: []byte(ni.LinkAddress),
Flags: uint32(nicStateFlagsToLinux(ni.Flags)),
- DeviceType: devType,
+ DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType),
MTU: ni.MTU,
}
}
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
index 718a4720a..83189676e 100644
--- a/pkg/tcpip/header/arp.go
+++ b/pkg/tcpip/header/arp.go
@@ -14,14 +14,33 @@
package header
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// ARPProtocolNumber is the ARP network protocol number.
ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
// ARPSize is the size of an IPv4-over-Ethernet ARP packet.
- ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+ ARPSize = 28
+)
+
+// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header.
+type ARPHardwareType uint16
+
+// Typical ARP HardwareType values. Some of the constants have to be specific
+// values as they are egressed on the wire in the HTYPE field of an ARP header.
+const (
+ ARPHardwareNone ARPHardwareType = 0
+ // ARPHardwareEther specifically is the HTYPE for Ethernet as specified
+ // in the IANA list here:
+ //
+ // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2
+ ARPHardwareEther ARPHardwareType = 1
+ ARPHardwareLoopback ARPHardwareType = 2
)
// ARPOp is an ARP opcode.
@@ -36,54 +55,64 @@ const (
// ARP is an ARP packet stored in a byte array as described in RFC 826.
type ARP []byte
-func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
-func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
-func (a ARP) hardwareAddressSize() int { return int(a[4]) }
-func (a ARP) protocolAddressSize() int { return int(a[5]) }
+const (
+ hTypeOffset = 0
+ protocolOffset = 2
+ haAddressSizeOffset = 4
+ protoAddressSizeOffset = 5
+ opCodeOffset = 6
+ senderHAAddressOffset = 8
+ senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize
+ targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize
+ targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize
+)
+
+func (a ARP) hardwareAddressType() ARPHardwareType {
+ return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:]))
+}
+
+func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) }
+func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) }
+func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) }
// Op is the ARP opcode.
-func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) }
// SetOp sets the ARP opcode.
func (a ARP) SetOp(op ARPOp) {
- a[6] = uint8(op >> 8)
- a[7] = uint8(op)
+ binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op))
}
// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
func (a ARP) SetIPv4OverEthernet() {
- a[0], a[1] = 0, 1 // htypeEthernet
- a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
- a[4] = 6 // macSize
- a[5] = uint8(IPv4AddressSize)
+ binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther))
+ binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber))
+ a[haAddressSizeOffset] = EthernetAddressSize
+ a[protoAddressSizeOffset] = uint8(IPv4AddressSize)
}
// HardwareAddressSender is the link address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressSender() []byte {
- const s = 8
- return a[s : s+6]
+ return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressSender is the protocol address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressSender() []byte {
- const s = 8 + 6
- return a[s : s+4]
+ return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize]
}
// HardwareAddressTarget is the link address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressTarget() []byte {
- const s = 8 + 6 + 4
- return a[s : s+6]
+ return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressTarget is the protocol address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressTarget() []byte {
- const s = 8 + 6 + 4 + 6
- return a[s : s+4]
+ return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize]
}
// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
@@ -91,10 +120,8 @@ func (a ARP) IsValid() bool {
if len(a) < ARPSize {
return false
}
- const htypeEthernet = 1
- const macSize = 6
- return a.hardwareAddressSpace() == htypeEthernet &&
+ return a.hardwareAddressType() == ARPHardwareEther &&
a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
- a.hardwareAddressSize() == macSize &&
+ a.hardwareAddressSize() == EthernetAddressSize &&
a.protocolAddressSize() == IPv4AddressSize
}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index b8b93e78e..39ca774ef 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 20b183da0..a2bb773d4 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -296,3 +297,8 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
e.q.RemoveNotify(handle)
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f34082e1a..32abe2a13 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -626,6 +626,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
// to the FD, but does not read from it. All reads come from injected packets.
type InjectableEndpoint struct {
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 568c6874f..3b17d8c28 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -113,3 +113,8 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
return nil
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareLoopback
+}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 82b441b79..e7493e5c5 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index c69d6b7e9..c305d9e86 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -18,6 +18,7 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -129,6 +130,11 @@ func (m *InjectableEndpoint) Wait() {
}
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unsupported operation")
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
index bdd5276ad..2cdb23475 100644
--- a/pkg/tcpip/link/nested/BUILD
+++ b/pkg/tcpip/link/nested/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 2998f9c4f..328bd048e 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -129,3 +130,8 @@ func (e *Endpoint) GSOMaxSize() uint32 {
}
return 0
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.child.ARPHardwareType()
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 054c213bc..1d0079bd6 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index b5dfb7850..c84fe1bb9 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -207,3 +208,8 @@ func (e *endpoint) Wait() {
e.wg.Wait()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 0374a2441..a36862c67 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -287,3 +287,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
e.completed.Done()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 6bc9033d0..47446efec 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
stack: s,
nicID: id,
name: name,
+ isTap: prefix == "tap",
}
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
@@ -348,6 +349,7 @@ type tunEndpoint struct {
stack *stack.Stack
nicID tcpip.NICID
name string
+ isTap bool
}
// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
@@ -356,3 +358,11 @@ func (e *tunEndpoint) DecRef() {
e.stack.RemoveNIC(e.nicID)
})
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.isTap {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 0956d2c65..ee84c3d96 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/gate",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
@@ -25,6 +26,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 949b3f2b2..24a8dc2eb 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -147,3 +148,8 @@ func (e *Endpoint) WaitDispatch() {
// Wait implements stack.LinkEndpoint.Wait.
func (e *Endpoint) Wait() {}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 63bf40562..ffb2354be 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -81,6 +82,11 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
return nil
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unimplemented")
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 7c8fb3e0a..a5b780ca2 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -172,14 +172,19 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
-func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testObject) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index a6546cef0..eefb4b07f 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -301,6 +302,11 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 31f865260..3bc9fd831 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -84,6 +84,11 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// An IPv6 NetworkEndpoint that throws away outgoing packets.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 5cbc946b6..f260eeb7f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -436,6 +437,12 @@ type LinkEndpoint interface {
// Wait will not block if the endpoint hasn't started any goroutines
// yet, even if it might later.
Wait()
+
+ // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint.
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
+ ARPHardwareType() header.ARPHardwareType
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0aa815447..2b7ece851 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1095,6 +1095,11 @@ type NICInfo struct {
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
+
+ // ARPHardwareType holds the ARP Hardware type of the NIC. This is the
+ // value sent in haType field of an ARP Request sent by this NIC and the
+ // value expected in the haType field of an ARP response.
+ ARPHardwareType header.ARPHardwareType
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -1126,6 +1131,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
Context: nic.context,
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc
index 15d4b85a7..5f8d7f981 100644
--- a/test/syscalls/linux/socket_netdevice.cc
+++ b/test/syscalls/linux/socket_netdevice.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include
#include
#include
#include
@@ -49,6 +50,7 @@ TEST(NetdeviceTest, Loopback) {
// Check that the loopback is zero hardware address.
ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds());
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_family, ARPHRD_LOOPBACK);
EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0);
EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0);
EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0);
@@ -178,6 +180,27 @@ TEST(NetdeviceTest, InterfaceMTU) {
EXPECT_GT(ifr.ifr_mtu, 0);
}
+TEST(NetdeviceTest, EthtoolGetTSInfo) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ethtool_ts_info tsi = {};
+ tsi.cmd = ETHTOOL_GET_TS_INFO; // Get NIC's Timestamping capabilities.
+
+ // Prepare the request.
+ struct ifreq ifr = {};
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ ifr.ifr_data = (void*)&tsi;
+
+ // Check that SIOCGIFMTU returns a nonzero MTU.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ return;
+ }
+ ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
--
cgit v1.2.3
From dcf6ddc2772b8fcf824f1f48e0281e1cc80b93ea Mon Sep 17 00:00:00 2001
From: Bhasker Hariharan
Date: Thu, 16 Jul 2020 18:38:28 -0700
Subject: Add support to return protocol in recvmsg for AF_PACKET.
Updates #173
PiperOrigin-RevId: 321690756
---
pkg/sentry/socket/netstack/netstack.go | 24 +++++++++++++++++++++---
pkg/tcpip/tcpip.go | 19 +++++++++++++++++++
pkg/tcpip/transport/packet/endpoint.go | 18 ++++++++++++++++--
test/syscalls/linux/packet_socket.cc | 1 +
test/syscalls/linux/packet_socket_raw.cc | 1 +
5 files changed, 58 insertions(+), 5 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 0b1be1bd2..49a04e613 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -297,8 +297,9 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
- sender tcpip.FullAddress
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+ linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When
@@ -447,8 +448,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = nil
s.sender = tcpip.FullAddress{}
+ s.linkPacketInfo = tcpip.LinkPacketInfo{}
- v, cms, err := s.Endpoint.Read(&s.sender)
+ var v buffer.View
+ var cms tcpip.ControlMessages
+ var err *tcpip.Error
+
+ switch e := s.Endpoint.(type) {
+ // The ordering of these interfaces matters. The most specific
+ // interfaces must be specified before the more generic Endpoint
+ // interface.
+ case tcpip.PacketEndpoint:
+ v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
+ case tcpip.Endpoint:
+ v, cms, err = e.Read(&s.sender)
+ }
if err != nil {
atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
@@ -2509,6 +2523,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = ConvertAddress(s.family, s.sender)
+ switch v := addr.(type) {
+ case *linux.SockAddrLink:
+ v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ }
}
if peek {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 71bcee785..48ad56d4d 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -549,6 +549,25 @@ type Endpoint interface {
SetOwner(owner PacketOwner)
}
+// LinkPacketInfo holds Link layer information for a received packet.
+//
+// +stateify savable
+type LinkPacketInfo struct {
+ // Protocol is the NetworkProtocolNumber for the packet.
+ Protocol NetworkProtocolNumber
+}
+
+// PacketEndpoint are additional methods that are only implemented by Packet
+// endpoints.
+type PacketEndpoint interface {
+ // ReadPacket reads a datagram/packet from the endpoint and optionally
+ // returns the sender and additional LinkPacketInfo.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
+}
+
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 92b487381..7b2083a09 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -45,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -151,8 +154,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -177,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// TODO(b/129292371): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -428,12 +440,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
}
if ep.cooked {
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index ce63adb23..e94ddcb77 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -193,6 +193,7 @@ void ReceiveMessage(int sock, int ifindex) {
EXPECT_EQ(src.sll_family, AF_PACKET);
EXPECT_EQ(src.sll_ifindex, ifindex);
EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
EXPECT_EQ(src.sll_addr[i], 0);
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index e44062475..2fca9fe4d 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -200,6 +200,7 @@ TEST_P(RawPacketTest, Receive) {
EXPECT_EQ(src.sll_family, AF_PACKET);
EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
EXPECT_EQ(src.sll_addr[i], 0);
--
cgit v1.2.3
From 71bf90c55bd888f9b9c493533ca5e4b2b4b3d21d Mon Sep 17 00:00:00 2001
From: Bhasker Hariharan
Date: Wed, 22 Jul 2020 15:12:56 -0700
Subject: Support for receiving outbound packets in AF_PACKET.
Updates #173
PiperOrigin-RevId: 322665518
---
pkg/abi/linux/socket.go | 9 +++
pkg/sentry/socket/netstack/netstack.go | 19 +++++
pkg/tcpip/link/channel/channel.go | 4 +
pkg/tcpip/link/fdbased/endpoint.go | 36 ++++-----
pkg/tcpip/link/fdbased/endpoint_test.go | 8 ++
pkg/tcpip/link/loopback/loopback.go | 3 +
pkg/tcpip/link/muxed/injectable.go | 4 +
pkg/tcpip/link/nested/nested.go | 15 ++++
pkg/tcpip/link/nested/nested_test.go | 4 +
pkg/tcpip/link/packetsocket/BUILD | 14 ++++
pkg/tcpip/link/packetsocket/endpoint.go | 50 +++++++++++++
pkg/tcpip/link/qdisc/fifo/endpoint.go | 12 +++
pkg/tcpip/link/sharedmem/sharedmem.go | 21 ++++--
pkg/tcpip/link/sharedmem/sharedmem_test.go | 4 +
pkg/tcpip/link/sniffer/sniffer.go | 5 ++
pkg/tcpip/link/tun/device.go | 43 +++++++----
pkg/tcpip/link/waitable/waitable.go | 14 ++++
pkg/tcpip/link/waitable/waitable_test.go | 9 +++
pkg/tcpip/network/ip_test.go | 5 ++
pkg/tcpip/stack/forwarder_test.go | 5 ++
pkg/tcpip/stack/nic.go | 30 ++++++--
pkg/tcpip/stack/nic_test.go | 5 ++
pkg/tcpip/stack/packet_buffer.go | 4 +
pkg/tcpip/stack/registration.go | 16 +++-
pkg/tcpip/tcpip.go | 25 +++++++
pkg/tcpip/transport/packet/endpoint.go | 52 +++++++++----
runsc/boot/BUILD | 1 +
runsc/boot/network.go | 4 +
test/syscalls/linux/packet_socket.cc | 116 +++++++++++++++++++++++++++++
29 files changed, 471 insertions(+), 66 deletions(-)
create mode 100644 pkg/tcpip/link/packetsocket/BUILD
create mode 100644 pkg/tcpip/link/packetsocket/endpoint.go
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 4a14ef691..95337c168 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -134,6 +134,15 @@ const (
SHUT_RDWR = 2
)
+// Packet types from
+const (
+ PACKET_HOST = 0 // To us
+ PACKET_BROADCAST = 1 // To all
+ PACKET_MULTICAST = 2 // To group
+ PACKET_OTHERHOST = 3 // To someone else
+ PACKET_OUTGOING = 4 // Outgoing of any type
+)
+
// Socket options from socket.h.
const (
SO_DEBUG = 1
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 49a04e613..964ec8414 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "fmt"
"io"
"math"
"reflect"
@@ -2468,6 +2469,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
}
+func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
+ switch pktType {
+ case tcpip.PacketHost:
+ return linux.PACKET_HOST
+ case tcpip.PacketOtherHost:
+ return linux.PACKET_OTHERHOST
+ case tcpip.PacketOutgoing:
+ return linux.PACKET_OUTGOING
+ case tcpip.PacketBroadcast:
+ return linux.PACKET_BROADCAST
+ case tcpip.PacketMulticast:
+ return linux.PACKET_MULTICAST
+ default:
+ panic(fmt.Sprintf("unknown packet type: %d", pktType))
+ }
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -2526,6 +2544,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
switch v := addr.(type) {
case *linux.SockAddrLink:
v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a2bb773d4..e12a5929b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -302,3 +302,7 @@ func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 6aa1badc7..c18bb91fb 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -386,26 +386,33 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
var builder iovec.Builder
@@ -448,22 +455,8 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
- var eth header.Ethernet
if e.hdrSize > 0 {
- // Add ethernet header if needed.
- eth = make(header.Ethernet, header.EthernetMinimumSize)
- ethHdr := &header.EthernetFields{
- DstAddr: pkt.EgressRoute.RemoteLinkAddress,
- Type: pkt.NetworkProtocolNumber,
- }
-
- // Preserve the src address if it's set in the route.
- if pkt.EgressRoute.LocalLinkAddress != "" {
- ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
@@ -493,7 +486,6 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
var builder iovec.Builder
builder.Add(vnetHdrBuf)
- builder.Add(eth)
builder.Add(pkt.Header.View())
for _, v := range pkt.Data.Views() {
builder.Add(v)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 4bad930c7..7b995b85a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin
c.ch <- packetInfo{remote, protocol, pkt}
}
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
@@ -510,6 +514,10 @@ func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAdd
d.pkts = append(d.pkts, pkt)
}
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestDispatchPacketFormat(t *testing.T) {
for _, test := range []struct {
name string
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 3b17d8c28..781cdd317 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -118,3 +118,6 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index c305d9e86..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -135,6 +135,10 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unsupported operation")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 328bd048e..d40de54df 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -61,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -135,3 +145,8 @@ func (e *Endpoint) GSOMaxSize() uint32 {
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.child.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
index c1a219f02..7d9249c1c 100644
--- a/pkg/tcpip/link/nested/nested_test.go
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd
d.count++
}
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNestedLinkEndpoint(t *testing.T) {
const emptyAddress = tcpip.LinkAddress("")
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index c84fe1bb9..467083239 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -107,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
@@ -194,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267/): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
@@ -213,3 +220,8 @@ func (e *endpoint) Wait() {
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index a36862c67..507c76b76 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Add the ethernet header here.
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
v := pkt.Data.ToView()
// Transmit the packet.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 28a2e88ba..8f3cd9449 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index d9cd4e83a..509076643 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 47446efec..04ae58e59 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -272,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader == nil {
- hdr := &header.EthernetFields{
- SrcAddr: info.Route.LocalLinkAddress,
- DstAddr: info.Route.RemoteLinkAddress,
- Type: info.Proto,
- }
- if hdr.SrcAddr == "" {
- hdr.SrcAddr = d.endpoint.LinkAddress()
- }
-
- eth := make(header.Ethernet, header.EthernetMinimumSize)
- eth.Encode(hdr)
- vv.AppendView(buffer.View(eth))
- } else {
- vv.AppendView(info.Pkt.LinkHeader)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
+ vv.AppendView(info.Pkt.LinkHeader)
}
// Append upper headers.
@@ -366,3 +354,30 @@ func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
}
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 24a8dc2eb..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -60,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -153,3 +162,8 @@ func (e *Endpoint) Wait() {}
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index ffb2354be..c448a888f 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -40,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress,
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -90,6 +94,11 @@ func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a5b780ca2..615bae648 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -185,6 +185,11 @@ func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index eefb4b07f..bca1d940b 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -307,6 +307,11 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7b80534e6..fea0ce7e8 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
// TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 3bc9fd831..c477e31d8 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -89,6 +89,11 @@ func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// An IPv6 NetworkEndpoint that throws away outgoing packets.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index e3556d5d2..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -79,6 +79,10 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index f260eeb7f..cd4b7a449 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -330,8 +330,7 @@ type NetworkProtocol interface {
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -342,6 +341,16 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -443,6 +452,9 @@ type LinkEndpoint interface {
// See:
// https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 48ad56d4d..ff14a3b3c 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -316,6 +316,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -555,6 +577,9 @@ type Endpoint interface {
type LinkPacketInfo struct {
// Protocol is the NetworkProtocolNumber for the packet.
Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
}
// PacketEndpoint are additional methods that are only implemented by Packet
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 7b2083a09..8f167391f 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -441,6 +441,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(hdr.SourceAddress()),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
@@ -448,30 +449,53 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(localAddr),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header from the Header.
+ pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):])
+ combinedVV := pkt.Header.View().ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var combinedVV buffer.VectorisedView
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(ðFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(ðFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ combinedVV = linkHeader.ToVectorisedView()
+ }
+ if pkt.PktType == tcpip.PacketOutgoing {
+ // For outgoing packets the Link, Network and Transport
+ // headers are in the pkt.Header fields normally unless
+ // a Raw socket is in use in which case pkt.Header could
+ // be nil.
+ combinedVV.AppendView(pkt.Header.View())
}
- combinedVV := linkHeader.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
}
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 55d45aaa6..9f52438c2 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -90,6 +90,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/packetsocket",
"//pkg/tcpip/link/qdisc/fifo",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 14d2f56a5..4e1fa7665 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/packetsocket"
"gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
@@ -252,6 +253,9 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000)
}
+ // Enable support for AF_PACKET sockets to receive outgoing packets.
+ linkEP = packetsocket.New(linkEP)
+
log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
return err
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index e94ddcb77..40aa9326d 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -417,6 +417,122 @@ TEST_P(CookedPacketTest, BindDrop) {
EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
}
+// Verify that we receive outbound packets. This test requires at least one
+// non loopback interface so that we can actually capture an outgoing packet.
+TEST_P(CookedPacketTest, ReceiveOutbound) {
+ // Only ETH_P_ALL sockets can receive outbound packets on linux.
+ SKIP_IF(GetParam() != ETH_P_ALL);
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index and name.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+ int ifindex = ifr.ifr_ifindex;
+
+ constexpr int kMACSize = 6;
+ char hwaddr[kMACSize];
+ // Get interface address.
+ ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds());
+ ASSERT_THAT(ifr.ifr_hwaddr.sa_family,
+ AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER)));
+ memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize);
+
+ // Just send it to the google dns server 8.8.8.8. It's UDP we don't care
+ // if it actually gets to the DNS Server we just want to see that we receive
+ // it on our AF_PACKET socket.
+ //
+ // NOTE: We just want to pick an IP that is non-local to avoid having to
+ // handle ARP as this should cause the UDP packet to be sent to the default
+ // gateway configured for the system under test. Otherwise the only packet we
+ // will see is the ARP query unless we picked an IP which will actually
+ // resolve. The test is a bit brittle but this was the best compromise for
+ // now.
+ struct sockaddr_in dest = {};
+ ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket receives the data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1));
+
+ // Now read and check that the packet is the one we just sent.
+ // Read and verify the data.
+ constexpr size_t packet_size =
+ sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
+ char buf[64];
+ struct sockaddr_ll src = {};
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ reinterpret_cast(&src), &src_len),
+ SyscallSucceedsWithValue(packet_size));
+
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+
+ // Verify the source address.
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, ifindex);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
+ EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING);
+ // Verify the link address of the interface matches that of the non
+ // non loopback interface address we stored above.
+ for (int i = 0; i < src.sll_halen; i++) {
+ EXPECT_EQ(src.sll_addr[i], hwaddr[i]);
+ }
+
+ // Verify the IP header.
+ struct iphdr ip = {};
+ memcpy(&ip, buf, sizeof(ip));
+ EXPECT_EQ(ip.ihl, 5);
+ EXPECT_EQ(ip.version, 4);
+ EXPECT_EQ(ip.tot_len, htons(packet_size));
+ EXPECT_EQ(ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr);
+ EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK));
+
+ // Verify the UDP header.
+ struct udphdr udp = {};
+ memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
+ EXPECT_EQ(udp.dest, kPort);
+ EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
+
+ // Verify the payload.
+ char* payload = reinterpret_cast(buf + sizeof(iphdr) + sizeof(udphdr));
+ EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
+}
+
// Bind with invalid address.
TEST_P(CookedPacketTest, BindFail) {
// Null address.
--
cgit v1.2.3
From 6f7f73996791bbab6b63c248000df0d3ce652f2b Mon Sep 17 00:00:00 2001
From: Ayush Ranjan
Date: Thu, 23 Jul 2020 11:43:40 -0700
Subject: Marshallable socket opitons.
Socket option values are now required to implement marshal.Marshallable.
Co-authored-by: Rahat Mahmood
PiperOrigin-RevId: 322831612
---
pkg/abi/linux/BUILD | 3 +
pkg/abi/linux/netfilter.go | 146 ++++++++++++++++++--
pkg/abi/linux/socket.go | 8 ++
pkg/sentry/socket/BUILD | 1 +
pkg/sentry/socket/hostinet/BUILD | 2 +
pkg/sentry/socket/hostinet/socket.go | 7 +-
pkg/sentry/socket/netfilter/netfilter.go | 28 ++--
pkg/sentry/socket/netlink/BUILD | 2 +
pkg/sentry/socket/netlink/socket.go | 14 +-
pkg/sentry/socket/netstack/BUILD | 2 +
pkg/sentry/socket/netstack/netstack.go | 205 ++++++++++++++++++----------
pkg/sentry/socket/netstack/netstack_vfs2.go | 16 ++-
pkg/sentry/socket/socket.go | 3 +-
pkg/sentry/socket/unix/BUILD | 1 +
pkg/sentry/socket/unix/unix.go | 3 +-
pkg/sentry/socket/unix/unix_vfs2.go | 3 +-
pkg/sentry/syscalls/linux/BUILD | 2 +
pkg/sentry/syscalls/linux/sys_socket.go | 17 ++-
pkg/sentry/syscalls/linux/vfs2/BUILD | 2 +
pkg/sentry/syscalls/linux/vfs2/socket.go | 17 ++-
tools/go_marshal/README.md | 8 +-
tools/go_marshal/primitive/primitive.go | 72 ++++++++++
22 files changed, 427 insertions(+), 135 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 2b789c4ec..a4bb62013 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -72,6 +72,9 @@ go_library(
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 46d8b0b42..a91f9f018 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -14,6 +14,14 @@
package linux
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
// This file contains structures required to support netfilter, specifically
// the iptables tool.
@@ -76,6 +84,8 @@ const (
// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTEntry struct {
// IP is used to filter packets based on the IP header.
IP IPTIP
@@ -112,21 +122,41 @@ type IPTEntry struct {
// SizeOfIPTEntry is the size of an IPTEntry.
const SizeOfIPTEntry = 112
-// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This
-// struct marshaled via the binary package to write an IPTEntry to userspace.
+// KernelIPTEntry is identical to IPTEntry, but includes the Elems field.
+// KernelIPTEntry itself is not Marshallable but it implements some methods of
+// marshal.Marshallable that help in other implementations of Marshallable.
type KernelIPTEntry struct {
- IPTEntry
+ Entry IPTEntry
// Elems holds the data for all this rule's matches followed by the
// target. It is variable length -- users have to iterate over any
// matches and use TargetOffset and NextOffset to make sense of the
// data.
- Elems []byte
+ Elems primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTEntry) SizeBytes() int {
+ return ke.Entry.SizeBytes() + ke.Elems.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
+ ke.Entry.MarshalBytes(dst)
+ ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
+ ke.Entry.UnmarshalBytes(src)
+ ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
// IPTIP contains information for matching a packet's IP header.
// It corresponds to struct ipt_ip in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTIP struct {
// Src is the source IP address.
Src InetAddr
@@ -189,6 +219,8 @@ const SizeOfIPTIP = 84
// XTCounters holds packet and byte counts for a rule. It corresponds to struct
// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTCounters struct {
// Pcnt is the packet count.
Pcnt uint64
@@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetinfo struct {
Name TableName
ValidHooks uint32
@@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetEntries struct {
Name TableName
Size uint32
@@ -350,13 +386,103 @@ type IPTGetEntries struct {
const SizeOfIPTGetEntries = 40
// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
-// Entrytable field. This struct marshaled via the binary package to write an
-// KernelIPTGetEntries to userspace.
+// Entrytable field. This has been manually made marshal.Marshallable since it
+// is dynamically sized.
type KernelIPTGetEntries struct {
IPTGetEntries
Entrytable []KernelIPTEntry
}
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTGetEntries) SizeBytes() int {
+ res := ke.IPTGetEntries.SizeBytes()
+ for _, entry := range ke.Entrytable {
+ res += entry.SizeBytes()
+ }
+ return res
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
+ ke.IPTGetEntries.MarshalBytes(dst)
+ marshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
+ marshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
+ ke.IPTGetEntries.UnmarshalBytes(src)
+ unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
+ unmarshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (ke *KernelIPTGetEntries) Packed() bool {
+ // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an
+ // indirection to the actual data we want to marshal (the slice data
+ // pointer), and the memory for KernelIPTGetEntries contains the slice
+ // header which we don't want to marshal.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) {
+ // Fall back to safe Marshal because the type in not packed.
+ ke.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
+ // Fall back to safe Unmarshal because the type in not packed.
+ ke.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ ke.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task))
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+}
+
+func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
+ buf := task.CopyScratchBuffer(ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ return buf
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) {
+ buf := make([]byte, ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ length, err := w.Write(buf)
+ return int64(length), err
+}
+
+var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
+
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
@@ -374,12 +500,6 @@ type IPTReplace struct {
// Entries [0]IPTEntry
}
-// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
-type KernelIPTReplace struct {
- IPTReplace
- Entries [0]IPTEntry
-}
-
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
@@ -392,6 +512,8 @@ func (en ExtensionName) String() string {
}
// TableName holds the name of a netfilter table.
+//
+// +marshal
type TableName [XT_TABLE_MAXNAMELEN]byte
// String implements fmt.Stringer.
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 95337c168..c24a8216e 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -234,6 +234,8 @@ const (
const SockAddrMax = 128
// InetAddr is struct in_addr, from uapi/linux/in.h.
+//
+// +marshal
type InetAddr [4]byte
// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
@@ -303,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
// Linger is struct linger, from include/linux/socket.h.
+//
+// +marshal
type Linger struct {
OnOff int32
Linger int32
@@ -317,6 +321,8 @@ const SizeOfLinger = 8
// the end of this struct or within existing unusued space, so its size grows
// over time. The current iteration is based on linux v4.17. New versions are
// always backwards compatible.
+//
+// +marshal
type TCPInfo struct {
State uint8
CaState uint8
@@ -414,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
// ControlMessageCredentials represents struct ucred from linux/socket.h.
+//
+// +marshal
type ControlMessageCredentials struct {
PID int32
UID uint32
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c40c6d673..c0fd3425b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,5 +20,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index ff81ea6e6..e76e498de 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -40,6 +40,8 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index a92aed2c9..ec5506efc 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -36,6 +36,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
@@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
if err != nil {
return nil, syserr.FromError(err)
}
- return opt, nil
+ optP := primitive.ByteSlice(opt)
+ return &optP, nil
}
// SetSockOpt implements socket.Socket.SetSockOpt.
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 1243143ea..d9394055d 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
// Each rule corresponds to an entry.
entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
+ Entry: linux.IPTEntry{
IP: linux.IPTIP{
Protocol: uint16(rule.Filter.Protocol),
},
@@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
TargetOffset: linux.SizeOfIPTEntry,
},
}
- copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
- copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
- copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src)
- copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask)
- copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
- copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IP.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
if rule.Filter.DstInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
}
if rule.Filter.SrcInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
}
if rule.Filter.OutputInterfaceInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
}
for _, matcher := range rule.Matchers {
@@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
}
// Serialize and append the target.
@@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
nflog("convert to binary: adding entry: %+v", entry)
- entries.Size += uint32(entry.NextOffset)
+ entries.Size += uint32(entry.Entry.NextOffset)
entries.Entrytable = append(entries.Entrytable, entry)
info.NumEntries++
}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index d5ca3ac56..0546801bf 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -36,6 +36,8 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 81f34c5a2..98ca7add0 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -38,6 +38,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
@@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- return int32(s.sendBufferSize), nil
+ sendBufferSizeP := primitive.Int32(s.sendBufferSize)
+ return &sendBufferSizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- return int32(math.MaxInt32), nil
+ recvBufferSizeP := primitive.Int32(math.MaxInt32)
+ return &recvBufferSizeP, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var passcred int32
+ var passcred primitive.Int32
if s.Passcred() {
passcred = 1
}
- return passcred, nil
+ return &passcred, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ea6ebd0e2..1fb777a6c 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 964ec8414..9856ab8c5 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -62,6 +62,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -910,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -920,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -956,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -971,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -981,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -1014,7 +1016,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1025,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// Get the last error and convert it.
err := ep.GetSockOpt(tcpip.ErrorOption{})
if err == nil {
- return int32(0), nil
+ optP := primitive.Int32(0)
+ return &optP, nil
}
- return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number())
+ return &optP, nil
case linux.SO_PEERCRED:
if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
@@ -1035,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
tcred := t.Credentials()
- return syscall.Ucred{
- Pid: int32(t.ThreadGroup().ID()),
- Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
- Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
- }, nil
+ creds := linux.ControlMessageCredentials{
+ PID: int32(t.ThreadGroup().ID()),
+ UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }
+ return &creds, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -1050,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1066,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
@@ -1082,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_REUSEADDR:
if outLen < sizeOfInt32 {
@@ -1093,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
@@ -1104,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1112,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
if v == 0 {
- return []byte{}, nil
+ var b primitive.ByteSlice
+ return &b, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
@@ -1127,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// interface was removed.
return nil, syserr.ErrUnknownDevice
}
- return append([]byte(nic.Name), 0), nil
+
+ name := primitive.ByteSlice(append([]byte(nic.Name), 0))
+ return &name, nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1138,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
@@ -1149,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- return linux.Linger{}, nil
+
+ linger := linux.Linger{}
+ return &linger, nil
case linux.SO_SNDTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1163,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.SendTimeout()), nil
+ sendTimeout := linux.NsecToTimeval(s.SendTimeout())
+ return &sendTimeout, nil
case linux.SO_RCVTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1171,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.RecvTimeout()), nil
+ recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
+ return &recvTimeout, nil
case linux.SO_OOBINLINE:
if outLen < sizeOfInt32 {
@@ -1183,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1194,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1203,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
@@ -1214,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(!v), nil
+
+ vP := primitive.Int32(boolToInt32(!v))
+ return &vP, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
@@ -1225,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
@@ -1236,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1247,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
@@ -1259,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveIdle, nil
case linux.TCP_KEEPINTVL:
if outLen < sizeOfInt32 {
@@ -1271,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveInterval, nil
case linux.TCP_KEEPCNT:
if outLen < sizeOfInt32 {
@@ -1283,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
@@ -1295,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Millisecond), nil
+ tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond)
+ return &tcpUserTimeout, nil
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
@@ -1309,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
info := linux.TCPInfo{}
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &info)
- if len(ib) > outLen {
- ib = ib[:outLen]
+ buf := t.CopyScratchBuffer(info.SizeBytes())
+ info.MarshalUnsafe(buf)
+ if len(buf) > outLen {
+ buf = buf[:outLen]
}
-
- return ib, nil
+ bufP := primitive.ByteSlice(buf)
+ return &bufP, nil
case linux.TCP_CC_INFO,
linux.TCP_NOTSENT_LOWAT,
@@ -1344,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
b := make([]byte, toCopy)
copy(b, v)
- return b, nil
+
+ bP := primitive.ByteSlice(b)
+ return &bP, nil
case linux.TCP_LINGER2:
if outLen < sizeOfInt32 {
@@ -1356,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
if outLen < sizeOfInt32 {
@@ -1368,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second)
+ return &tcpDeferAccept, nil
case linux.TCP_SYNCNT:
if outLen < sizeOfInt32 {
@@ -1379,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_WINDOW_CLAMP:
if outLen < sizeOfInt32 {
@@ -1391,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1400,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1411,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1419,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_TCLASS:
// Length handling for parity with Linux.
if outLen == 0 {
- return make([]byte, 0), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- uintv := uint32(v)
+ uintv := primitive.Uint32(v)
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ ib := t.CopyScratchBuffer(uintv.SizeBytes())
+ uintv.MarshalUnsafe(ib)
// Handle cases where outLen is lesser than sizeOfInt32.
if len(ib) > outLen {
ib = ib[:outLen]
}
- return ib, nil
+ ibP := primitive.ByteSlice(ib)
+ return &ibP, nil
case linux.IPV6_RECVTCLASS:
if outLen < sizeOfInt32 {
@@ -1444,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1453,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1466,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
// Fill in the default value, if needed.
- if v == 0 {
- v = DefaultTTL
+ vP := primitive.Int32(v)
+ if vP == 0 {
+ vP = DefaultTTL
}
- return int32(v), nil
+ return &vP, nil
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
@@ -1482,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_MULTICAST_IF:
if outLen < len(linux.InetAddr{}) {
@@ -1496,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(*linux.SockAddrInet).Addr, nil
+ return &a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1507,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
- return []byte(nil), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
- return uint8(v), nil
+ vP := primitive.Uint8(v)
+ return &vP, nil
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_RECVTOS:
if outLen < sizeOfInt32 {
@@ -1532,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1543,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIP(t, name)
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index d65a89316..a9025b0ec 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -31,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketVFS2 rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fcd7f9d7f..d112757fb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
@@ -86,7 +87,7 @@ type SocketOps interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cca5e70f1..061a689a9 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -35,5 +35,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4bb2b6ff4..0482d33cf 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -40,6 +40,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index ff2149250..05c16fcfe 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
@@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 217fcfef2..4a9b04fd0 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -99,5 +99,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 0760af77b..414fce8e3 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -29,6 +29,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 0c740335b..64696b438 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -72,5 +72,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 10b668477..8096a8f9c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// minListenBacklog is the minimum reasonable backlog for listening sockets.
@@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
index 4886efddf..68d759083 100644
--- a/tools/go_marshal/README.md
+++ b/tools/go_marshal/README.md
@@ -9,11 +9,9 @@ automatically generating code to marshal go data structures to memory.
`binary.Marshal` by moving the go runtime reflection necessary to marshal a
struct to compile-time.
-`go_marshal` automatically generates implementations for `abi.Marshallable` and
-`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall
-implementations) can directly invoke `safemem.Reader.ReadToBlocks` and
-`safemem.Writer.WriteFromBlocks`. Data structures that require custom
-serialization will have manual implementations for these interfaces.
+`go_marshal` automatically generates implementations for `marshal.Marshallable`
+and `safemem.{Reader,Writer}`. Data structures that require custom serialization
+will have manual implementations for these interfaces.
Data structures can be flagged for code generation by adding a struct-level
comment `// +marshal`.
diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go
index ebcf130ae..d93edda8b 100644
--- a/tools/go_marshal/primitive/primitive.go
+++ b/tools/go_marshal/primitive/primitive.go
@@ -17,10 +17,22 @@
package primitive
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/tools/go_marshal/marshal"
)
+// Int8 is a marshal.Marshallable implementation for int8.
+//
+// +marshal slice:Int8Slice:inner
+type Int8 int8
+
+// Uint8 is a marshal.Marshallable implementation for uint8.
+//
+// +marshal slice:Uint8Slice:inner
+type Uint8 uint8
+
// Int16 is a marshal.Marshallable implementation for int16.
//
// +marshal slice:Int16Slice:inner
@@ -51,6 +63,66 @@ type Int64 int64
// +marshal slice:Uint64Slice:inner
type Uint64 uint64
+// ByteSlice is a marshal.Marshallable implementation for []byte.
+// This is a convenience wrapper around a dynamically sized type, and can't be
+// embedded in other marshallable types because it breaks assumptions made by
+// go-marshal internals. It violates the "no dynamically-sized types"
+// constraint of the go-marshal library.
+type ByteSlice []byte
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (b *ByteSlice) SizeBytes() int {
+ return len(*b)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (b *ByteSlice) MarshalBytes(dst []byte) {
+ copy(dst, *b)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (b *ByteSlice) UnmarshalBytes(src []byte) {
+ copy(*b, src)
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (b *ByteSlice) Packed() bool {
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (b *ByteSlice) MarshalUnsafe(dst []byte) {
+ b.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (b *ByteSlice) UnmarshalUnsafe(src []byte) {
+ b.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (b *ByteSlice) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ return task.CopyInBytes(addr, *b)
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (b *ByteSlice) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ return task.CopyOutBytes(addr, *b)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (b *ByteSlice) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ return task.CopyOutBytes(addr, (*b)[:limit])
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) {
+ n, err := w.Write(*b)
+ return int64(n), err
+}
+
+var _ marshal.Marshallable = (*ByteSlice)(nil)
+
// Below, we define some convenience functions for marshalling primitive types
// using the newtypes above, without requiring superfluous casts.
--
cgit v1.2.3
From e2c70ee9814f0f76ab5c30478748e4c697e91f33 Mon Sep 17 00:00:00 2001
From: Ayush Ranjan
Date: Fri, 24 Jul 2020 01:24:16 -0700
Subject: Enable automated marshalling for netstack.
PiperOrigin-RevId: 322954792
---
pkg/abi/linux/netdevice.go | 4 +++
pkg/sentry/fs/file_operations.go | 1 +
pkg/sentry/socket/netfilter/netfilter.go | 4 +--
pkg/sentry/socket/netstack/netstack.go | 54 ++++++++++++++------------------
4 files changed, 31 insertions(+), 32 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go
index 7866352b4..0faf015c7 100644
--- a/pkg/abi/linux/netdevice.go
+++ b/pkg/abi/linux/netdevice.go
@@ -22,6 +22,8 @@ const (
)
// IFReq is an interface request.
+//
+// +marshal
type IFReq struct {
// IFName is an encoded name, normally null-terminated. This should be
// accessed via the Name and SetName functions.
@@ -79,6 +81,8 @@ type IFMap struct {
// IFConf is used to return a list of interfaces and their addresses. See
// netdevice(7) and struct ifconf for more detail on its use.
+//
+// +marshal
type IFConf struct {
Len int32
_ [4]byte // Pad to sizeof(struct ifconf).
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index beba0f771..f5537411e 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -160,6 +160,7 @@ type FileOperations interface {
// refer.
//
// Preconditions: The AddressSpace (if any) that io refers to is activated.
+ // Must only be called from a task goroutine.
Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error)
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index d9394055d..a9f0604ae 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -66,7 +66,7 @@ func nflog(format string, args ...interface{}) {
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
var info linux.IPTGetinfo
- if _, err := t.CopyIn(outPtr, &info); err != nil {
+ if _, err := info.CopyIn(t, outPtr); err != nil {
return linux.IPTGetinfo{}, syserr.FromError(err)
}
@@ -84,7 +84,7 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
- if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ if _, err := userEntries.CopyIn(t, outPtr); err != nil {
nflog("couldn't copy in entries %q", userEntries.Name)
return linux.KernelIPTGetEntries{}, syserr.FromError(err)
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9856ab8c5..44b3fff46 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2835,6 +2835,11 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
// SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
@@ -2847,9 +2852,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
}
tv := linux.NsecToTimeval(s.timestampNS)
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2868,9 +2871,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
}
@@ -2879,6 +2881,11 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
// Ioctl performs a socket ioctl.
func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
switch arg := int(args[1].Int()); arg {
case linux.SIOCGIFFLAGS,
linux.SIOCGIFADDR,
@@ -2895,37 +2902,28 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
linux.SIOCETHTOOL:
var ifr linux.IFReq
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil {
return 0, err.ToError()
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := ifr.CopyOut(t, args[2].Pointer())
return 0, err
case linux.SIOCGIFCONF:
// Return a list of interface addresses or the buffer size
// necessary to hold the list.
var ifc linux.IFConf
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
- if err := ifconfIoctl(ctx, io, &ifc); err != nil {
+ if err := ifconfIoctl(ctx, t, io, &ifc); err != nil {
return 0, err
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
+ _, err := ifc.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2938,9 +2936,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
v = math.MaxInt32
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCOUTQ:
@@ -2954,9 +2951,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
@@ -3105,7 +3101,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
@@ -3142,9 +3138,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
// Copy the ifr to userspace.
dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil {
return err
}
}
--
cgit v1.2.3
From f82dd8ddb477b8923d1db12654a62d55d613019c Mon Sep 17 00:00:00 2001
From: Fabricio Voznika
Date: Tue, 28 Jul 2020 21:22:52 -0700
Subject: Redirect TODO to GitHub issues
PiperOrigin-RevId: 323715260
---
pkg/sentry/socket/netstack/netstack.go | 4 ++--
pkg/tcpip/stack/nic.go | 2 +-
pkg/tcpip/stack/stack_test.go | 6 +++---
pkg/tcpip/transport/packet/endpoint.go | 4 ++--
test/syscalls/linux/packet_socket.cc | 4 ++--
test/syscalls/linux/packet_socket_raw.cc | 4 ++--
6 files changed, 12 insertions(+), 12 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 44b3fff46..f86e6cd7a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -423,7 +423,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- // TODO(b/129292371): Return protocol too.
+ // TODO(gvisor.dev/issue/173): Return protocol too.
return tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
@@ -2418,7 +2418,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(sockAddrInet6Size)
case linux.AF_PACKET:
- // TODO(b/129292371): Return protocol too.
+ // TODO(gvisor.dev/issue/173): Return protocol too.
var out linux.SockAddrLink
out.Family = linux.AF_PACKET
out.InterfaceIndex = int32(addr.NIC)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index fea0ce7e8..9256d4d43 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -181,7 +181,7 @@ func (n *NIC) disableLocked() *tcpip.Error {
return nil
}
- // TODO(b/147015577): Should Routes that are currently bound to n be
+ // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be
// invalidated? Currently, Routes will continue to work when a NIC is enabled
// again, and applications may not know that the underlying NIC was ever
// disabled.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7657a4101..101ca2206 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -867,9 +867,9 @@ func TestRouteWithDownNIC(t *testing.T) {
// Writes with Routes that use NIC1 after being brought up should
// succeed.
//
- // TODO(b/147015577): Should we instead completely invalidate all
- // Routes that were bound to a NIC that was brought down at some
- // point?
+ // TODO(gvisor.dev/issue/1491): Should we instead completely
+ // invalidate all Routes that were bound to a NIC that was brought
+ // down at some point?
if err := upFn(s, nicID1); err != nil {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 0e46e6355..df478115d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -193,7 +193,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- // TODO(b/129292371): Implement.
+ // TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
@@ -432,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(b/129292371): Return network protocol.
+ // TODO(gvisor.dev/issue/173): Return network protocol.
if len(pkt.LinkHeader) > 0 {
// Get info directly from the ethernet header.
hdr := header.Ethernet(pkt.LinkHeader)
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 40aa9326d..861617ff7 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -188,7 +188,7 @@ void ReceiveMessage(int sock, int ifindex) {
// sizeof(sockaddr_ll).
ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
- // TODO(b/129292371): Verify protocol once we return it.
+ // TODO(gvisor.dev/issue/173): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
EXPECT_EQ(src.sll_ifindex, ifindex);
@@ -234,7 +234,7 @@ TEST_P(CookedPacketTest, Receive) {
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
- // TODO(b/129292371): Remove once we support packet socket writing.
+ // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing.
SKIP_IF(IsRunningOnGvisor());
// Let's send a UDP packet and receive it using a regular UDP socket.
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 2fca9fe4d..a11a03415 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -195,7 +195,7 @@ TEST_P(RawPacketTest, Receive) {
// sizeof(sockaddr_ll).
ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
- // TODO(b/129292371): Verify protocol once we return it.
+ // TODO(gvisor.dev/issue/173): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
@@ -240,7 +240,7 @@ TEST_P(RawPacketTest, Receive) {
// Send via a packet socket.
TEST_P(RawPacketTest, Send) {
- // TODO(b/129292371): Remove once we support packet socket writing.
+ // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing.
SKIP_IF(IsRunningOnGvisor());
// Let's send a UDP packet and receive it using a regular UDP socket.
--
cgit v1.2.3
From 2a7b2a61e3ea32129c26eeaa6fab3d81a5d8ad6e Mon Sep 17 00:00:00 2001
From: Kevin Krakauer
Date: Thu, 11 Jun 2020 20:33:56 -0700
Subject: iptables: support SO_ORIGINAL_DST
Envoy (#170) uses this to get the original destination of redirected
packets.
---
pkg/abi/linux/netfilter.go | 8 +-
pkg/abi/linux/socket.go | 4 +-
pkg/sentry/socket/netstack/netstack.go | 17 ++++
pkg/sentry/strace/socket.go | 1 +
pkg/tcpip/stack/conntrack.go | 26 ++++++
pkg/tcpip/stack/iptables.go | 11 ++-
pkg/tcpip/tcpip.go | 4 +
pkg/tcpip/transport/tcp/endpoint.go | 11 +++
test/iptables/BUILD | 1 +
test/iptables/iptables_test.go | 8 ++
test/iptables/iptables_unsafe.go | 63 ++++++++++++++
test/iptables/iptables_util.go | 51 ++++++++++-
test/iptables/nat.go | 152 ++++++++++++++++++++++++++++++++-
13 files changed, 344 insertions(+), 13 deletions(-)
create mode 100644 test/iptables/iptables_unsafe.go
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index a91f9f018..9c27f7bb2 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -59,7 +59,7 @@ var VerdictStrings = map[int32]string{
NF_RETURN: "RETURN",
}
-// Socket options. These correspond to values in
+// Socket options for SOL_SOCKET. These correspond to values in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
const (
IPT_BASE_CTL = 64
@@ -74,6 +74,12 @@ const (
IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET
)
+// Socket option for SOL_IP. This corresponds to the value in
+// include/uapi/linux/netfilter_ipv4.h.
+const (
+ SO_ORIGINAL_DST = 80
+)
+
// Name lengths. These correspond to values in
// include/uapi/linux/netfilter/x_tables.h.
const (
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index c24a8216e..d6946bb82 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -239,11 +239,13 @@ const SockAddrMax = 128
type InetAddr [4]byte
// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
+//
+// +marshal
type SockAddrInet struct {
Family uint16
Port uint16
Addr InetAddr
- Zero [8]uint8 // pad to sizeof(struct sockaddr).
+ _ [8]uint8 // pad to sizeof(struct sockaddr).
}
// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index f86e6cd7a..31a168f7e 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1490,6 +1490,10 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
+ case linux.SO_ORIGINAL_DST:
+ // TODO(gvisor.dev/issue/170): ip6tables.
+ return nil, syserr.ErrInvalidArgument
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1600,6 +1604,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
+ case linux.SO_ORIGINAL_DST:
+ if outLen < int(binary.Size(linux.SockAddrInet{})) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OriginalDestinationOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ return a.(*linux.SockAddrInet), nil
+
default:
emitUnimplementedEventIP(t, name)
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index c0512de89..b51c4c941 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -521,6 +521,7 @@ var sockOptNames = map[uint64]abi.ValueSet{
linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT",
linux.IP_PKTOPTIONS: "IP_PKTOPTIONS",
linux.IP_MTU: "IP_MTU",
+ linux.SO_ORIGINAL_DST: "SO_ORIGINAL_DST",
},
linux.SOL_SOCKET: {
linux.SO_ERROR: "SO_ERROR",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 559a1c4dd..470c265aa 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -240,7 +240,10 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
if err != nil {
return nil, dirOriginal
}
+ return ct.connForTID(tid)
+}
+func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
bucket := ct.bucket(tid)
now := time.Now()
@@ -604,3 +607,26 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
+
+func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ // Lookup the connection. The reply's original destination
+ // describes the original address.
+ tid := tupleID{
+ srcAddr: epID.LocalAddress,
+ srcPort: epID.LocalPort,
+ dstAddr: epID.RemoteAddress,
+ dstPort: epID.RemotePort,
+ transProto: header.TCPProtocolNumber,
+ netProto: header.IPv4ProtocolNumber,
+ }
+ conn, _ := ct.connForTID(tid)
+ if conn == nil {
+ // Not a tracked connection.
+ return "", 0, tcpip.ErrNotConnected
+ } else if conn.manip == manipNone {
+ // Unmanipulated connection.
+ return "", 0, tcpip.ErrInvalidOptionValue
+ }
+
+ return conn.original.dstAddr, conn.original.dstPort, nil
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index cbbae4224..110ba073d 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -218,19 +218,16 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
it.mu.RLock()
+ defer it.mu.RUnlock()
if !it.modified {
- it.mu.RUnlock()
return true
}
- it.mu.RUnlock()
// Packets are manipulated only if connection and matching
// NAT rule exists.
shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- it.mu.RLock()
- defer it.mu.RUnlock()
priorities := it.priorities[hook]
for _, tableID := range priorities {
// If handlePacket already NATed the packet, we don't need to
@@ -418,3 +415,9 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// All the matchers matched, so run the target.
return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
}
+
+// OriginalDst returns the original destination of redirected connections. It
+// returns an error if the connection doesn't exist or isn't redirected.
+func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ return it.connections.originalDst(epID)
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index a634b9b60..45f59b60f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -954,6 +954,10 @@ type DefaultTTLOption uint8
// classic BPF filter on a given endpoint.
type SocketDetachFilterOption int
+// OriginalDestinationOption is used to get the original destination address
+// and port of a redirected packet.
+type OriginalDestinationOption FullAddress
+
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 0f7487963..682687ebe 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2017,6 +2017,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
e.UnlockUser()
+ case *tcpip.OriginalDestinationOption:
+ ipt := e.stack.IPTables()
+ addr, port, err := ipt.OriginalDst(e.ID)
+ if err != nil {
+ return err
+ }
+ *o = tcpip.OriginalDestinationOption{
+ Addr: addr,
+ Port: port,
+ }
+
default:
return tcpip.ErrUnknownProtocolOption
}
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
index 40b63ebbe..66453772a 100644
--- a/test/iptables/BUILD
+++ b/test/iptables/BUILD
@@ -9,6 +9,7 @@ go_library(
"filter_input.go",
"filter_output.go",
"iptables.go",
+ "iptables_unsafe.go",
"iptables_util.go",
"nat.go",
],
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 550b6198a..fda5f694f 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -371,3 +371,11 @@ func TestFilterAddrs(t *testing.T) {
}
}
}
+
+func TestNATPreOriginalDst(t *testing.T) {
+ singleTest(t, NATPreOriginalDst{})
+}
+
+func TestNATOutOriginalDst(t *testing.T) {
+ singleTest(t, NATOutOriginalDst{})
+}
diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go
new file mode 100644
index 000000000..bd85a8fea
--- /dev/null
+++ b/test/iptables/iptables_unsafe.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+)
+
+type originalDstError struct {
+ errno syscall.Errno
+}
+
+func (e originalDstError) Error() string {
+ return fmt.Sprintf("errno (%d) when calling getsockopt(SO_ORIGINAL_DST): %v", int(e.errno), e.errno.Error())
+}
+
+// SO_ORIGINAL_DST gets the original destination of a redirected packet via
+// getsockopt.
+const SO_ORIGINAL_DST = 80
+
+func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) {
+ var addr syscall.RawSockaddrInet4
+ var addrLen uint32 = syscall.SizeofSockaddrInet4
+ if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 {
+ return syscall.RawSockaddrInet4{}, originalDstError{errno}
+ }
+ return addr, nil
+}
+
+func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) {
+ var addr syscall.RawSockaddrInet6
+ var addrLen uint32 = syscall.SizeofSockaddrInet6
+ if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 {
+ return syscall.RawSockaddrInet6{}, originalDstError{errno}
+ }
+ return addr, nil
+}
+
+func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_GETSOCKOPT,
+ uintptr(connfd),
+ level,
+ SO_ORIGINAL_DST,
+ uintptr(optval),
+ uintptr(unsafe.Pointer(optlen)),
+ 0)
+ return errno
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index ca80a4b5f..5125fe47b 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -15,6 +15,8 @@
package iptables
import (
+ "encoding/binary"
+ "errors"
"fmt"
"net"
"os/exec"
@@ -218,17 +220,58 @@ func filterAddrs(addrs []string, ipv6 bool) []string {
// getInterfaceName returns the name of the interface other than loopback.
func getInterfaceName() (string, bool) {
- var ifname string
+ iface, ok := getNonLoopbackInterface()
+ if !ok {
+ return "", false
+ }
+ return iface.Name, true
+}
+
+func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) {
+ iface, ok := getNonLoopbackInterface()
+ if !ok {
+ return nil, errors.New("no non-loopback interface found")
+ }
+ addrs, err := iface.Addrs()
+ if err != nil {
+ return nil, err
+ }
+
+ // Get only IPv4 or IPv6 addresses.
+ ips := make([]net.IP, 0, len(addrs))
+ for _, addr := range addrs {
+ parts := strings.Split(addr.String(), "/")
+ var ip net.IP
+ // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses.
+ // So we check whether To4() returns nil to test whether the
+ // address is v4 or v6.
+ if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil {
+ ip = net.ParseIP(parts[0]).To16()
+ } else {
+ ip = v4
+ }
+ if ip != nil {
+ ips = append(ips, ip)
+ }
+ }
+ return ips, nil
+}
+
+func getNonLoopbackInterface() (net.Interface, bool) {
if interfaces, err := net.Interfaces(); err == nil {
for _, intf := range interfaces {
if intf.Name != "lo" {
- ifname = intf.Name
- break
+ return intf, true
}
}
}
+ return net.Interface{}, false
+}
- return ifname, ifname != ""
+func htons(x uint16) uint16 {
+ buf := make([]byte, 2)
+ binary.BigEndian.PutUint16(buf, x)
+ return binary.LittleEndian.Uint16(buf)
}
func localIP(ipv6 bool) string {
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index ac0d91bb2..b7fea2527 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -18,12 +18,11 @@ import (
"errors"
"fmt"
"net"
+ "syscall"
"time"
)
-const (
- redirectPort = 42
-)
+const redirectPort = 42
func init() {
RegisterTestCase(NATPreRedirectUDPPort{})
@@ -42,6 +41,8 @@ func init() {
RegisterTestCase(NATOutRedirectInvert{})
RegisterTestCase(NATRedirectRequiresProtocol{})
RegisterTestCase(NATLoopbackSkipsPrerouting{})
+ RegisterTestCase(NATPreOriginalDst{})
+ RegisterTestCase(NATOutOriginalDst{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@@ -471,6 +472,151 @@ func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error {
return nil
}
+// NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination
+// of PREROUTING NATted packets.
+type NATPreOriginalDst struct{}
+
+// Name implements TestCase.Name.
+func (NATPreOriginalDst) Name() string {
+ return "NATPreOriginalDst"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error {
+ // Redirect incoming TCP connections to acceptPort.
+ if err := natTable(ipv6, "-A", "PREROUTING",
+ "-p", "tcp",
+ "--destination-port", fmt.Sprintf("%d", dropPort),
+ "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ addrs, err := getInterfaceAddrs(ipv6)
+ if err != nil {
+ return err
+ }
+ return listenForRedirectedConn(ipv6, addrs)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreOriginalDst) LocalAction(ip net.IP, ipv6 bool) error {
+ return connectTCP(ip, dropPort, sendloopDuration)
+}
+
+// NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination
+// of OUTBOUND NATted packets.
+type NATOutOriginalDst struct{}
+
+// Name implements TestCase.Name.
+func (NATOutOriginalDst) Name() string {
+ return "NATOutOriginalDst"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error {
+ // Redirect incoming TCP connections to acceptPort.
+ if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ connCh := make(chan error)
+ go func() {
+ connCh <- connectTCP(ip, dropPort, sendloopDuration)
+ }()
+
+ if err := listenForRedirectedConn(ipv6, []net.IP{ip}); err != nil {
+ return err
+ }
+ return <-connCh
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutOriginalDst) LocalAction(ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error {
+ // The net package doesn't give guarantee access to the connection's
+ // underlying FD, and thus we cannot call getsockopt. We have to use
+ // traditional syscalls for SO_ORIGINAL_DST.
+
+ // Create the listening socket, bind, listen, and accept.
+ family := syscall.AF_INET
+ if ipv6 {
+ family = syscall.AF_INET6
+ }
+ sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ return err
+ }
+ defer syscall.Close(sockfd)
+
+ var bindAddr syscall.Sockaddr
+ if ipv6 {
+ bindAddr = &syscall.SockaddrInet6{
+ Port: acceptPort,
+ Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any
+ }
+ } else {
+ bindAddr = &syscall.SockaddrInet4{
+ Port: acceptPort,
+ Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY
+ }
+ }
+ if err := syscall.Bind(sockfd, bindAddr); err != nil {
+ return err
+ }
+
+ if err := syscall.Listen(sockfd, 1); err != nil {
+ return err
+ }
+
+ connfd, _, err := syscall.Accept(sockfd)
+ if err != nil {
+ return err
+ }
+ defer syscall.Close(connfd)
+
+ // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST
+ // indicates the packet was sent to originalDst:dropPort.
+ if ipv6 {
+ got, err := originalDestination6(connfd)
+ if err != nil {
+ return err
+ }
+ // The original destination could be any of our IPs.
+ for _, dst := range originalDsts {
+ want := syscall.RawSockaddrInet6{
+ Family: syscall.AF_INET6,
+ Port: htons(dropPort),
+ }
+ copy(want.Addr[:], dst.To16())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ } else {
+ got, err := originalDestination4(connfd)
+ if err != nil {
+ return err
+ }
+ // The original destination could be any of our IPs.
+ for _, dst := range originalDsts {
+ want := syscall.RawSockaddrInet4{
+ Family: syscall.AF_INET,
+ Port: htons(dropPort),
+ }
+ copy(want.Addr[:], dst.To4())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ }
+}
+
// loopbackTests runs an iptables rule and ensures that packets sent to
// dest:dropPort are received by localhost:acceptPort.
func loopbackTest(ipv6 bool, dest net.IP, args ...string) error {
--
cgit v1.2.3
From b2ae7ea1bb207eddadd7962080e7bd0b8634db96 Mon Sep 17 00:00:00 2001
From: Nayana Bidari
Date: Mon, 3 Aug 2020 13:33:47 -0700
Subject: Plumbing context.Context to DecRef() and Release().
context is passed to DecRef() and Release() which is
needed for SO_LINGER implementation.
PiperOrigin-RevId: 324672584
---
pkg/refs/BUILD | 6 +-
pkg/refs/refcounter.go | 19 +--
pkg/refs/refcounter_test.go | 38 ++---
pkg/sentry/control/proc.go | 2 +-
pkg/sentry/devices/memdev/full.go | 2 +-
pkg/sentry/devices/memdev/null.go | 2 +-
pkg/sentry/devices/memdev/random.go | 2 +-
pkg/sentry/devices/memdev/zero.go | 2 +-
pkg/sentry/devices/ttydev/ttydev.go | 2 +-
pkg/sentry/devices/tundev/tundev.go | 4 +-
pkg/sentry/fdimport/fdimport.go | 8 +-
pkg/sentry/fs/copy_up.go | 12 +-
pkg/sentry/fs/copy_up_test.go | 4 +-
pkg/sentry/fs/dev/net_tun.go | 4 +-
pkg/sentry/fs/dirent.go | 110 +++++++-------
pkg/sentry/fs/dirent_cache.go | 3 +-
pkg/sentry/fs/dirent_refs_test.go | 16 +--
pkg/sentry/fs/dirent_state.go | 3 +-
pkg/sentry/fs/fdpipe/pipe.go | 2 +-
pkg/sentry/fs/fdpipe/pipe_opener_test.go | 16 +--
pkg/sentry/fs/fdpipe/pipe_test.go | 18 +--
pkg/sentry/fs/file.go | 10 +-
pkg/sentry/fs/file_operations.go | 2 +-
pkg/sentry/fs/file_overlay.go | 22 +--
pkg/sentry/fs/fsutil/file.go | 4 +-
pkg/sentry/fs/gofer/file.go | 4 +-
pkg/sentry/fs/gofer/gofer_test.go | 8 +-
pkg/sentry/fs/gofer/handles.go | 5 +-
pkg/sentry/fs/gofer/inode.go | 5 +-
pkg/sentry/fs/gofer/path.go | 6 +-
pkg/sentry/fs/gofer/session.go | 16 +--
pkg/sentry/fs/gofer/session_state.go | 3 +-
pkg/sentry/fs/gofer/socket.go | 6 +-
pkg/sentry/fs/host/control.go | 2 +-
pkg/sentry/fs/host/file.go | 4 +-
pkg/sentry/fs/host/inode_test.go | 2 +-
pkg/sentry/fs/host/socket.go | 10 +-
pkg/sentry/fs/host/socket_test.go | 38 ++---
pkg/sentry/fs/host/tty.go | 4 +-
pkg/sentry/fs/host/wait_test.go | 2 +-
pkg/sentry/fs/inode.go | 11 +-
pkg/sentry/fs/inode_inotify.go | 5 +-
pkg/sentry/fs/inode_overlay.go | 30 ++--
pkg/sentry/fs/inode_overlay_test.go | 8 +-
pkg/sentry/fs/inotify.go | 8 +-
pkg/sentry/fs/inotify_watch.go | 9 +-
pkg/sentry/fs/mount.go | 12 +-
pkg/sentry/fs/mount_overlay.go | 6 +-
pkg/sentry/fs/mount_test.go | 29 ++--
pkg/sentry/fs/mounts.go | 30 ++--
pkg/sentry/fs/mounts_test.go | 2 +-
pkg/sentry/fs/overlay.go | 10 +-
pkg/sentry/fs/proc/fds.go | 18 +--
pkg/sentry/fs/proc/mounts.go | 8 +-
pkg/sentry/fs/proc/net.go | 12 +-
pkg/sentry/fs/proc/proc.go | 2 +-
pkg/sentry/fs/proc/task.go | 4 +-
pkg/sentry/fs/ramfs/dir.go | 18 +--
pkg/sentry/fs/ramfs/tree_test.go | 2 +-
pkg/sentry/fs/timerfd/timerfd.go | 4 +-
pkg/sentry/fs/tmpfs/file_test.go | 2 +-
pkg/sentry/fs/tty/dir.go | 8 +-
pkg/sentry/fs/tty/fs.go | 2 +-
pkg/sentry/fs/tty/master.go | 8 +-
pkg/sentry/fs/tty/slave.go | 4 +-
pkg/sentry/fs/user/path.go | 8 +-
pkg/sentry/fs/user/user.go | 8 +-
pkg/sentry/fs/user/user_test.go | 8 +-
pkg/sentry/fsbridge/bridge.go | 2 +-
pkg/sentry/fsbridge/fs.go | 8 +-
pkg/sentry/fsbridge/vfs.go | 6 +-
pkg/sentry/fsimpl/devpts/devpts.go | 4 +-
pkg/sentry/fsimpl/devpts/master.go | 6 +-
pkg/sentry/fsimpl/devpts/slave.go | 6 +-
pkg/sentry/fsimpl/devtmpfs/devtmpfs.go | 6 +-
pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go | 8 +-
pkg/sentry/fsimpl/eventfd/eventfd.go | 6 +-
pkg/sentry/fsimpl/eventfd/eventfd_test.go | 12 +-
pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go | 6 +-
pkg/sentry/fsimpl/ext/dentry.go | 7 +-
pkg/sentry/fsimpl/ext/directory.go | 2 +-
pkg/sentry/fsimpl/ext/ext.go | 10 +-
pkg/sentry/fsimpl/ext/ext_test.go | 4 +-
pkg/sentry/fsimpl/ext/filesystem.go | 68 ++++-----
pkg/sentry/fsimpl/ext/regular_file.go | 2 +-
pkg/sentry/fsimpl/ext/symlink.go | 2 +-
pkg/sentry/fsimpl/fuse/dev.go | 2 +-
pkg/sentry/fsimpl/fuse/dev_test.go | 4 +-
pkg/sentry/fsimpl/fuse/fusefs.go | 4 +-
pkg/sentry/fsimpl/gofer/directory.go | 4 +-
pkg/sentry/fsimpl/gofer/filesystem.go | 90 ++++++------
pkg/sentry/fsimpl/gofer/gofer.go | 40 +++---
pkg/sentry/fsimpl/gofer/gofer_test.go | 6 +-
pkg/sentry/fsimpl/gofer/regular_file.go | 2 +-
pkg/sentry/fsimpl/gofer/socket.go | 6 +-
pkg/sentry/fsimpl/gofer/special_file.go | 4 +-
pkg/sentry/fsimpl/host/control.go | 2 +-
pkg/sentry/fsimpl/host/host.go | 14 +-
pkg/sentry/fsimpl/host/socket.go | 12 +-
pkg/sentry/fsimpl/host/tty.go | 4 +-
pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go | 2 +-
pkg/sentry/fsimpl/kernfs/fd_impl_util.go | 2 +-
pkg/sentry/fsimpl/kernfs/filesystem.go | 68 ++++-----
pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 10 +-
pkg/sentry/fsimpl/kernfs/kernfs.go | 26 ++--
pkg/sentry/fsimpl/kernfs/kernfs_test.go | 18 +--
pkg/sentry/fsimpl/overlay/copy_up.go | 8 +-
pkg/sentry/fsimpl/overlay/directory.go | 8 +-
pkg/sentry/fsimpl/overlay/filesystem.go | 64 ++++-----
pkg/sentry/fsimpl/overlay/non_directory.go | 18 +--
pkg/sentry/fsimpl/overlay/overlay.go | 48 +++----
pkg/sentry/fsimpl/pipefs/pipefs.go | 6 +-
pkg/sentry/fsimpl/proc/filesystem.go | 4 +-
pkg/sentry/fsimpl/proc/task_fds.go | 18 +--
pkg/sentry/fsimpl/proc/task_files.go | 14 +-
pkg/sentry/fsimpl/proc/task_net.go | 12 +-
pkg/sentry/fsimpl/proc/tasks_test.go | 10 +-
pkg/sentry/fsimpl/signalfd/signalfd.go | 4 +-
pkg/sentry/fsimpl/sockfs/sockfs.go | 4 +-
pkg/sentry/fsimpl/sys/sys.go | 4 +-
pkg/sentry/fsimpl/sys/sys_test.go | 2 +-
pkg/sentry/fsimpl/testutil/kernel.go | 2 +-
pkg/sentry/fsimpl/testutil/testutil.go | 6 +-
pkg/sentry/fsimpl/timerfd/timerfd.go | 6 +-
pkg/sentry/fsimpl/tmpfs/benchmark_test.go | 50 +++----
pkg/sentry/fsimpl/tmpfs/directory.go | 4 +-
pkg/sentry/fsimpl/tmpfs/filesystem.go | 106 +++++++-------
pkg/sentry/fsimpl/tmpfs/pipe_test.go | 20 +--
pkg/sentry/fsimpl/tmpfs/regular_file.go | 2 +-
pkg/sentry/fsimpl/tmpfs/tmpfs.go | 34 ++---
pkg/sentry/fsimpl/tmpfs/tmpfs_test.go | 6 +-
pkg/sentry/kernel/abstract_socket_namespace.go | 13 +-
pkg/sentry/kernel/epoll/epoll.go | 14 +-
pkg/sentry/kernel/epoll/epoll_test.go | 5 +-
pkg/sentry/kernel/eventfd/eventfd.go | 4 +-
pkg/sentry/kernel/fd_table.go | 55 +++----
pkg/sentry/kernel/fd_table_test.go | 6 +-
pkg/sentry/kernel/fs_context.go | 31 ++--
pkg/sentry/kernel/futex/BUILD | 1 +
pkg/sentry/kernel/futex/futex.go | 35 ++---
pkg/sentry/kernel/futex/futex_test.go | 66 +++++----
pkg/sentry/kernel/kernel.go | 57 ++++----
pkg/sentry/kernel/pipe/node.go | 6 +-
pkg/sentry/kernel/pipe/node_test.go | 2 +-
pkg/sentry/kernel/pipe/pipe.go | 2 +-
pkg/sentry/kernel/pipe/pipe_test.go | 16 +--
pkg/sentry/kernel/pipe/pipe_util.go | 2 +-
pkg/sentry/kernel/pipe/reader.go | 3 +-
pkg/sentry/kernel/pipe/vfs.go | 8 +-
pkg/sentry/kernel/pipe/writer.go | 3 +-
pkg/sentry/kernel/sessions.go | 5 +-
pkg/sentry/kernel/shm/shm.go | 10 +-
pkg/sentry/kernel/signalfd/signalfd.go | 2 +-
pkg/sentry/kernel/task.go | 8 +-
pkg/sentry/kernel/task_clone.go | 10 +-
pkg/sentry/kernel/task_exec.go | 6 +-
pkg/sentry/kernel/task_exit.go | 8 +-
pkg/sentry/kernel/task_log.go | 2 +-
pkg/sentry/kernel/task_start.go | 6 +-
pkg/sentry/kernel/thread_group.go | 4 +-
pkg/sentry/loader/elf.go | 4 +-
pkg/sentry/loader/loader.go | 6 +-
pkg/sentry/memmap/memmap.go | 2 +-
pkg/sentry/mm/aio_context.go | 6 +-
pkg/sentry/mm/lifecycle.go | 2 +-
pkg/sentry/mm/metadata.go | 5 +-
pkg/sentry/mm/special_mappable.go | 4 +-
pkg/sentry/mm/syscalls.go | 4 +-
pkg/sentry/mm/vma.go | 4 +-
pkg/sentry/socket/control/control.go | 8 +-
pkg/sentry/socket/control/control_vfs2.go | 8 +-
pkg/sentry/socket/hostinet/socket.go | 8 +-
pkg/sentry/socket/netlink/provider.go | 2 +-
pkg/sentry/socket/netlink/socket.go | 14 +-
pkg/sentry/socket/netlink/socket_vfs2.go | 4 +-
pkg/sentry/socket/netstack/netstack.go | 6 +-
pkg/sentry/socket/netstack/netstack_vfs2.go | 2 +-
pkg/sentry/socket/socket.go | 4 +-
pkg/sentry/socket/unix/transport/connectioned.go | 10 +-
pkg/sentry/socket/unix/transport/connectionless.go | 12 +-
pkg/sentry/socket/unix/transport/queue.go | 13 +-
pkg/sentry/socket/unix/transport/unix.go | 48 +++----
pkg/sentry/socket/unix/unix.go | 38 ++---
pkg/sentry/socket/unix/unix_vfs2.go | 18 +--
pkg/sentry/strace/strace.go | 12 +-
pkg/sentry/syscalls/epoll.go | 18 +--
pkg/sentry/syscalls/linux/sys_aio.go | 8 +-
pkg/sentry/syscalls/linux/sys_eventfd.go | 2 +-
pkg/sentry/syscalls/linux/sys_file.go | 78 +++++-----
pkg/sentry/syscalls/linux/sys_futex.go | 6 +-
pkg/sentry/syscalls/linux/sys_getdents.go | 2 +-
pkg/sentry/syscalls/linux/sys_inotify.go | 10 +-
pkg/sentry/syscalls/linux/sys_lseek.go | 2 +-
pkg/sentry/syscalls/linux/sys_mmap.go | 4 +-
pkg/sentry/syscalls/linux/sys_mount.go | 2 +-
pkg/sentry/syscalls/linux/sys_pipe.go | 6 +-
pkg/sentry/syscalls/linux/sys_poll.go | 10 +-
pkg/sentry/syscalls/linux/sys_prctl.go | 4 +-
pkg/sentry/syscalls/linux/sys_read.go | 12 +-
pkg/sentry/syscalls/linux/sys_shm.go | 10 +-
pkg/sentry/syscalls/linux/sys_signal.go | 4 +-
pkg/sentry/syscalls/linux/sys_socket.go | 46 +++---
pkg/sentry/syscalls/linux/sys_splice.go | 12 +-
pkg/sentry/syscalls/linux/sys_stat.go | 8 +-
pkg/sentry/syscalls/linux/sys_sync.go | 8 +-
pkg/sentry/syscalls/linux/sys_thread.go | 6 +-
pkg/sentry/syscalls/linux/sys_timerfd.go | 6 +-
pkg/sentry/syscalls/linux/sys_write.go | 10 +-
pkg/sentry/syscalls/linux/sys_xattr.go | 8 +-
pkg/sentry/syscalls/linux/vfs2/aio.go | 8 +-
pkg/sentry/syscalls/linux/vfs2/epoll.go | 14 +-
pkg/sentry/syscalls/linux/vfs2/eventfd.go | 4 +-
pkg/sentry/syscalls/linux/vfs2/execve.go | 12 +-
pkg/sentry/syscalls/linux/vfs2/fd.go | 12 +-
pkg/sentry/syscalls/linux/vfs2/filesystem.go | 22 +--
pkg/sentry/syscalls/linux/vfs2/fscontext.go | 22 +--
pkg/sentry/syscalls/linux/vfs2/getdents.go | 2 +-
pkg/sentry/syscalls/linux/vfs2/inotify.go | 14 +-
pkg/sentry/syscalls/linux/vfs2/ioctl.go | 2 +-
pkg/sentry/syscalls/linux/vfs2/lock.go | 2 +-
pkg/sentry/syscalls/linux/vfs2/memfd.go | 2 +-
pkg/sentry/syscalls/linux/vfs2/mmap.go | 4 +-
pkg/sentry/syscalls/linux/vfs2/mount.go | 4 +-
pkg/sentry/syscalls/linux/vfs2/path.go | 12 +-
pkg/sentry/syscalls/linux/vfs2/pipe.go | 6 +-
pkg/sentry/syscalls/linux/vfs2/poll.go | 10 +-
pkg/sentry/syscalls/linux/vfs2/read_write.go | 48 +++----
pkg/sentry/syscalls/linux/vfs2/setstat.go | 20 +--
pkg/sentry/syscalls/linux/vfs2/signal.go | 4 +-
pkg/sentry/syscalls/linux/vfs2/socket.go | 46 +++---
pkg/sentry/syscalls/linux/vfs2/splice.go | 20 +--
pkg/sentry/syscalls/linux/vfs2/stat.go | 30 ++--
pkg/sentry/syscalls/linux/vfs2/sync.go | 6 +-
pkg/sentry/syscalls/linux/vfs2/timerfd.go | 8 +-
pkg/sentry/syscalls/linux/vfs2/xattr.go | 16 +--
pkg/sentry/vfs/anonfs.go | 8 +-
pkg/sentry/vfs/dentry.go | 37 ++---
pkg/sentry/vfs/epoll.go | 6 +-
pkg/sentry/vfs/file_description.go | 24 ++--
pkg/sentry/vfs/file_description_impl_util_test.go | 18 +--
pkg/sentry/vfs/filesystem.go | 6 +-
pkg/sentry/vfs/inotify.go | 34 ++---
pkg/sentry/vfs/mount.go | 54 +++----
pkg/sentry/vfs/pathname.go | 18 +--
pkg/sentry/vfs/resolving_path.go | 43 +++---
pkg/sentry/vfs/vfs.go | 160 ++++++++++-----------
pkg/tcpip/link/tun/BUILD | 1 +
pkg/tcpip/link/tun/device.go | 9 +-
runsc/boot/fs.go | 12 +-
runsc/boot/loader.go | 16 +--
runsc/boot/loader_test.go | 8 +-
runsc/boot/vfs.go | 10 +-
252 files changed, 1711 insertions(+), 1668 deletions(-)
(limited to 'pkg/sentry/socket/netstack/netstack.go')
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 74affc887..9888cce9c 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -24,6 +24,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/context",
"//pkg/log",
"//pkg/sync",
],
@@ -34,5 +35,8 @@ go_test(
size = "small",
srcs = ["refcounter_test.go"],
library = ":refs",
- deps = ["//pkg/sync"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sync",
+ ],
)
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index c45ba8200..61790221b 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -23,6 +23,7 @@ import (
"runtime"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -38,7 +39,7 @@ type RefCounter interface {
// Note that AtomicRefCounter.DecRef() does not support destructors.
// If a type has a destructor, it must implement its own DecRef()
// method and call AtomicRefCounter.DecRefWithDestructor(destructor).
- DecRef()
+ DecRef(ctx context.Context)
// TryIncRef attempts to increase the reference counter on the object,
// but may fail if all references have already been dropped. This
@@ -57,7 +58,7 @@ type RefCounter interface {
// A WeakRefUser is notified when the last non-weak reference is dropped.
type WeakRefUser interface {
// WeakRefGone is called when the last non-weak reference is dropped.
- WeakRefGone()
+ WeakRefGone(ctx context.Context)
}
// WeakRef is a weak reference.
@@ -123,7 +124,7 @@ func (w *WeakRef) Get() RefCounter {
// Drop drops this weak reference. You should always call drop when you are
// finished with the weak reference. You may not use this object after calling
// drop.
-func (w *WeakRef) Drop() {
+func (w *WeakRef) Drop(ctx context.Context) {
rc, ok := w.get()
if !ok {
// We've been zapped already. When the refcounter has called
@@ -145,7 +146,7 @@ func (w *WeakRef) Drop() {
// And now aren't on the object's list of weak references. So it won't
// zap us if this causes the reference count to drop to zero.
- rc.DecRef()
+ rc.DecRef(ctx)
// Return to the pool.
weakRefPool.Put(w)
@@ -427,7 +428,7 @@ func (r *AtomicRefCount) dropWeakRef(w *WeakRef) {
// A: TryIncRef [transform speculative to real]
//
//go:nosplit
-func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
+func (r *AtomicRefCount) DecRefWithDestructor(ctx context.Context, destroy func(context.Context)) {
switch v := atomic.AddInt64(&r.refCount, -1); {
case v < -1:
panic("Decrementing non-positive ref count")
@@ -448,7 +449,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
if user != nil {
r.mu.Unlock()
- user.WeakRefGone()
+ user.WeakRefGone(ctx)
r.mu.Lock()
}
}
@@ -456,7 +457,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
// Call the destructor.
if destroy != nil {
- destroy()
+ destroy(ctx)
}
}
}
@@ -464,6 +465,6 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
// DecRef decrements this object's reference count.
//
//go:nosplit
-func (r *AtomicRefCount) DecRef() {
- r.DecRefWithDestructor(nil)
+func (r *AtomicRefCount) DecRef(ctx context.Context) {
+ r.DecRefWithDestructor(ctx, nil)
}
diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go
index 1ab4a4440..6d0dd1018 100644
--- a/pkg/refs/refcounter_test.go
+++ b/pkg/refs/refcounter_test.go
@@ -18,6 +18,7 @@ import (
"reflect"
"testing"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -31,11 +32,11 @@ type testCounter struct {
destroyed bool
}
-func (t *testCounter) DecRef() {
- t.AtomicRefCount.DecRefWithDestructor(t.destroy)
+func (t *testCounter) DecRef(ctx context.Context) {
+ t.AtomicRefCount.DecRefWithDestructor(ctx, t.destroy)
}
-func (t *testCounter) destroy() {
+func (t *testCounter) destroy(context.Context) {
t.mu.Lock()
defer t.mu.Unlock()
t.destroyed = true
@@ -53,7 +54,7 @@ func newTestCounter() *testCounter {
func TestOneRef(t *testing.T) {
tc := newTestCounter()
- tc.DecRef()
+ tc.DecRef(context.Background())
if !tc.IsDestroyed() {
t.Errorf("object should have been destroyed")
@@ -63,8 +64,9 @@ func TestOneRef(t *testing.T) {
func TestTwoRefs(t *testing.T) {
tc := newTestCounter()
tc.IncRef()
- tc.DecRef()
- tc.DecRef()
+ ctx := context.Background()
+ tc.DecRef(ctx)
+ tc.DecRef(ctx)
if !tc.IsDestroyed() {
t.Errorf("object should have been destroyed")
@@ -74,12 +76,13 @@ func TestTwoRefs(t *testing.T) {
func TestMultiRefs(t *testing.T) {
tc := newTestCounter()
tc.IncRef()
- tc.DecRef()
+ ctx := context.Background()
+ tc.DecRef(ctx)
tc.IncRef()
- tc.DecRef()
+ tc.DecRef(ctx)
- tc.DecRef()
+ tc.DecRef(ctx)
if !tc.IsDestroyed() {
t.Errorf("object should have been destroyed")
@@ -89,19 +92,20 @@ func TestMultiRefs(t *testing.T) {
func TestWeakRef(t *testing.T) {
tc := newTestCounter()
w := NewWeakRef(tc, nil)
+ ctx := context.Background()
// Try resolving.
if x := w.Get(); x == nil {
t.Errorf("weak reference didn't resolve: expected %v, got nil", tc)
} else {
- x.DecRef()
+ x.DecRef(ctx)
}
// Try resolving again.
if x := w.Get(); x == nil {
t.Errorf("weak reference didn't resolve: expected %v, got nil", tc)
} else {
- x.DecRef()
+ x.DecRef(ctx)
}
// Shouldn't be destroyed yet. (Can't continue if this fails.)
@@ -110,7 +114,7 @@ func TestWeakRef(t *testing.T) {
}
// Drop the original reference.
- tc.DecRef()
+ tc.DecRef(ctx)
// Assert destroyed.
if !tc.IsDestroyed() {
@@ -126,7 +130,8 @@ func TestWeakRef(t *testing.T) {
func TestWeakRefDrop(t *testing.T) {
tc := newTestCounter()
w := NewWeakRef(tc, nil)
- w.Drop()
+ ctx := context.Background()
+ w.Drop(ctx)
// Just assert the list is empty.
if !tc.weakRefs.Empty() {
@@ -134,14 +139,14 @@ func TestWeakRefDrop(t *testing.T) {
}
// Drop the original reference.
- tc.DecRef()
+ tc.DecRef(ctx)
}
type testWeakRefUser struct {
weakRefGone func()
}
-func (u *testWeakRefUser) WeakRefGone() {
+func (u *testWeakRefUser) WeakRefGone(ctx context.Context) {
u.weakRefGone()
}
@@ -165,7 +170,8 @@ func TestCallback(t *testing.T) {
}})
// Drop the original reference, this must trigger the callback.
- tc.DecRef()
+ ctx := context.Background()
+ tc.DecRef(ctx)
if !called {
t.Fatalf("Callback not called")
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 1bae7cfaf..dfa936563 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -139,7 +139,6 @@ func ExecAsync(proc *Proc, args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID
func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
// Import file descriptors.
fdTable := proc.Kernel.NewFDTable()
- defer fdTable.DecRef()
creds := auth.NewUserCredentials(
args.KUID,
@@ -177,6 +176,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
initArgs.MountNamespaceVFS2.IncRef()
}
ctx := initArgs.NewContext(proc.Kernel)
+ defer fdTable.DecRef(ctx)
if kernel.VFS2Enabled {
// Get the full path to the filename from the PATH env variable.
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
index af66fe4dc..511179e31 100644
--- a/pkg/sentry/devices/memdev/full.go
+++ b/pkg/sentry/devices/memdev/full.go
@@ -46,7 +46,7 @@ type fullFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *fullFD) Release() {
+func (fd *fullFD) Release(context.Context) {
// noop
}
diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go
index 92d3d71be..4918dbeeb 100644
--- a/pkg/sentry/devices/memdev/null.go
+++ b/pkg/sentry/devices/memdev/null.go
@@ -47,7 +47,7 @@ type nullFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *nullFD) Release() {
+func (fd *nullFD) Release(context.Context) {
// noop
}
diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go
index 6b81da5ef..5e7fe0280 100644
--- a/pkg/sentry/devices/memdev/random.go
+++ b/pkg/sentry/devices/memdev/random.go
@@ -56,7 +56,7 @@ type randomFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *randomFD) Release() {
+func (fd *randomFD) Release(context.Context) {
// noop
}
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
index c6f15054d..2e631a252 100644
--- a/pkg/sentry/devices/memdev/zero.go
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -48,7 +48,7 @@ type zeroFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *zeroFD) Release() {
+func (fd *zeroFD) Release(context.Context) {
// noop
}
diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go
index fbb7fd92c..fd4b79c46 100644
--- a/pkg/sentry/devices/ttydev/ttydev.go
+++ b/pkg/sentry/devices/ttydev/ttydev.go
@@ -55,7 +55,7 @@ type ttyFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *ttyFD) Release() {}
+func (fd *ttyFD) Release(context.Context) {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index dfbd069af..852ec3c5c 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -108,8 +108,8 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *tunFD) Release() {
- fd.device.Release()
+func (fd *tunFD) Release(ctx context.Context) {
+ fd.device.Release(ctx)
}
// PRead implements vfs.FileDescriptionImpl.PRead.
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
index b8686adb4..1b7cb94c0 100644
--- a/pkg/sentry/fdimport/fdimport.go
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -50,7 +50,7 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []
if err != nil {
return nil, err
}
- defer appFile.DecRef()
+ defer appFile.DecRef(ctx)
// Remember this in the TTY file, as we will
// use it for the other stdio FDs.
@@ -69,7 +69,7 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []
if err != nil {
return nil, err
}
- defer appFile.DecRef()
+ defer appFile.DecRef(ctx)
}
// Add the file to the FD map.
@@ -102,7 +102,7 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi
if err != nil {
return nil, err
}
- defer appFile.DecRef()
+ defer appFile.DecRef(ctx)
// Remember this in the TTY file, as we will use it for the other stdio
// FDs.
@@ -119,7 +119,7 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi
if err != nil {
return nil, err
}
- defer appFile.DecRef()
+ defer appFile.DecRef(ctx)
}
if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index ab1424c95..735452b07 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -201,7 +201,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
parentUpper := parent.Inode.overlay.upper
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
// Create the file in the upper filesystem and get an Inode for it.
@@ -212,7 +212,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
log.Warningf("copy up failed to create file: %v", err)
return syserror.EIO
}
- defer childFile.DecRef()
+ defer childFile.DecRef(ctx)
childUpperInode = childFile.Dirent.Inode
case Directory:
@@ -226,7 +226,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
- defer childUpper.DecRef()
+ defer childUpper.DecRef(ctx)
childUpperInode = childUpper.Inode
case Symlink:
@@ -246,7 +246,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
- defer childUpper.DecRef()
+ defer childUpper.DecRef(ctx)
childUpperInode = childUpper.Inode
default:
@@ -352,14 +352,14 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
if err != nil {
return err
}
- defer upperFile.DecRef()
+ defer upperFile.DecRef(ctx)
// Get a handle to the lower filesystem, which we will read from.
lowerFile, err := overlayFile(ctx, lower, FileFlags{Read: true})
if err != nil {
return err
}
- defer lowerFile.DecRef()
+ defer lowerFile.DecRef(ctx)
// Use a buffer pool to minimize allocations.
buf := copyUpBuffers.Get().([]byte)
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index 91792d9fe..c7a11eec1 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -126,7 +126,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile {
if err != nil {
t.Fatalf("failed to create file %q: %v", name, err)
}
- defer f.DecRef()
+ defer f.DecRef(ctx)
relname, _ := f.Dirent.FullName(lowerRoot)
@@ -171,7 +171,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile {
if err != nil {
t.Fatalf("failed to find %q: %v", f.name, err)
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
if err != nil {
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index dc7ad075a..ec474e554 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -80,8 +80,8 @@ type netTunFileOperations struct {
var _ fs.FileOperations = (*netTunFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (fops *netTunFileOperations) Release() {
- fops.device.Release()
+func (fops *netTunFileOperations) Release(ctx context.Context) {
+ fops.device.Release(ctx)
}
// Ioctl implements fs.FileOperations.Ioctl.
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index 65be12175..a2f751068 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -325,7 +325,7 @@ func (d *Dirent) SyncAll(ctx context.Context) {
for _, w := range d.children {
if child := w.Get(); child != nil {
child.(*Dirent).SyncAll(ctx)
- child.DecRef()
+ child.DecRef(ctx)
}
}
}
@@ -451,7 +451,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// which don't hold a hard reference on their parent (their parent holds a
// hard reference on them, and they contain virtually no state). But this is
// good house-keeping.
- child.DecRef()
+ child.DecRef(ctx)
return nil, syscall.ENOENT
}
@@ -468,20 +468,20 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// their pins on the child. Inotify doesn't properly support filesystems that
// revalidate dirents (since watches are lost on revalidation), but if we fail
// to unpin the watches child will never be GCed.
- cd.Inode.Watches.Unpin(cd)
+ cd.Inode.Watches.Unpin(ctx, cd)
// This child needs to be revalidated, fallthrough to unhash it. Make sure
// to not leak a reference from Get().
//
// Note that previous lookups may still have a reference to this stale child;
// this can't be helped, but we can ensure that *new* lookups are up-to-date.
- child.DecRef()
+ child.DecRef(ctx)
}
// Either our weak reference expired or we need to revalidate it. Unhash child first, we're
// about to replace it.
delete(d.children, name)
- w.Drop()
+ w.Drop(ctx)
}
// Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be
@@ -512,12 +512,12 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// There are active references to the existing child, prefer it to the one we
// retrieved from Lookup. Likely the Lookup happened very close to the insertion
// of child, so considering one stale over the other is fairly arbitrary.
- c.DecRef()
+ c.DecRef(ctx)
// The child that was installed could be negative.
if cd.IsNegative() {
// If so, don't leak a reference and short circuit.
- child.DecRef()
+ child.DecRef(ctx)
return nil, syscall.ENOENT
}
@@ -531,7 +531,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// we did the Inode.Lookup. Fully drop the weak reference and fallback to using the child
// we looked up.
delete(d.children, name)
- w.Drop()
+ w.Drop(ctx)
}
// Give the looked up child a parent. We cannot kick out entries, since we just checked above
@@ -587,7 +587,7 @@ func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool {
return false
}
// Child exists.
- child.DecRef()
+ child.DecRef(ctx)
return true
}
@@ -622,7 +622,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi
}
child := file.Dirent
- d.finishCreate(child, name)
+ d.finishCreate(ctx, child, name)
// Return the reference and the new file. When the last reference to
// the file is dropped, file.Dirent may no longer be cached.
@@ -631,7 +631,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi
// finishCreate validates the created file, adds it as a child of this dirent,
// and notifies any watchers.
-func (d *Dirent) finishCreate(child *Dirent, name string) {
+func (d *Dirent) finishCreate(ctx context.Context, child *Dirent, name string) {
// Sanity check c, its name must be consistent.
if child.name != name {
panic(fmt.Sprintf("create from %q to %q returned unexpected name %q", d.name, name, child.name))
@@ -650,14 +650,14 @@ func (d *Dirent) finishCreate(child *Dirent, name string) {
panic(fmt.Sprintf("hashed child %q over a positive child", child.name))
}
// Don't leak a reference.
- old.DecRef()
+ old.DecRef(ctx)
// Drop d's reference.
- old.DecRef()
+ old.DecRef(ctx)
}
// Finally drop the useless weak reference on the floor.
- w.Drop()
+ w.Drop(ctx)
}
d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
@@ -686,17 +686,17 @@ func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, c
panic(fmt.Sprintf("hashed over a positive child %q", old.(*Dirent).name))
}
// Don't leak a reference.
- old.DecRef()
+ old.DecRef(ctx)
// Drop d's reference.
- old.DecRef()
+ old.DecRef(ctx)
}
// Unhash the negative Dirent, name needs to exist now.
delete(d.children, name)
// Finally drop the useless weak reference on the floor.
- w.Drop()
+ w.Drop(ctx)
}
// Execute the create operation.
@@ -756,7 +756,7 @@ func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data trans
if e != nil {
return e
}
- d.finishCreate(childDir, name)
+ d.finishCreate(ctx, childDir, name)
return nil
})
if err == syscall.EEXIST {
@@ -901,7 +901,7 @@ func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent,
// references to children.
//
// Preconditions: d.mu must be held.
-func (d *Dirent) flush() {
+func (d *Dirent) flush(ctx context.Context) {
expired := make(map[string]*refs.WeakRef)
for n, w := range d.children {
// Call flush recursively on each child before removing our
@@ -912,7 +912,7 @@ func (d *Dirent) flush() {
if !cd.IsNegative() {
// Flush the child.
cd.mu.Lock()
- cd.flush()
+ cd.flush(ctx)
cd.mu.Unlock()
// Allow the file system to drop extra references on child.
@@ -920,13 +920,13 @@ func (d *Dirent) flush() {
}
// Don't leak a reference.
- child.DecRef()
+ child.DecRef(ctx)
}
// Check if the child dirent is closed, and mark it as expired if it is.
// We must call w.Get() again here, since the child could have been closed
// by the calls to flush() and cache.Remove() in the above if-block.
if child := w.Get(); child != nil {
- child.DecRef()
+ child.DecRef(ctx)
} else {
expired[n] = w
}
@@ -935,7 +935,7 @@ func (d *Dirent) flush() {
// Remove expired entries.
for n, w := range expired {
delete(d.children, n)
- w.Drop()
+ w.Drop(ctx)
}
}
@@ -977,7 +977,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err
if !ok {
panic("mount must mount over an existing dirent")
}
- weakRef.Drop()
+ weakRef.Drop(ctx)
// Note that even though `d` is now hidden, it still holds a reference
// to its parent.
@@ -1002,13 +1002,13 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
if !ok {
panic("mount must mount over an existing dirent")
}
- weakRef.Drop()
+ weakRef.Drop(ctx)
// d is not reachable anymore, and hence not mounted anymore.
d.mounted = false
// Drop mount reference.
- d.DecRef()
+ d.DecRef(ctx)
return nil
}
@@ -1029,7 +1029,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
// Child does not exist.
return err
}
- defer child.DecRef()
+ defer child.DecRef(ctx)
// Remove cannot remove directories.
if IsDir(child.Inode.StableAttr) {
@@ -1055,7 +1055,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
atomic.StoreInt32(&child.deleted, 1)
if w, ok := d.children[name]; ok {
delete(d.children, name)
- w.Drop()
+ w.Drop(ctx)
}
// Allow the file system to drop extra references on child.
@@ -1067,7 +1067,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
// inode may have other links. If this was the last link, the events for the
// watch removal will be queued by the inode destructor.
child.Inode.Watches.MarkUnlinked()
- child.Inode.Watches.Unpin(child)
+ child.Inode.Watches.Unpin(ctx, child)
d.Inode.Watches.Notify(name, linux.IN_DELETE, 0)
return nil
@@ -1100,7 +1100,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
// Child does not exist.
return err
}
- defer child.DecRef()
+ defer child.DecRef(ctx)
// RemoveDirectory can only remove directories.
if !IsDir(child.Inode.StableAttr) {
@@ -1121,7 +1121,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
atomic.StoreInt32(&child.deleted, 1)
if w, ok := d.children[name]; ok {
delete(d.children, name)
- w.Drop()
+ w.Drop(ctx)
}
// Allow the file system to drop extra references on child.
@@ -1130,14 +1130,14 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
// Finally, let inotify know the child is being unlinked. Drop any extra
// refs from inotify to this child dirent.
child.Inode.Watches.MarkUnlinked()
- child.Inode.Watches.Unpin(child)
+ child.Inode.Watches.Unpin(ctx, child)
d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_DELETE, 0)
return nil
}
// destroy closes this node and all children.
-func (d *Dirent) destroy() {
+func (d *Dirent) destroy(ctx context.Context) {
if d.IsNegative() {
// Nothing to tear-down and no parent references to drop, since a negative
// Dirent does not take a references on its parent, has no Inode and no children.
@@ -1153,19 +1153,19 @@ func (d *Dirent) destroy() {
if c.(*Dirent).IsNegative() {
// The parent holds both weak and strong refs in the case of
// negative dirents.
- c.DecRef()
+ c.DecRef(ctx)
}
// Drop the reference we just acquired in WeakRef.Get.
- c.DecRef()
+ c.DecRef(ctx)
}
- w.Drop()
+ w.Drop(ctx)
}
d.children = nil
allDirents.remove(d)
// Drop our reference to the Inode.
- d.Inode.DecRef()
+ d.Inode.DecRef(ctx)
// Allow the Dirent to be GC'ed after this point, since the Inode may still
// be referenced after the Dirent is destroyed (for instance by filesystem
@@ -1175,7 +1175,7 @@ func (d *Dirent) destroy() {
// Drop the reference we have on our parent if we took one. renameMu doesn't need to be
// held because d can't be reparented without any references to it left.
if d.parent != nil {
- d.parent.DecRef()
+ d.parent.DecRef(ctx)
}
}
@@ -1201,14 +1201,14 @@ func (d *Dirent) TryIncRef() bool {
// DecRef decreases the Dirent's refcount and drops its reference on its mount.
//
// DecRef implements RefCounter.DecRef with destructor d.destroy.
-func (d *Dirent) DecRef() {
+func (d *Dirent) DecRef(ctx context.Context) {
if d.Inode != nil {
// Keep mount around, since DecRef may destroy d.Inode.
msrc := d.Inode.MountSource
- d.DecRefWithDestructor(d.destroy)
+ d.DecRefWithDestructor(ctx, d.destroy)
msrc.DecDirentRefs()
} else {
- d.DecRefWithDestructor(d.destroy)
+ d.DecRefWithDestructor(ctx, d.destroy)
}
}
@@ -1359,7 +1359,7 @@ func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error
if err != nil {
return err
}
- defer victim.DecRef()
+ defer victim.DecRef(ctx)
return d.mayDelete(ctx, victim)
}
@@ -1411,7 +1411,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
if err != nil {
return err
}
- defer renamed.DecRef()
+ defer renamed.DecRef(ctx)
// Check that the renamed dirent is deletable.
if err := oldParent.mayDelete(ctx, renamed); err != nil {
@@ -1453,13 +1453,13 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// Check that we can delete replaced.
if err := newParent.mayDelete(ctx, replaced); err != nil {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return err
}
// Target should not be an ancestor of source.
if oldParent.descendantOf(replaced) {
- replaced.DecRef()
+ replaced.DecRef(ctx)
// Note that Linux returns EINVAL if the source is an
// ancestor of target, but ENOTEMPTY if the target is
@@ -1470,7 +1470,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// Check that replaced is not a mount point.
if replaced.isMountPointLocked() {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syscall.EBUSY
}
@@ -1478,11 +1478,11 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
oldIsDir := IsDir(renamed.Inode.StableAttr)
newIsDir := IsDir(replaced.Inode.StableAttr)
if !newIsDir && oldIsDir {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syscall.ENOTDIR
}
if !oldIsDir && newIsDir {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syscall.EISDIR
}
@@ -1493,13 +1493,13 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// open across renames is currently broken for multiple
// reasons, so we flush all references on the replaced node and
// its children.
- replaced.Inode.Watches.Unpin(replaced)
+ replaced.Inode.Watches.Unpin(ctx, replaced)
replaced.mu.Lock()
- replaced.flush()
+ replaced.flush(ctx)
replaced.mu.Unlock()
// Done with replaced.
- replaced.DecRef()
+ replaced.DecRef(ctx)
}
if err := renamed.Inode.Rename(ctx, oldParent, renamed, newParent, newName, replaced != nil); err != nil {
@@ -1513,14 +1513,14 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// can't destroy oldParent (and try to retake its lock) because
// Rename's caller must be holding a reference.
newParent.IncRef()
- oldParent.DecRef()
+ oldParent.DecRef(ctx)
}
if w, ok := newParent.children[newName]; ok {
- w.Drop()
+ w.Drop(ctx)
delete(newParent.children, newName)
}
if w, ok := oldParent.children[oldName]; ok {
- w.Drop()
+ w.Drop(ctx)
delete(oldParent.children, oldName)
}
@@ -1551,7 +1551,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// Same as replaced.flush above.
renamed.mu.Lock()
- renamed.flush()
+ renamed.flush(ctx)
renamed.mu.Unlock()
return nil
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
index 33de32c69..7d9dd717e 100644
--- a/pkg/sentry/fs/dirent_cache.go
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -17,6 +17,7 @@ package fs
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -101,7 +102,7 @@ func (c *DirentCache) remove(d *Dirent) {
panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d))
}
c.list.Remove(d)
- d.DecRef()
+ d.DecRef(context.Background())
c.currentSize--
if c.limit != nil {
c.limit.dec()
diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go
index 98d69c6f2..176b894ba 100644
--- a/pkg/sentry/fs/dirent_refs_test.go
+++ b/pkg/sentry/fs/dirent_refs_test.go
@@ -51,7 +51,7 @@ func TestWalkPositive(t *testing.T) {
t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1)
}
- d.DecRef()
+ d.DecRef(ctx)
if got := root.ReadRefs(); got != 1 {
t.Fatalf("root has a ref count of %d, want %d", got, 1)
@@ -61,7 +61,7 @@ func TestWalkPositive(t *testing.T) {
t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0)
}
- root.flush()
+ root.flush(ctx)
if got := len(root.children); got != 0 {
t.Fatalf("root has %d children, want %d", got, 0)
@@ -114,7 +114,7 @@ func TestWalkNegative(t *testing.T) {
t.Fatalf("child has a ref count of %d, want %d", got, 2)
}
- child.DecRef()
+ child.DecRef(ctx)
if got := child.(*Dirent).ReadRefs(); got != 1 {
t.Fatalf("child has a ref count of %d, want %d", got, 1)
@@ -124,7 +124,7 @@ func TestWalkNegative(t *testing.T) {
t.Fatalf("root has %d children, want %d", got, 1)
}
- root.DecRef()
+ root.DecRef(ctx)
if got := root.ReadRefs(); got != 0 {
t.Fatalf("root has a ref count of %d, want %d", got, 0)
@@ -351,9 +351,9 @@ func TestRemoveExtraRefs(t *testing.T) {
t.Fatalf("dirent has a ref count of %d, want %d", got, 1)
}
- d.DecRef()
+ d.DecRef(ctx)
- test.root.flush()
+ test.root.flush(ctx)
if got := len(test.root.children); got != 0 {
t.Errorf("root has %d children, want %d", got, 0)
@@ -403,8 +403,8 @@ func TestRenameExtraRefs(t *testing.T) {
t.Fatalf("Rename got error %v, want nil", err)
}
- oldParent.flush()
- newParent.flush()
+ oldParent.flush(ctx)
+ newParent.flush(ctx)
// Expect to have only active references.
if got := renamed.ReadRefs(); got != 1 {
diff --git a/pkg/sentry/fs/dirent_state.go b/pkg/sentry/fs/dirent_state.go
index f623d6c0e..67a35f0b2 100644
--- a/pkg/sentry/fs/dirent_state.go
+++ b/pkg/sentry/fs/dirent_state.go
@@ -18,6 +18,7 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
)
@@ -48,7 +49,7 @@ func (d *Dirent) saveChildren() map[string]*Dirent {
for name, w := range d.children {
if rc := w.Get(); rc != nil {
// Drop the reference count obtain in w.Get()
- rc.DecRef()
+ rc.DecRef(context.Background())
cd := rc.(*Dirent)
if cd.IsNegative() {
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 9fce177ad..b99199798 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -115,7 +115,7 @@ func (p *pipeOperations) Readiness(mask waiter.EventMask) (eventMask waiter.Even
}
// Release implements fs.FileOperations.Release.
-func (p *pipeOperations) Release() {
+func (p *pipeOperations) Release(context.Context) {
fdnotifier.RemoveFD(int32(p.file.FD()))
p.file.Close()
p.file = nil
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
index e556da48a..b9cec4b13 100644
--- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
@@ -182,7 +182,7 @@ func TestTryOpen(t *testing.T) {
// Cleanup the state of the pipe, and remove the fd from the
// fdnotifier. Sadly this needed to maintain the correctness
// of other tests because the fdnotifier is global.
- pipeOps.Release()
+ pipeOps.Release(ctx)
}
continue
}
@@ -191,7 +191,7 @@ func TestTryOpen(t *testing.T) {
}
if pipeOps != nil {
// Same as above.
- pipeOps.Release()
+ pipeOps.Release(ctx)
}
}
}
@@ -279,7 +279,7 @@ func TestPipeOpenUnblocksEventually(t *testing.T) {
pipeOps, err := Open(ctx, opener, flags)
if pipeOps != nil {
// Same as TestTryOpen.
- pipeOps.Release()
+ pipeOps.Release(ctx)
}
// Check that the partner opened the file successfully.
@@ -325,7 +325,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) {
ctx := contexttest.Context(t)
pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true})
if pipeOps != nil {
- pipeOps.Release()
+ pipeOps.Release(ctx)
t.Fatalf("open(%s, %o) got file, want nil", name, syscall.O_RDONLY)
}
if err != syserror.ErrWouldBlock {
@@ -351,7 +351,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) {
if pipeOps == nil {
t.Fatalf("open(%s, %o) got nil file, want not nil", name, syscall.O_RDONLY)
}
- defer pipeOps.Release()
+ defer pipeOps.Release(ctx)
if err != nil {
t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_RDONLY, err)
@@ -471,14 +471,14 @@ func TestPipeHangup(t *testing.T) {
f := <-fdchan
if f < 0 {
t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f)
- pipeOps.Release()
+ pipeOps.Release(ctx)
continue
}
if test.hangupSelf {
// Hangup self and assert that our partner got the expected hangup
// error.
- pipeOps.Release()
+ pipeOps.Release(ctx)
if test.flags.Read {
// Partner is writer.
@@ -490,7 +490,7 @@ func TestPipeHangup(t *testing.T) {
} else {
// Hangup our partner and expect us to get the hangup error.
syscall.Close(f)
- defer pipeOps.Release()
+ defer pipeOps.Release(ctx)
if test.flags.Read {
assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file)
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
index a0082ecca..1c9e82562 100644
--- a/pkg/sentry/fs/fdpipe/pipe_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -98,10 +98,11 @@ func TestNewPipe(t *testing.T) {
}
f := fd.New(gfd)
- p, err := newPipeOperations(contexttest.Context(t), nil, test.flags, f, test.readAheadBuffer)
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, test.flags, f, test.readAheadBuffer)
if p != nil {
// This is necessary to remove the fd from the global fd notifier.
- defer p.Release()
+ defer p.Release(ctx)
} else {
// If there is no p to DecRef on, because newPipeOperations failed, then the
// file still needs to be closed.
@@ -153,13 +154,14 @@ func TestPipeDestruction(t *testing.T) {
syscall.Close(fds[1])
// Test the read end, but it doesn't really matter which.
- p, err := newPipeOperations(contexttest.Context(t), nil, fs.FileFlags{Read: true}, f, nil)
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, f, nil)
if err != nil {
f.Close()
t.Fatalf("newPipeOperations got error %v, want nil", err)
}
// Drop our only reference, which should trigger the destructor.
- p.Release()
+ p.Release(ctx)
if fdnotifier.HasFD(int32(fds[0])) {
t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0])
@@ -282,7 +284,7 @@ func TestPipeRequest(t *testing.T) {
if err != nil {
t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err)
}
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe})
file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p)
@@ -334,7 +336,7 @@ func TestPipeReadAheadBuffer(t *testing.T) {
rfile.Close()
t.Fatalf("newPipeOperations got error %v, want nil", err)
}
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
Type: fs.Pipe,
@@ -380,7 +382,7 @@ func TestPipeReadsAccumulate(t *testing.T) {
}
// Don't forget to remove the fd from the fd notifier. Otherwise other tests will
// likely be borked, because it's global :(
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
Type: fs.Pipe,
@@ -448,7 +450,7 @@ func TestPipeWritesAccumulate(t *testing.T) {
}
// Don't forget to remove the fd from the fd notifier. Otherwise other tests
// will likely be borked, because it's global :(
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
Type: fs.Pipe,
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index ca41520b4..72ea70fcf 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -142,17 +142,17 @@ func NewFile(ctx context.Context, dirent *Dirent, flags FileFlags, fops FileOper
}
// DecRef destroys the File when it is no longer referenced.
-func (f *File) DecRef() {
- f.DecRefWithDestructor(func() {
+func (f *File) DecRef(ctx context.Context) {
+ f.DecRefWithDestructor(ctx, func(context.Context) {
// Drop BSD style locks.
lockRng := lock.LockRange{Start: 0, End: lock.LockEOF}
f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng)
// Release resources held by the FileOperations.
- f.FileOperations.Release()
+ f.FileOperations.Release(ctx)
// Release a reference on the Dirent.
- f.Dirent.DecRef()
+ f.Dirent.DecRef(ctx)
// Only unregister if we are currently registered. There is nothing
// to register if f.async is nil (this happens when async mode is
@@ -460,7 +460,7 @@ func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
func (f *File) MappedName(ctx context.Context) string {
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
name, _ := f.Dirent.FullName(root)
return name
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index f5537411e..305c0f840 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -67,7 +67,7 @@ type SpliceOpts struct {
// - File.Flags(): This value may change during the operation.
type FileOperations interface {
// Release release resources held by FileOperations.
- Release()
+ Release(ctx context.Context)
// Waitable defines how this File can be waited on for read and
// write readiness.
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index dcc1df38f..9dc58d5ff 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -54,7 +54,7 @@ func overlayFile(ctx context.Context, inode *Inode, flags FileFlags) (*File, err
// Drop the extra reference on the Dirent. Now there's only one reference
// on the dirent, either owned by f (if non-nil), or the Dirent is about
// to be destroyed (if GetFile failed).
- dirent.DecRef()
+ dirent.DecRef(ctx)
return f, err
}
@@ -89,12 +89,12 @@ type overlayFileOperations struct {
}
// Release implements FileOperations.Release.
-func (f *overlayFileOperations) Release() {
+func (f *overlayFileOperations) Release(ctx context.Context) {
if f.upper != nil {
- f.upper.DecRef()
+ f.upper.DecRef(ctx)
}
if f.lower != nil {
- f.lower.DecRef()
+ f.lower.DecRef(ctx)
}
}
@@ -164,7 +164,7 @@ func (f *overlayFileOperations) Seek(ctx context.Context, file *File, whence See
func (f *overlayFileOperations) Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error) {
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &DirCtx{
@@ -497,7 +497,7 @@ func readdirOne(ctx context.Context, d *Dirent) (map[string]DentAttr, error) {
if err != nil {
return nil, err
}
- defer dir.DecRef()
+ defer dir.DecRef(ctx)
// Use a stub serializer to read the entries into memory.
stubSerializer := &CollectEntriesSerializer{}
@@ -521,10 +521,10 @@ type overlayMappingIdentity struct {
}
// DecRef implements AtomicRefCount.DecRef.
-func (omi *overlayMappingIdentity) DecRef() {
- omi.AtomicRefCount.DecRefWithDestructor(func() {
- omi.overlayFile.DecRef()
- omi.id.DecRef()
+func (omi *overlayMappingIdentity) DecRef(ctx context.Context) {
+ omi.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) {
+ omi.overlayFile.DecRef(ctx)
+ omi.id.DecRef(ctx)
})
}
@@ -544,7 +544,7 @@ func (omi *overlayMappingIdentity) InodeID() uint64 {
func (omi *overlayMappingIdentity) MappedName(ctx context.Context) string {
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
name, _ := omi.overlayFile.Dirent.FullName(root)
return name
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index 08695391c..dc9efa5df 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -31,7 +31,7 @@ import (
type FileNoopRelease struct{}
// Release is a no-op.
-func (FileNoopRelease) Release() {}
+func (FileNoopRelease) Release(context.Context) {}
// SeekWithDirCursor is used to implement fs.FileOperations.Seek. If dirCursor
// is not nil and the seek was on a directory, the cursor will be updated.
@@ -296,7 +296,7 @@ func (sdfo *StaticDirFileOperations) IterateDir(ctx context.Context, d *fs.Diren
func (sdfo *StaticDirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index b2fcab127..c0bc63a32 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -114,7 +114,7 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF
}
// Release implements fs.FileOpeations.Release.
-func (f *fileOperations) Release() {
+func (f *fileOperations) Release(context.Context) {
f.handles.DecRef()
}
@@ -122,7 +122,7 @@ func (f *fileOperations) Release() {
func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go
index 2df2fe889..326fed954 100644
--- a/pkg/sentry/fs/gofer/gofer_test.go
+++ b/pkg/sentry/fs/gofer/gofer_test.go
@@ -232,7 +232,7 @@ func TestRevalidation(t *testing.T) {
// We must release the dirent, of the test will fail
// with a reference leak. This is tracked by p9test.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
// Walk again. Depending on the cache policy, we may
// get a new dirent.
@@ -246,7 +246,7 @@ func TestRevalidation(t *testing.T) {
if !test.preModificationWantReload && dirent != newDirent {
t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent)
}
- newDirent.DecRef() // See above.
+ newDirent.DecRef(ctx) // See above.
// Modify the underlying mocked file's modification
// time for the next walk that occurs.
@@ -287,7 +287,7 @@ func TestRevalidation(t *testing.T) {
if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds {
t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds)
}
- newDirent.DecRef() // See above.
+ newDirent.DecRef(ctx) // See above.
// Remove the file from the remote fs, subsequent walks
// should now fail to find anything.
@@ -303,7 +303,7 @@ func TestRevalidation(t *testing.T) {
t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err)
}
if err == nil {
- newDirent.DecRef() // See above.
+ newDirent.DecRef(ctx) // See above.
}
})
}
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
index fc14249be..f324dbf26 100644
--- a/pkg/sentry/fs/gofer/handles.go
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -47,7 +47,8 @@ type handles struct {
// DecRef drops a reference on handles.
func (h *handles) DecRef() {
- h.DecRefWithDestructor(func() {
+ ctx := context.Background()
+ h.DecRefWithDestructor(ctx, func(context.Context) {
if h.Host != nil {
if h.isHostBorrowed {
h.Host.Release()
@@ -57,7 +58,7 @@ func (h *handles) DecRef() {
}
}
}
- if err := h.File.close(context.Background()); err != nil {
+ if err := h.File.close(ctx); err != nil {
log.Warningf("error closing p9 file: %v", err)
}
})
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 51d7368a1..3a225fd39 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -441,8 +441,9 @@ func (i *inodeOperations) Release(ctx context.Context) {
// asynchronously.
//
// We use AsyncWithContext to avoid needing to allocate an extra
- // anonymous function on the heap.
- fs.AsyncWithContext(ctx, i.fileState.Release)
+ // anonymous function on the heap. We must use background context
+ // because the async work cannot happen on the task context.
+ fs.AsyncWithContext(context.Background(), i.fileState.Release)
}
// Mappable implements fs.InodeOperations.Mappable.
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index cf9800100..3c66dc3c2 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -168,7 +168,7 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string
// Construct the positive Dirent.
d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
- defer d.DecRef()
+ defer d.DecRef(ctx)
// Construct the new file, caching the handles if allowed.
h := handles{
@@ -371,7 +371,7 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
// Find out if file being deleted is a socket or pipe that needs to be
// removed from endpoint map.
if d, err := i.Lookup(ctx, dir, name); err == nil {
- defer d.DecRef()
+ defer d.DecRef(ctx)
if fs.IsSocket(d.Inode.StableAttr) || fs.IsPipe(d.Inode.StableAttr) {
switch iops := d.Inode.InodeOperations.(type) {
@@ -392,7 +392,7 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
return err
}
if key != nil {
- i.session().overrides.remove(*key)
+ i.session().overrides.remove(ctx, *key)
}
i.touchModificationAndStatusChangeTime(ctx, dir)
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index b5efc86f2..7cf3522ff 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -89,10 +89,10 @@ func (e *overrideMaps) addPipe(key device.MultiDeviceKey, d *fs.Dirent, inode *f
// remove deletes the key from the maps.
//
// Precondition: maps must have been locked with 'lock'.
-func (e *overrideMaps) remove(key device.MultiDeviceKey) {
+func (e *overrideMaps) remove(ctx context.Context, key device.MultiDeviceKey) {
endpoint := e.keyMap[key]
delete(e.keyMap, key)
- endpoint.dirent.DecRef()
+ endpoint.dirent.DecRef(ctx)
}
// lock blocks other addition and removal operations from happening while
@@ -197,7 +197,7 @@ type session struct {
}
// Destroy tears down the session.
-func (s *session) Destroy() {
+func (s *session) Destroy(ctx context.Context) {
s.client.Close()
}
@@ -329,7 +329,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
s.client, err = p9.NewClient(conn, s.msize, s.version)
if err != nil {
// Drop our reference on the session, it needs to be torn down.
- s.DecRef()
+ s.DecRef(ctx)
return nil, err
}
@@ -340,7 +340,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
ctx.UninterruptibleSleepFinish(false)
if err != nil {
// Same as above.
- s.DecRef()
+ s.DecRef(ctx)
return nil, err
}
@@ -348,7 +348,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
if err != nil {
s.attach.close(ctx)
// Same as above, but after we execute the Close request.
- s.DecRef()
+ s.DecRef(ctx)
return nil, err
}
@@ -393,13 +393,13 @@ func (s *session) fillKeyMap(ctx context.Context) error {
// fillPathMap populates paths for overrides from dirents in direntMap
// before save.
-func (s *session) fillPathMap() error {
+func (s *session) fillPathMap(ctx context.Context) error {
unlock := s.overrides.lock()
defer unlock()
for _, endpoint := range s.overrides.keyMap {
mountRoot := endpoint.dirent.MountRoot()
- defer mountRoot.DecRef()
+ defer mountRoot.DecRef(ctx)
dirPath, _ := endpoint.dirent.FullName(mountRoot)
if dirPath == "" {
return fmt.Errorf("error getting path from dirent")
diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go
index 2d398b753..48b423dd8 100644
--- a/pkg/sentry/fs/gofer/session_state.go
+++ b/pkg/sentry/fs/gofer/session_state.go
@@ -26,7 +26,8 @@ import (
// beforeSave is invoked by stateify.
func (s *session) beforeSave() {
if s.overrides != nil {
- if err := s.fillPathMap(); err != nil {
+ ctx := &dummyClockContext{context.Background()}
+ if err := s.fillPathMap(ctx); err != nil {
panic("failed to save paths to override map before saving" + err.Error())
}
}
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
index 40f2c1cad..8a1c69ac2 100644
--- a/pkg/sentry/fs/gofer/socket.go
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -134,14 +134,14 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
// We don't need the receiver.
c.CloseRecv()
- c.Release()
+ c.Release(ctx)
return c, nil
}
// Release implements transport.BoundEndpoint.Release.
-func (e *endpoint) Release() {
- e.inode.DecRef()
+func (e *endpoint) Release(ctx context.Context) {
+ e.inode.DecRef(ctx)
}
// Passcred implements transport.BoundEndpoint.Passcred.
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
index 39299b7e4..0d8d36afa 100644
--- a/pkg/sentry/fs/host/control.go
+++ b/pkg/sentry/fs/host/control.go
@@ -57,7 +57,7 @@ func (c *scmRights) Clone() transport.RightsControlMessage {
}
// Release implements transport.RightsControlMessage.Release.
-func (c *scmRights) Release() {
+func (c *scmRights) Release(ctx context.Context) {
for _, fd := range c.fds {
syscall.Close(fd)
}
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
index 3e48b8b2c..86d1a87f0 100644
--- a/pkg/sentry/fs/host/file.go
+++ b/pkg/sentry/fs/host/file.go
@@ -110,7 +110,7 @@ func newFileFromDonatedFD(ctx context.Context, donated int, saveable, isTTY bool
name := fmt.Sprintf("host:[%d]", inode.StableAttr.InodeID)
dirent := fs.NewDirent(ctx, inode, name)
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
if isTTY {
return newTTYFile(ctx, dirent, flags, iops), nil
@@ -169,7 +169,7 @@ func (f *fileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go
index c507f57eb..41a23b5da 100644
--- a/pkg/sentry/fs/host/inode_test.go
+++ b/pkg/sentry/fs/host/inode_test.go
@@ -36,7 +36,7 @@ func TestCloseFD(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create File: %v", err)
}
- file.DecRef()
+ file.DecRef(ctx)
s := make([]byte, 10)
if c, err := syscall.Read(p[0], s); c != 0 || err != nil {
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index cfb089e43..a2f3d5918 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -194,7 +194,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
}
// Send implements transport.ConnectedEndpoint.Send.
-func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -271,7 +271,7 @@ func (c *ConnectedEndpoint) EventUpdate() {
}
// Recv implements transport.Receiver.Recv.
-func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -318,7 +318,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek
}
// close releases all resources related to the endpoint.
-func (c *ConnectedEndpoint) close() {
+func (c *ConnectedEndpoint) close(context.Context) {
fdnotifier.RemoveFD(int32(c.file.FD()))
c.file.Close()
c.file = nil
@@ -374,8 +374,8 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
}
// Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release.
-func (c *ConnectedEndpoint) Release() {
- c.ref.DecRefWithDestructor(c.close)
+func (c *ConnectedEndpoint) Release(ctx context.Context) {
+ c.ref.DecRefWithDestructor(ctx, c.close)
}
// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index affdbcacb..9d58ea448 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -67,11 +67,12 @@ func TestSocketIsBlocking(t *testing.T) {
if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
t.Fatalf("Expected socket %v to be blocking", pair[1])
}
- sock, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ sock, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) failed => %v", pair[0], err)
}
- defer sock.DecRef()
+ defer sock.DecRef(ctx)
// Test that the socket now is non-blocking.
if fl, err = getFl(pair[0]); err != nil {
t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
@@ -93,11 +94,12 @@ func TestSocketWritev(t *testing.T) {
if err != nil {
t.Fatalf("host socket creation failed: %v", err)
}
- socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ socket, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer socket.DecRef()
+ defer socket.DecRef(ctx)
buf := []byte("hello world\n")
n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf))
if err != nil {
@@ -115,11 +117,12 @@ func TestSocketWritevLen0(t *testing.T) {
if err != nil {
t.Fatalf("host socket creation failed: %v", err)
}
- socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ socket, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer socket.DecRef()
+ defer socket.DecRef(ctx)
n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil))
if err != nil {
t.Fatalf("socket writev failed: %v", err)
@@ -136,11 +139,12 @@ func TestSocketSendMsgLen0(t *testing.T) {
if err != nil {
t.Fatalf("host socket creation failed: %v", err)
}
- sfile, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ sfile, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer sfile.DecRef()
+ defer sfile.DecRef(ctx)
s := sfile.FileOperations.(socket.Socket)
n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{})
@@ -158,18 +162,19 @@ func TestListen(t *testing.T) {
if err != nil {
t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
}
- sfile1, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ sfile1, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer sfile1.DecRef()
+ defer sfile1.DecRef(ctx)
socket1 := sfile1.FileOperations.(socket.Socket)
- sfile2, err := newSocket(contexttest.Context(t), pair[1], false)
+ sfile2, err := newSocket(ctx, pair[1], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[1], err)
}
- defer sfile2.DecRef()
+ defer sfile2.DecRef(ctx)
socket2 := sfile2.FileOperations.(socket.Socket)
// Socketpairs can not be listened to.
@@ -185,11 +190,11 @@ func TestListen(t *testing.T) {
if err != nil {
t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
}
- sfile3, err := newSocket(contexttest.Context(t), sock, false)
+ sfile3, err := newSocket(ctx, sock, false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", sock, err)
}
- defer sfile3.DecRef()
+ defer sfile3.DecRef(ctx)
socket3 := sfile3.FileOperations.(socket.Socket)
// This socket is not bound so we can't listen on it.
@@ -237,9 +242,10 @@ func TestRelease(t *testing.T) {
}
c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
want := &ConnectedEndpoint{queue: c.queue}
- want.ref.DecRef()
+ ctx := contexttest.Context(t)
+ want.ref.DecRef(ctx)
fdnotifier.AddFD(int32(c.file.FD()), nil)
- c.Release()
+ c.Release(ctx)
if !reflect.DeepEqual(c, want) {
t.Errorf("got = %#v, want = %#v", c, want)
}
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 82a02fcb2..b5229098c 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -113,12 +113,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme
}
// Release implements fs.FileOperations.Release.
-func (t *TTYFileOperations) Release() {
+func (t *TTYFileOperations) Release(ctx context.Context) {
t.mu.Lock()
t.fgProcessGroup = nil
t.mu.Unlock()
- t.fileOperations.Release()
+ t.fileOperations.Release(ctx)
}
// Ioctl implements fs.FileOperations.Ioctl.
diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go
index ce397a5e3..c143f4ce2 100644
--- a/pkg/sentry/fs/host/wait_test.go
+++ b/pkg/sentry/fs/host/wait_test.go
@@ -39,7 +39,7 @@ func TestWait(t *testing.T) {
t.Fatalf("NewFile failed: %v", err)
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
r := file.Readiness(waiter.EventIn)
if r != 0 {
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index a34fbc946..b79cd9877 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -96,13 +96,12 @@ func NewInode(ctx context.Context, iops InodeOperations, msrc *MountSource, satt
}
// DecRef drops a reference on the Inode.
-func (i *Inode) DecRef() {
- i.DecRefWithDestructor(i.destroy)
+func (i *Inode) DecRef(ctx context.Context) {
+ i.DecRefWithDestructor(ctx, i.destroy)
}
// destroy releases the Inode and releases the msrc reference taken.
-func (i *Inode) destroy() {
- ctx := context.Background()
+func (i *Inode) destroy(ctx context.Context) {
if err := i.WriteOut(ctx); err != nil {
// FIXME(b/65209558): Mark as warning again once noatime is
// properly supported.
@@ -122,12 +121,12 @@ func (i *Inode) destroy() {
i.Watches.targetDestroyed()
if i.overlay != nil {
- i.overlay.release()
+ i.overlay.release(ctx)
} else {
i.InodeOperations.Release(ctx)
}
- i.MountSource.DecRef()
+ i.MountSource.DecRef(ctx)
}
// Mappable calls i.InodeOperations.Mappable.
diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go
index efd3c962b..9911a00c2 100644
--- a/pkg/sentry/fs/inode_inotify.go
+++ b/pkg/sentry/fs/inode_inotify.go
@@ -17,6 +17,7 @@ package fs
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -136,11 +137,11 @@ func (w *Watches) Notify(name string, events, cookie uint32) {
}
// Unpin unpins dirent from all watches in this set.
-func (w *Watches) Unpin(d *Dirent) {
+func (w *Watches) Unpin(ctx context.Context, d *Dirent) {
w.mu.RLock()
defer w.mu.RUnlock()
for _, watch := range w.ws {
- watch.Unpin(d)
+ watch.Unpin(ctx, d)
}
}
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 537c8d257..dc2e353d9 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -85,7 +85,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
upperInode = child.Inode
upperInode.IncRef()
}
- child.DecRef()
+ child.DecRef(ctx)
}
// Are we done?
@@ -108,7 +108,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
entry, err := newOverlayEntry(ctx, upperInode, nil, false)
if err != nil {
// Don't leak resources.
- upperInode.DecRef()
+ upperInode.DecRef(ctx)
parent.copyMu.RUnlock()
return nil, false, err
}
@@ -129,7 +129,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
if err != nil && err != syserror.ENOENT {
// Don't leak resources.
if upperInode != nil {
- upperInode.DecRef()
+ upperInode.DecRef(ctx)
}
parent.copyMu.RUnlock()
return nil, false, err
@@ -152,7 +152,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
}
}
}
- child.DecRef()
+ child.DecRef(ctx)
}
}
@@ -183,7 +183,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
// unnecessary because we don't need to copy-up and we will always
// operate (e.g. read/write) on the upper Inode.
if !IsDir(upperInode.StableAttr) {
- lowerInode.DecRef()
+ lowerInode.DecRef(ctx)
lowerInode = nil
}
}
@@ -194,10 +194,10 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
// Well, not quite, we failed at the last moment, how depressing.
// Be sure not to leak resources.
if upperInode != nil {
- upperInode.DecRef()
+ upperInode.DecRef(ctx)
}
if lowerInode != nil {
- lowerInode.DecRef()
+ lowerInode.DecRef(ctx)
}
parent.copyMu.RUnlock()
return nil, false, err
@@ -248,7 +248,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st
// user) will clobber the real path for the underlying Inode.
upperFile.Dirent.Inode.IncRef()
upperDirent := NewTransientDirent(upperFile.Dirent.Inode)
- upperFile.Dirent.DecRef()
+ upperFile.Dirent.DecRef(ctx)
upperFile.Dirent = upperDirent
// Create the overlay inode and dirent. We need this to construct the
@@ -259,7 +259,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st
// The overlay file created below with NewFile will take a reference on
// the overlayDirent, and it should be the only thing holding a
// reference at the time of creation, so we must drop this reference.
- defer overlayDirent.DecRef()
+ defer overlayDirent.DecRef(ctx)
// Create a new overlay file that wraps the upper file.
flags.Pread = upperFile.Flags().Pread
@@ -399,7 +399,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
if !replaced.IsNegative() && IsDir(replaced.Inode.StableAttr) {
children, err := readdirOne(ctx, replaced)
if err != nil {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return err
}
@@ -407,12 +407,12 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
// included among the returned children, so we don't
// need to bother checking for them.
if len(children) > 0 {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syserror.ENOTEMPTY
}
}
- replaced.DecRef()
+ replaced.DecRef(ctx)
}
}
@@ -455,12 +455,12 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri
// Grab the inode and drop the dirent, we don't need it.
inode := d.Inode
inode.IncRef()
- d.DecRef()
+ d.DecRef(ctx)
// Create a new overlay entry and dirent for the socket.
entry, err := newOverlayEntry(ctx, inode, nil, false)
if err != nil {
- inode.DecRef()
+ inode.DecRef(ctx)
return nil, err
}
// Use the parent's MountSource, since that corresponds to the overlay,
@@ -672,7 +672,7 @@ func overlayGetlink(ctx context.Context, o *overlayEntry) (*Dirent, error) {
// ground and claim that jumping around the filesystem like this
// is not supported.
name, _ := dirent.FullName(nil)
- dirent.DecRef()
+ dirent.DecRef(ctx)
// Claim that the path is not accessible.
err = syserror.EACCES
diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go
index 389c219d6..aa9851b26 100644
--- a/pkg/sentry/fs/inode_overlay_test.go
+++ b/pkg/sentry/fs/inode_overlay_test.go
@@ -316,7 +316,7 @@ func TestCacheFlush(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
ctx = &rootContext{
Context: ctx,
@@ -345,7 +345,7 @@ func TestCacheFlush(t *testing.T) {
}
// Drop the file reference.
- file.DecRef()
+ file.DecRef(ctx)
// Dirent should have 2 refs left.
if got, want := dirent.ReadRefs(), 2; int(got) != want {
@@ -361,7 +361,7 @@ func TestCacheFlush(t *testing.T) {
}
// Drop our ref.
- dirent.DecRef()
+ dirent.DecRef(ctx)
// We should be back to zero refs.
if got, want := dirent.ReadRefs(), 0; int(got) != want {
@@ -398,7 +398,7 @@ func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags
if err != nil {
return nil, err
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
// Wrap the file's FileOperations in a dirFile.
fops := &dirFile{
FileOperations: file.FileOperations,
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index e3a715c1f..c5c07d564 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -80,7 +80,7 @@ func NewInotify(ctx context.Context) *Inotify {
// Release implements FileOperations.Release. Release removes all watches and
// frees all resources for an inotify instance.
-func (i *Inotify) Release() {
+func (i *Inotify) Release(ctx context.Context) {
// We need to hold i.mu to avoid a race with concurrent calls to
// Inotify.targetDestroyed from Watches. There's no risk of Watches
// accessing this Inotify after the destructor ends, because we remove all
@@ -93,7 +93,7 @@ func (i *Inotify) Release() {
// the owner's destructor.
w.target.Watches.Remove(w.ID())
// Don't leak any references to the target, held by pins in the watch.
- w.destroy()
+ w.destroy(ctx)
}
}
@@ -321,7 +321,7 @@ func (i *Inotify) AddWatch(target *Dirent, mask uint32) int32 {
//
// RmWatch looks up an inotify watch for the given 'wd' and configures the
// target dirent to stop sending events to this inotify instance.
-func (i *Inotify) RmWatch(wd int32) error {
+func (i *Inotify) RmWatch(ctx context.Context, wd int32) error {
i.mu.Lock()
// Find the watch we were asked to removed.
@@ -346,7 +346,7 @@ func (i *Inotify) RmWatch(wd int32) error {
i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0))
// Remove all pins.
- watch.destroy()
+ watch.destroy(ctx)
return nil
}
diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go
index 900cba3ca..605423d22 100644
--- a/pkg/sentry/fs/inotify_watch.go
+++ b/pkg/sentry/fs/inotify_watch.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -105,12 +106,12 @@ func (w *Watch) Pin(d *Dirent) {
// Unpin drops any extra refs held on dirent due to a previous Pin
// call. Calling Unpin multiple times for the same dirent, or on a dirent
// without a corresponding Pin call is a no-op.
-func (w *Watch) Unpin(d *Dirent) {
+func (w *Watch) Unpin(ctx context.Context, d *Dirent) {
w.mu.Lock()
defer w.mu.Unlock()
if w.pins[d] {
delete(w.pins, d)
- d.DecRef()
+ d.DecRef(ctx)
}
}
@@ -125,11 +126,11 @@ func (w *Watch) TargetDestroyed() {
// this watch. Destroy does not cause any new events to be generated. The caller
// is responsible for ensuring there are no outstanding references to this
// watch.
-func (w *Watch) destroy() {
+func (w *Watch) destroy(ctx context.Context) {
w.mu.Lock()
defer w.mu.Unlock()
for d := range w.pins {
- d.DecRef()
+ d.DecRef(ctx)
}
w.pins = nil
}
diff --git a/pkg/sentry/fs/mount.go b/pkg/sentry/fs/mount.go
index 37bae6810..ee69b10e8 100644
--- a/pkg/sentry/fs/mount.go
+++ b/pkg/sentry/fs/mount.go
@@ -51,7 +51,7 @@ type MountSourceOperations interface {
DirentOperations
// Destroy destroys the MountSource.
- Destroy()
+ Destroy(ctx context.Context)
// Below are MountSourceOperations that do not conform to Linux.
@@ -165,16 +165,16 @@ func (msrc *MountSource) DecDirentRefs() {
}
}
-func (msrc *MountSource) destroy() {
+func (msrc *MountSource) destroy(ctx context.Context) {
if c := msrc.DirentRefs(); c != 0 {
panic(fmt.Sprintf("MountSource with non-zero direntRefs is being destroyed: %d", c))
}
- msrc.MountSourceOperations.Destroy()
+ msrc.MountSourceOperations.Destroy(ctx)
}
// DecRef drops a reference on the MountSource.
-func (msrc *MountSource) DecRef() {
- msrc.DecRefWithDestructor(msrc.destroy)
+func (msrc *MountSource) DecRef(ctx context.Context) {
+ msrc.DecRefWithDestructor(ctx, msrc.destroy)
}
// FlushDirentRefs drops all references held by the MountSource on Dirents.
@@ -264,7 +264,7 @@ func (*SimpleMountSourceOperations) ResetInodeMappings() {}
func (*SimpleMountSourceOperations) SaveInodeMapping(*Inode, string) {}
// Destroy implements MountSourceOperations.Destroy.
-func (*SimpleMountSourceOperations) Destroy() {}
+func (*SimpleMountSourceOperations) Destroy(context.Context) {}
// Info defines attributes of a filesystem.
type Info struct {
diff --git a/pkg/sentry/fs/mount_overlay.go b/pkg/sentry/fs/mount_overlay.go
index 78e35b1e6..7badc75d6 100644
--- a/pkg/sentry/fs/mount_overlay.go
+++ b/pkg/sentry/fs/mount_overlay.go
@@ -115,9 +115,9 @@ func (o *overlayMountSourceOperations) SaveInodeMapping(inode *Inode, path strin
}
// Destroy drops references on the upper and lower MountSource.
-func (o *overlayMountSourceOperations) Destroy() {
- o.upper.DecRef()
- o.lower.DecRef()
+func (o *overlayMountSourceOperations) Destroy(ctx context.Context) {
+ o.upper.DecRef(ctx)
+ o.lower.DecRef(ctx)
}
// type overlayFilesystem is the filesystem for overlay mounts.
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
index a3d10770b..6c296f5d0 100644
--- a/pkg/sentry/fs/mount_test.go
+++ b/pkg/sentry/fs/mount_test.go
@@ -18,6 +18,7 @@ import (
"fmt"
"testing"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
)
@@ -32,13 +33,13 @@ func cacheReallyContains(cache *DirentCache, d *Dirent) bool {
return false
}
-func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
+func mountPathsAre(ctx context.Context, root *Dirent, got []*Mount, want ...string) error {
gotPaths := make(map[string]struct{}, len(got))
gotStr := make([]string, len(got))
for i, g := range got {
if groot := g.Root(); groot != nil {
name, _ := groot.FullName(root)
- groot.DecRef()
+ groot.DecRef(ctx)
gotStr[i] = name
gotPaths[name] = struct{}{}
}
@@ -69,7 +70,7 @@ func TestMountSourceOnlyCachedOnce(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
rootDirent := mm.Root()
- defer rootDirent.DecRef()
+ defer rootDirent.DecRef(ctx)
// Get a child of the root which we will mount over. Note that the
// MockInodeOperations causes Walk to always succeed.
@@ -125,7 +126,7 @@ func TestAllMountsUnder(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
rootDirent := mm.Root()
- defer rootDirent.DecRef()
+ defer rootDirent.DecRef(ctx)
// Add mounts at the following paths:
paths := []string{
@@ -150,14 +151,14 @@ func TestAllMountsUnder(t *testing.T) {
if err := mm.Mount(ctx, d, submountInode); err != nil {
t.Fatalf("could not mount at %q: %v", p, err)
}
- d.DecRef()
+ d.DecRef(ctx)
}
// mm root should contain all submounts (and does not include the root mount).
rootMnt := mm.FindMount(rootDirent)
submounts := mm.AllMountsUnder(rootMnt)
allPaths := append(paths, "/")
- if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil {
t.Error(err)
}
@@ -181,9 +182,9 @@ func TestAllMountsUnder(t *testing.T) {
if err != nil {
t.Fatalf("could not find path %q in mount manager: %v", "/foo", err)
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
submounts = mm.AllMountsUnder(mm.FindMount(d))
- if err := mountPathsAre(rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil {
t.Error(err)
}
@@ -193,9 +194,9 @@ func TestAllMountsUnder(t *testing.T) {
if err != nil {
t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err)
}
- defer waldo.DecRef()
+ defer waldo.DecRef(ctx)
submounts = mm.AllMountsUnder(mm.FindMount(waldo))
- if err := mountPathsAre(rootDirent, submounts, "/waldo"); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, "/waldo"); err != nil {
t.Error(err)
}
}
@@ -212,7 +213,7 @@ func TestUnmount(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
rootDirent := mm.Root()
- defer rootDirent.DecRef()
+ defer rootDirent.DecRef(ctx)
// Add mounts at the following paths:
paths := []string{
@@ -240,7 +241,7 @@ func TestUnmount(t *testing.T) {
if err := mm.Mount(ctx, d, submountInode); err != nil {
t.Fatalf("could not mount at %q: %v", p, err)
}
- d.DecRef()
+ d.DecRef(ctx)
}
allPaths := make([]string, len(paths)+1)
@@ -259,13 +260,13 @@ func TestUnmount(t *testing.T) {
if err := mm.Unmount(ctx, d, false); err != nil {
t.Fatalf("could not unmount at %q: %v", p, err)
}
- d.DecRef()
+ d.DecRef(ctx)
// Remove the path that has been unmounted and the check that the remaining
// mounts are still there.
allPaths = allPaths[:len(allPaths)-1]
submounts := mm.AllMountsUnder(rootMnt)
- if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil {
t.Error(err)
}
}
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 3f2bd0e87..d741c4339 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -234,7 +234,7 @@ func (mns *MountNamespace) flushMountSourceRefsLocked() {
// After destroy is called, the MountNamespace may continue to be referenced (for
// example via /proc/mounts), but should free all resources and shouldn't have
// Find* methods called.
-func (mns *MountNamespace) destroy() {
+func (mns *MountNamespace) destroy(ctx context.Context) {
mns.mu.Lock()
defer mns.mu.Unlock()
@@ -247,13 +247,13 @@ func (mns *MountNamespace) destroy() {
for _, mp := range mns.mounts {
// Drop the mount reference on all mounted dirents.
for ; mp != nil; mp = mp.previous {
- mp.root.DecRef()
+ mp.root.DecRef(ctx)
}
}
mns.mounts = nil
// Drop reference on the root.
- mns.root.DecRef()
+ mns.root.DecRef(ctx)
// Ensure that root cannot be accessed via this MountNamespace any
// more.
@@ -265,8 +265,8 @@ func (mns *MountNamespace) destroy() {
}
// DecRef implements RefCounter.DecRef with destructor mns.destroy.
-func (mns *MountNamespace) DecRef() {
- mns.DecRefWithDestructor(mns.destroy)
+func (mns *MountNamespace) DecRef(ctx context.Context) {
+ mns.DecRefWithDestructor(ctx, mns.destroy)
}
// withMountLocked prevents further walks to `node`, because `node` is about to
@@ -312,7 +312,7 @@ func (mns *MountNamespace) Mount(ctx context.Context, mountPoint *Dirent, inode
if err != nil {
return err
}
- defer replacement.DecRef()
+ defer replacement.DecRef(ctx)
// Set the mount's root dirent and id.
parentMnt := mns.findMountLocked(mountPoint)
@@ -394,7 +394,7 @@ func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly
panic(fmt.Sprintf("Last mount in the chain must be a undo mount: %+v", prev))
}
// Drop mount reference taken at the end of MountNamespace.Mount.
- prev.root.DecRef()
+ prev.root.DecRef(ctx)
} else {
mns.mounts[prev.root] = prev
}
@@ -496,11 +496,11 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path
// non-directory root is hopeless.
if current != root {
if !IsDir(current.Inode.StableAttr) {
- current.DecRef() // Drop reference from above.
+ current.DecRef(ctx) // Drop reference from above.
return nil, syserror.ENOTDIR
}
if err := current.Inode.CheckPermission(ctx, PermMask{Execute: true}); err != nil {
- current.DecRef() // Drop reference from above.
+ current.DecRef(ctx) // Drop reference from above.
return nil, err
}
}
@@ -511,12 +511,12 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path
// Allow failed walks to cache the dirent, because no
// children will acquire a reference at the end.
current.maybeExtendReference()
- current.DecRef()
+ current.DecRef(ctx)
return nil, err
}
// Drop old reference.
- current.DecRef()
+ current.DecRef(ctx)
if remainder != "" {
// Ensure it's resolved, unless it's the last level.
@@ -570,11 +570,11 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema
case nil:
// Make sure we didn't exhaust the traversal budget.
if *remainingTraversals == 0 {
- target.DecRef()
+ target.DecRef(ctx)
return nil, syscall.ELOOP
}
- node.DecRef() // Drop the original reference.
+ node.DecRef(ctx) // Drop the original reference.
return target, nil
case syscall.ENOLINK:
@@ -582,7 +582,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema
return node, nil
case ErrResolveViaReadlink:
- defer node.DecRef() // See above.
+ defer node.DecRef(ctx) // See above.
// First, check if we should traverse.
if *remainingTraversals == 0 {
@@ -608,7 +608,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema
return d, err
default:
- node.DecRef() // Drop for err; see above.
+ node.DecRef(ctx) // Drop for err; see above.
// Propagate the error.
return nil, err
diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go
index a69b41468..975d6cbc9 100644
--- a/pkg/sentry/fs/mounts_test.go
+++ b/pkg/sentry/fs/mounts_test.go
@@ -51,7 +51,7 @@ func TestFindLink(t *testing.T) {
}
root := mm.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
foo, err := root.Walk(ctx, root, "foo")
if err != nil {
t.Fatalf("Error walking to foo: %v", err)
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
index a8ae7d81d..35013a21b 100644
--- a/pkg/sentry/fs/overlay.go
+++ b/pkg/sentry/fs/overlay.go
@@ -107,7 +107,7 @@ func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags Mount
msrc := newOverlayMountSource(ctx, upper.MountSource, lower.MountSource, flags)
overlay, err := newOverlayEntry(ctx, upper, lower, true)
if err != nil {
- msrc.DecRef()
+ msrc.DecRef(ctx)
return nil, err
}
@@ -130,7 +130,7 @@ func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode,
msrc := newOverlayMountSource(ctx, upperMS, lower.MountSource, flags)
overlay, err := newOverlayEntry(ctx, nil, lower, true)
if err != nil {
- msrc.DecRef()
+ msrc.DecRef(ctx)
return nil, err
}
return newOverlayInode(ctx, overlay, msrc), nil
@@ -230,16 +230,16 @@ func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExist
}, nil
}
-func (o *overlayEntry) release() {
+func (o *overlayEntry) release(ctx context.Context) {
// We drop a reference on upper and lower file system Inodes
// rather than releasing them, because in-memory filesystems
// may hold an extra reference to these Inodes so that they
// stay in memory.
if o.upper != nil {
- o.upper.DecRef()
+ o.upper.DecRef(ctx)
}
if o.lower != nil {
- o.lower.DecRef()
+ o.lower.DecRef(ctx)
}
}
diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go
index 35972e23c..45523adf8 100644
--- a/pkg/sentry/fs/proc/fds.go
+++ b/pkg/sentry/fs/proc/fds.go
@@ -56,11 +56,11 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF
// readDescriptors reads fds in the task starting at offset, and calls the
// toDentAttr callback for each to get a DentAttr, which it then emits. This is
// a helper for implementing fs.InodeOperations.Readdir.
-func readDescriptors(t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) {
+func readDescriptors(ctx context.Context, t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) {
var fds []int32
t.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.GetFDs()
+ fds = fdTable.GetFDs(ctx)
}
})
@@ -116,7 +116,7 @@ func (f *fd) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error
func (f *fd) Readlink(ctx context.Context, _ *fs.Inode) (string, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
n, _ := f.file.Dirent.FullName(root)
return n, nil
@@ -135,13 +135,7 @@ func (f *fd) Truncate(context.Context, *fs.Inode, int64) error {
func (f *fd) Release(ctx context.Context) {
f.Symlink.Release(ctx)
- f.file.DecRef()
-}
-
-// Close releases the reference on the file.
-func (f *fd) Close() error {
- f.file.DecRef()
- return nil
+ f.file.DecRef(ctx)
}
// fdDir is an InodeOperations for /proc/TID/fd.
@@ -227,7 +221,7 @@ func (f *fdDirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySer
if f.isInfoFile {
typ = fs.Symlink
}
- return readDescriptors(f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr {
+ return readDescriptors(ctx, f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr {
return fs.GenericDentAttr(typ, device.ProcDevice)
})
}
@@ -261,7 +255,7 @@ func (fdid *fdInfoDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs
// locks, and other data. For now we only have flags.
// See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
flags := file.Flags().ToLinux() | fdFlags.ToLinuxFileFlags()
- file.DecRef()
+ file.DecRef(ctx)
contents := []byte(fmt.Sprintf("flags:\t0%o\n", flags))
return newStaticProcInode(ctx, dir.MountSource, contents)
})
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index 1fc9c703c..6a63c47b3 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -47,7 +47,7 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
// The task has been destroyed. Nothing to show here.
return
}
- defer rootDir.DecRef()
+ defer rootDir.DecRef(t)
mnt := t.MountNamespace().FindMount(rootDir)
if mnt == nil {
@@ -64,7 +64,7 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
continue // No longer valid.
}
mountPath, desc := mroot.FullName(rootDir)
- mroot.DecRef()
+ mroot.DecRef(t)
if !desc {
// MountSources that are not descendants of the chroot jail are ignored.
continue
@@ -97,7 +97,7 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
if mroot == nil {
return // No longer valid.
}
- defer mroot.DecRef()
+ defer mroot.DecRef(ctx)
// Format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
@@ -216,7 +216,7 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
if root == nil {
return // No longer valid.
}
- defer root.DecRef()
+ defer root.DecRef(ctx)
flags := root.Inode.MountSource.Flags
opts := "rw"
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index bd18177d4..83a43aa26 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -419,7 +419,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
}
sfile := s.(*fs.File)
if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
- s.DecRef()
+ s.DecRef(ctx)
// Not a unix socket.
continue
}
@@ -479,7 +479,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
}
fmt.Fprintf(&buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
data := []seqfile.SeqData{
@@ -574,7 +574,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
}
if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
- s.DecRef()
+ s.DecRef(ctx)
// Not tcp4 sockets.
continue
}
@@ -664,7 +664,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne
fmt.Fprintf(&buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
data := []seqfile.SeqData{
@@ -752,7 +752,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
}
if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
- s.DecRef()
+ s.DecRef(ctx)
// Not udp4 socket.
continue
}
@@ -822,7 +822,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
fmt.Fprintf(&buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
data := []seqfile.SeqData{
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index c659224a7..77e0e1d26 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -213,7 +213,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent
// Add dot and dotdot.
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dot, dotdot := file.Dirent.GetDotAttrs(root)
names = append(names, ".", "..")
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 4bbe90198..9cf7f2a62 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -185,7 +185,7 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry
// Serialize "." and "..".
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dot, dotdot := file.Dirent.GetDotAttrs(root)
if err := dirCtx.DirEmit(".", dot); err != nil {
@@ -295,7 +295,7 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
if err != nil {
return "", err
}
- defer exec.DecRef()
+ defer exec.DecRef(ctx)
return exec.PathnameWithDeleted(ctx), nil
}
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
index bfa304552..f4fcddecb 100644
--- a/pkg/sentry/fs/ramfs/dir.go
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -219,7 +219,7 @@ func (d *Dir) Remove(ctx context.Context, _ *fs.Inode, name string) error {
}
// Remove our reference on the inode.
- inode.DecRef()
+ inode.DecRef(ctx)
return nil
}
@@ -250,7 +250,7 @@ func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) err
}
// Remove our reference on the inode.
- inode.DecRef()
+ inode.DecRef(ctx)
return nil
}
@@ -326,7 +326,7 @@ func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.F
// Create the Dirent and corresponding file.
created := fs.NewDirent(ctx, inode, name)
- defer created.DecRef()
+ defer created.DecRef(ctx)
return created.Inode.GetFile(ctx, created, flags)
}
@@ -412,11 +412,11 @@ func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, ol
}
// Release implements fs.InodeOperation.Release.
-func (d *Dir) Release(_ context.Context) {
+func (d *Dir) Release(ctx context.Context) {
// Drop references on all children.
d.mu.Lock()
for _, i := range d.children {
- i.DecRef()
+ i.DecRef(ctx)
}
d.mu.Unlock()
}
@@ -456,7 +456,7 @@ func (dfo *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirC
func (dfo *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
@@ -473,13 +473,13 @@ func hasChildren(ctx context.Context, inode *fs.Inode) (bool, error) {
// dropped when that dirent is destroyed.
inode.IncRef()
d := fs.NewTransientDirent(inode)
- defer d.DecRef()
+ defer d.DecRef(ctx)
file, err := inode.GetFile(ctx, d, fs.FileFlags{Read: true})
if err != nil {
return false, err
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
ser := &fs.CollectEntriesSerializer{}
if err := file.Readdir(ctx, ser); err != nil {
@@ -530,7 +530,7 @@ func Rename(ctx context.Context, oldParent fs.InodeOperations, oldName string, n
if err != nil {
return err
}
- inode.DecRef()
+ inode.DecRef(ctx)
}
// Be careful, we may have already grabbed this mutex above.
diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go
index a6ed8b2c5..3e0d1e07e 100644
--- a/pkg/sentry/fs/ramfs/tree_test.go
+++ b/pkg/sentry/fs/ramfs/tree_test.go
@@ -67,7 +67,7 @@ func TestMakeDirectoryTree(t *testing.T) {
continue
}
root := mm.Root()
- defer mm.DecRef()
+ defer mm.DecRef(ctx)
for _, p := range test.subdirs {
maxTraversals := uint(0)
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 88c344089..f362ca9b6 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -55,7 +55,7 @@ type TimerOperations struct {
func NewFile(ctx context.Context, c ktime.Clock) *fs.File {
dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[timerfd]")
// Release the initial dirent reference after NewFile takes a reference.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
tops := &TimerOperations{}
tops.timer = ktime.NewTimer(c, tops)
// Timerfds reject writes, but the Write flag must be set in order to
@@ -65,7 +65,7 @@ func NewFile(ctx context.Context, c ktime.Clock) *fs.File {
}
// Release implements fs.FileOperations.Release.
-func (t *TimerOperations) Release() {
+func (t *TimerOperations) Release(context.Context) {
t.timer.Destroy()
}
diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go
index aaba35502..d4d613ea9 100644
--- a/pkg/sentry/fs/tmpfs/file_test.go
+++ b/pkg/sentry/fs/tmpfs/file_test.go
@@ -46,7 +46,7 @@ func newFile(ctx context.Context) *fs.File {
func TestGrow(t *testing.T) {
ctx := contexttest.Context(t)
f := newFile(ctx)
- defer f.DecRef()
+ defer f.DecRef(ctx)
abuf := bytes.Repeat([]byte{'a'}, 68)
n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0)
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 108654827..463f6189e 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -132,7 +132,7 @@ func (d *dirInodeOperations) Release(ctx context.Context) {
d.mu.Lock()
defer d.mu.Unlock()
- d.master.DecRef()
+ d.master.DecRef(ctx)
if len(d.slaves) != 0 {
panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
}
@@ -263,7 +263,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e
}
// masterClose is called when the master end of t is closed.
-func (d *dirInodeOperations) masterClose(t *Terminal) {
+func (d *dirInodeOperations) masterClose(ctx context.Context, t *Terminal) {
d.mu.Lock()
defer d.mu.Unlock()
@@ -277,7 +277,7 @@ func (d *dirInodeOperations) masterClose(t *Terminal) {
panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d))
}
- s.DecRef()
+ s.DecRef(ctx)
delete(d.slaves, t.n)
d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10))
}
@@ -322,7 +322,7 @@ func (df *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCt
func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go
index 8fe05ebe5..2d4d44bf3 100644
--- a/pkg/sentry/fs/tty/fs.go
+++ b/pkg/sentry/fs/tty/fs.go
@@ -108,4 +108,4 @@ func (superOperations) ResetInodeMappings() {}
func (superOperations) SaveInodeMapping(*fs.Inode, string) {}
// Destroy implements MountSourceOperations.Destroy.
-func (superOperations) Destroy() {}
+func (superOperations) Destroy(context.Context) {}
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index fe07fa929..e00746017 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -75,7 +75,7 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn
}
// Release implements fs.InodeOperations.Release.
-func (mi *masterInodeOperations) Release(ctx context.Context) {
+func (mi *masterInodeOperations) Release(context.Context) {
}
// Truncate implements fs.InodeOperations.Truncate.
@@ -120,9 +120,9 @@ type masterFileOperations struct {
var _ fs.FileOperations = (*masterFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (mf *masterFileOperations) Release() {
- mf.d.masterClose(mf.t)
- mf.t.DecRef()
+func (mf *masterFileOperations) Release(ctx context.Context) {
+ mf.d.masterClose(ctx, mf.t)
+ mf.t.DecRef(ctx)
}
// EventRegister implements waiter.Waitable.EventRegister.
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index 9871f6fc6..7c7292687 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -71,7 +71,7 @@ func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owne
// Release implements fs.InodeOperations.Release.
func (si *slaveInodeOperations) Release(ctx context.Context) {
- si.t.DecRef()
+ si.t.DecRef(ctx)
}
// Truncate implements fs.InodeOperations.Truncate.
@@ -106,7 +106,7 @@ type slaveFileOperations struct {
var _ fs.FileOperations = (*slaveFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (sf *slaveFileOperations) Release() {
+func (sf *slaveFileOperations) Release(context.Context) {
}
// EventRegister implements waiter.Waitable.EventRegister.
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
index 397e96045..2f5a43b84 100644
--- a/pkg/sentry/fs/user/path.go
+++ b/pkg/sentry/fs/user/path.go
@@ -82,7 +82,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
// Caller has no root. Don't bother traversing anything.
return "", syserror.ENOENT
}
- defer root.DecRef()
+ defer root.DecRef(ctx)
for _, p := range paths {
if !path.IsAbs(p) {
// Relative paths aren't safe, no one should be using them.
@@ -100,7 +100,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
if err != nil {
return "", err
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
// Check that it is a regular file.
if !fs.IsRegular(d.Inode.StableAttr) {
@@ -121,7 +121,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
for _, p := range paths {
if !path.IsAbs(p) {
// Relative paths aren't safe, no one should be using them.
@@ -148,7 +148,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam
if err != nil {
return "", err
}
- dentry.DecRef()
+ dentry.DecRef(ctx)
return binPath, nil
}
diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go
index f4d525523..936fd3932 100644
--- a/pkg/sentry/fs/user/user.go
+++ b/pkg/sentry/fs/user/user.go
@@ -62,7 +62,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K
// doesn't exist we will return the default home directory.
return defaultHome, nil
}
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
// Check read permissions on the file.
if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Read: true}); err != nil {
@@ -81,7 +81,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K
if err != nil {
return "", err
}
- defer f.DecRef()
+ defer f.DecRef(ctx)
r := &fileReader{
Ctx: ctx,
@@ -105,7 +105,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.
const defaultHome = "/"
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
creds := auth.CredentialsFromContext(ctx)
@@ -123,7 +123,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.
if err != nil {
return defaultHome, nil
}
- defer fd.DecRef()
+ defer fd.DecRef(ctx)
r := &fileReaderVFS2{
ctx: ctx,
diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go
index 7d8e9ac7c..12b786224 100644
--- a/pkg/sentry/fs/user/user_test.go
+++ b/pkg/sentry/fs/user/user_test.go
@@ -39,7 +39,7 @@ func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode
if err != nil {
return err
}
- defer etc.DecRef()
+ defer etc.DecRef(ctx)
switch mode.FileType() {
case 0:
// Don't create anything.
@@ -49,7 +49,7 @@ func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode
if err != nil {
return err
}
- defer passwd.DecRef()
+ defer passwd.DecRef(ctx)
if _, err := passwd.Writev(ctx, usermem.BytesIOSequence([]byte(contents))); err != nil {
return err
}
@@ -110,9 +110,9 @@ func TestGetExecUserHome(t *testing.T) {
if err != nil {
t.Fatalf("NewMountNamespace failed: %v", err)
}
- defer mns.DecRef()
+ defer mns.DecRef(ctx)
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
ctx = fs.WithRoot(ctx, root)
if err := createEtcPasswd(ctx, root, tc.passwdContents, tc.passwdMode); err != nil {
diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go
index 8e7590721..7e61209ee 100644
--- a/pkg/sentry/fsbridge/bridge.go
+++ b/pkg/sentry/fsbridge/bridge.go
@@ -44,7 +44,7 @@ type File interface {
IncRef()
// DecRef decrements reference.
- DecRef()
+ DecRef(ctx context.Context)
}
// Lookup provides a common interface to open files.
diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go
index 093ce1fb3..9785fd62a 100644
--- a/pkg/sentry/fsbridge/fs.go
+++ b/pkg/sentry/fsbridge/fs.go
@@ -49,7 +49,7 @@ func (f *fsFile) PathnameWithDeleted(ctx context.Context) string {
// global there.
return ""
}
- defer root.DecRef()
+ defer root.DecRef(ctx)
name, _ := f.file.Dirent.FullName(root)
return name
@@ -87,8 +87,8 @@ func (f *fsFile) IncRef() {
}
// DecRef implements File.
-func (f *fsFile) DecRef() {
- f.file.DecRef()
+func (f *fsFile) DecRef(ctx context.Context) {
+ f.file.DecRef(ctx)
}
// fsLookup implements Lookup interface using fs.File.
@@ -124,7 +124,7 @@ func (l *fsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptio
if err != nil {
return nil, err
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
if !resolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
return nil, syserror.ELOOP
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index 89168220a..323506d33 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -43,7 +43,7 @@ func NewVFSFile(file *vfs.FileDescription) File {
// PathnameWithDeleted implements File.
func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string {
root := vfs.RootFromContext(ctx)
- defer root.DecRef()
+ defer root.DecRef(ctx)
vfsObj := f.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem()
name, _ := vfsObj.PathnameWithDeleted(ctx, root, f.file.VirtualDentry())
@@ -86,8 +86,8 @@ func (f *VFSFile) IncRef() {
}
// DecRef implements File.
-func (f *VFSFile) DecRef() {
- f.file.DecRef()
+func (f *VFSFile) DecRef(ctx context.Context) {
+ f.file.DecRef(ctx)
}
// FileDescription returns the FileDescription represented by f. It does not
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index e6fda2b4f..7169e91af 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -103,9 +103,9 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
- fs.Filesystem.Release()
+ fs.Filesystem.Release(ctx)
}
// rootInode is the root directory inode for the devpts mounts.
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 1081fff52..3bb397f71 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -60,7 +60,7 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf
}
fd.LockFD.Init(&mi.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
- mi.DecRef()
+ mi.DecRef(ctx)
return nil, err
}
return &fd.vfsfd, nil
@@ -98,9 +98,9 @@ type masterFileDescription struct {
var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
-func (mfd *masterFileDescription) Release() {
+func (mfd *masterFileDescription) Release(ctx context.Context) {
mfd.inode.root.masterClose(mfd.t)
- mfd.inode.DecRef()
+ mfd.inode.DecRef(ctx)
}
// EventRegister implements waiter.Waitable.EventRegister.
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
index a91cae3ef..32e4e1908 100644
--- a/pkg/sentry/fsimpl/devpts/slave.go
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -56,7 +56,7 @@ func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs
}
fd.LockFD.Init(&si.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
- si.DecRef()
+ si.DecRef(ctx)
return nil, err
}
return &fd.vfsfd, nil
@@ -103,8 +103,8 @@ type slaveFileDescription struct {
var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil)
// Release implements fs.FileOperations.Release.
-func (sfd *slaveFileDescription) Release() {
- sfd.inode.DecRef()
+func (sfd *slaveFileDescription) Release(ctx context.Context) {
+ sfd.inode.DecRef(ctx)
}
// EventRegister implements waiter.Waitable.EventRegister.
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index d0e06cdc0..2ed5fa8a9 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -92,9 +92,9 @@ func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth
}
// Release must be called when a is no longer in use.
-func (a *Accessor) Release() {
- a.root.DecRef()
- a.mntns.DecRef()
+func (a *Accessor) Release(ctx context.Context) {
+ a.root.DecRef(ctx)
+ a.mntns.DecRef(ctx)
}
// accessorContext implements context.Context by extending an existing
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
index b6d52c015..747867cca 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -30,7 +30,7 @@ func TestDevtmpfs(t *testing.T) {
creds := auth.CredentialsFromContext(ctx)
vfsObj := &vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
// Register tmpfs just so that we can have a root filesystem that isn't
@@ -48,9 +48,9 @@ func TestDevtmpfs(t *testing.T) {
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
devpop := vfs.PathOperation{
Root: root,
Start: root,
@@ -69,7 +69,7 @@ func TestDevtmpfs(t *testing.T) {
if err != nil {
t.Fatalf("failed to create devtmpfs.Accessor: %v", err)
}
- defer a.Release()
+ defer a.Release(ctx)
// Create "userspace-initialized" files using a devtmpfs.Accessor.
if err := a.UserspaceInit(ctx); err != nil {
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index d12d78b84..812171fa3 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -59,9 +59,9 @@ type EventFileDescription struct {
var _ vfs.FileDescriptionImpl = (*EventFileDescription)(nil)
// New creates a new event fd.
-func New(vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) {
+func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) {
vd := vfsObj.NewAnonVirtualDentry("[eventfd]")
- defer vd.DecRef()
+ defer vd.DecRef(ctx)
efd := &EventFileDescription{
val: initVal,
semMode: semMode,
@@ -107,7 +107,7 @@ func (efd *EventFileDescription) HostFD() (int, error) {
}
// Release implements FileDescriptionImpl.Release()
-func (efd *EventFileDescription) Release() {
+func (efd *EventFileDescription) Release(context.Context) {
efd.mu.Lock()
defer efd.mu.Unlock()
if efd.hostfd >= 0 {
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd_test.go b/pkg/sentry/fsimpl/eventfd/eventfd_test.go
index 20e3adffc..49916fa81 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd_test.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd_test.go
@@ -36,16 +36,16 @@ func TestEventFD(t *testing.T) {
for _, initVal := range initVals {
ctx := contexttest.Context(t)
vfsObj := &vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
// Make a new eventfd that is writable.
- eventfd, err := New(vfsObj, initVal, false, linux.O_RDWR)
+ eventfd, err := New(ctx, vfsObj, initVal, false, linux.O_RDWR)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
- defer eventfd.DecRef()
+ defer eventfd.DecRef(ctx)
// Register a callback for a write event.
w, ch := waiter.NewChannelEntry(nil)
@@ -74,16 +74,16 @@ func TestEventFD(t *testing.T) {
func TestEventFDStat(t *testing.T) {
ctx := contexttest.Context(t)
vfsObj := &vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
// Make a new eventfd that is writable.
- eventfd, err := New(vfsObj, 0, false, linux.O_RDWR)
+ eventfd, err := New(ctx, vfsObj, 0, false, linux.O_RDWR)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
- defer eventfd.DecRef()
+ defer eventfd.DecRef(ctx)
statx, err := eventfd.Stat(ctx, vfs.StatOptions{
Mask: linux.STATX_BASIC_STATS,
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index 89caee3df..8f7d5a9bb 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -53,7 +53,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := &vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
return nil, nil, nil, nil, err
}
vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -68,7 +68,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
root := mntns.Root()
tearDown := func() {
- root.DecRef()
+ root.DecRef(ctx)
if err := f.Close(); err != nil {
b.Fatalf("tearDown failed: %v", err)
@@ -169,7 +169,7 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount point: %v", err)
}
- defer mountPoint.DecRef()
+ defer mountPoint.DecRef(ctx)
// Create extfs submount.
mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop)
diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go
index 55902322a..7a1b4219f 100644
--- a/pkg/sentry/fsimpl/ext/dentry.go
+++ b/pkg/sentry/fsimpl/ext/dentry.go
@@ -15,6 +15,7 @@
package ext
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -55,7 +56,7 @@ func (d *dentry) TryIncRef() bool {
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef() {
+func (d *dentry) DecRef(ctx context.Context) {
// FIXME(b/134676337): filesystem.mu may not be locked as required by
// inode.decRef().
d.inode.decRef()
@@ -64,7 +65,7 @@ func (d *dentry) DecRef() {
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
//
// TODO(b/134676337): Implement inotify.
-func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {}
+func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {}
// Watches implements vfs.DentryImpl.Watches.
//
@@ -76,4 +77,4 @@ func (d *dentry) Watches() *vfs.Watches {
// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
//
// TODO(b/134676337): Implement inotify.
-func (d *dentry) OnZeroWatches() {}
+func (d *dentry) OnZeroWatches(context.Context) {}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 357512c7e..0fc01668d 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -142,7 +142,7 @@ type directoryFD struct {
var _ vfs.FileDescriptionImpl = (*directoryFD)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *directoryFD) Release() {
+func (fd *directoryFD) Release(ctx context.Context) {
if fd.iter == nil {
return
}
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index dac6effbf..08ffc2834 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -123,32 +123,32 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs.vfsfs.Init(vfsObj, &fsType, &fs)
fs.sb, err = readSuperBlock(dev)
if err != nil {
- fs.vfsfs.DecRef()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
if fs.sb.Magic() != linux.EXT_SUPER_MAGIC {
// mount(2) specifies that EINVAL should be returned if the superblock is
// invalid.
- fs.vfsfs.DecRef()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, syserror.EINVAL
}
// Refuse to mount if the filesystem is incompatible.
if !isCompatible(fs.sb) {
- fs.vfsfs.DecRef()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, syserror.EINVAL
}
fs.bgs, err = readBlockGroups(dev, fs.sb)
if err != nil {
- fs.vfsfs.DecRef()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode)
if err != nil {
- fs.vfsfs.DecRef()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
rootInode.incRef()
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 64e9a579f..2dbaee287 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -65,7 +65,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := &vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -80,7 +80,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
root := mntns.Root()
tearDown := func() {
- root.DecRef()
+ root.DecRef(ctx)
if err := f.Close(); err != nil {
t.Fatalf("tearDown failed: %v", err)
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 557963e03..c714ddf73 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -84,7 +84,7 @@ var _ vfs.FilesystemImpl = (*filesystem)(nil)
// - filesystem.mu must be locked (for writing if write param is true).
// - !rp.Done().
// - inode == vfsd.Impl().(*Dentry).inode.
-func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) {
+func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) {
if !inode.isDir() {
return nil, nil, syserror.ENOTDIR
}
@@ -100,7 +100,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo
}
d := vfsd.Impl().(*dentry)
if name == ".." {
- isRoot, err := rp.CheckRoot(vfsd)
+ isRoot, err := rp.CheckRoot(ctx, vfsd)
if err != nil {
return nil, nil, err
}
@@ -108,7 +108,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo
rp.Advance()
return vfsd, inode, nil
}
- if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
return nil, nil, err
}
rp.Advance()
@@ -143,7 +143,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo
child.name = name
dir.childCache[name] = child
}
- if err := rp.CheckMount(&child.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
return nil, nil, err
}
if child.inode.isSymlink() && rp.ShouldFollowSymlink() {
@@ -167,12 +167,12 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo
//
// Preconditions:
// - filesystem.mu must be locked (for writing if write param is true).
-func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+func walkLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
vfsd := rp.Start()
inode := vfsd.Impl().(*dentry).inode
for !rp.Done() {
var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write)
if err != nil {
return nil, nil, err
}
@@ -196,12 +196,12 @@ func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error)
// Preconditions:
// - filesystem.mu must be locked (for writing if write param is true).
// - !rp.Done().
-func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+func walkParentLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
vfsd := rp.Start()
inode := vfsd.Impl().(*dentry).inode
for !rp.Final() {
var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write)
if err != nil {
return nil, nil, err
}
@@ -216,7 +216,7 @@ func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, e
// the rp till the parent of the last component which should be an existing
// directory. If parent is false then resolves rp entirely. Attemps to resolve
// the path as far as it can with a read lock and upgrades the lock if needed.
-func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) {
+func (fs *filesystem) walk(ctx context.Context, rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) {
var (
vfsd *vfs.Dentry
inode *inode
@@ -227,9 +227,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in
// of disk. This reduces congestion (allows concurrent walks).
fs.mu.RLock()
if parent {
- vfsd, inode, err = walkParentLocked(rp, false)
+ vfsd, inode, err = walkParentLocked(ctx, rp, false)
} else {
- vfsd, inode, err = walkLocked(rp, false)
+ vfsd, inode, err = walkLocked(ctx, rp, false)
}
fs.mu.RUnlock()
@@ -238,9 +238,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in
// walk is fine as this is a read only filesystem.
fs.mu.Lock()
if parent {
- vfsd, inode, err = walkParentLocked(rp, true)
+ vfsd, inode, err = walkParentLocked(ctx, rp, true)
} else {
- vfsd, inode, err = walkLocked(rp, true)
+ vfsd, inode, err = walkLocked(ctx, rp, true)
}
fs.mu.Unlock()
}
@@ -283,7 +283,7 @@ func (fs *filesystem) statTo(stat *linux.Statfs) {
// AccessAt implements vfs.Filesystem.Impl.AccessAt.
func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -292,7 +292,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
- vfsd, inode, err := fs.walk(rp, false)
+ vfsd, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return nil, err
}
@@ -312,7 +312,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
- vfsd, inode, err := fs.walk(rp, true)
+ vfsd, inode, err := fs.walk(ctx, rp, true)
if err != nil {
return nil, err
}
@@ -322,7 +322,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
// OpenAt implements vfs.FilesystemImpl.OpenAt.
func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- vfsd, inode, err := fs.walk(rp, false)
+ vfsd, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return nil, err
}
@@ -336,7 +336,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return "", err
}
@@ -349,7 +349,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// StatAt implements vfs.FilesystemImpl.StatAt.
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return linux.Statx{}, err
}
@@ -360,7 +360,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
- if _, _, err := fs.walk(rp, false); err != nil {
+ if _, _, err := fs.walk(ctx, rp, false); err != nil {
return linux.Statfs{}, err
}
@@ -370,7 +370,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
}
@@ -390,7 +390,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EEXIST
}
- if _, _, err := fs.walk(rp, true); err != nil {
+ if _, _, err := fs.walk(ctx, rp, true); err != nil {
return err
}
@@ -403,7 +403,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
- if _, _, err := fs.walk(rp, true); err != nil {
+ if _, _, err := fs.walk(ctx, rp, true); err != nil {
return err
}
@@ -416,7 +416,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
- _, _, err := fs.walk(rp, true)
+ _, _, err := fs.walk(ctx, rp, true)
if err != nil {
return err
}
@@ -430,7 +430,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return syserror.ENOENT
}
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -440,7 +440,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -454,7 +454,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -468,7 +468,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return syserror.EEXIST
}
- _, _, err := fs.walk(rp, true)
+ _, _, err := fs.walk(ctx, rp, true)
if err != nil {
return err
}
@@ -478,7 +478,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -492,7 +492,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return nil, err
}
@@ -506,7 +506,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return nil, err
}
@@ -515,7 +515,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return "", err
}
@@ -524,7 +524,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -533,7 +533,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index 66d14bb95..e73e740d6 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -79,7 +79,7 @@ type regularFileFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {}
+func (fd *regularFileFD) Release(context.Context) {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index 62efd4095..2fd0d1fa8 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -73,7 +73,7 @@ type symlinkFD struct {
var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *symlinkFD) Release() {}
+func (fd *symlinkFD) Release(context.Context) {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index 2225076bc..e522ff9a0 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -99,7 +99,7 @@ type DeviceFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *DeviceFD) Release() {
+func (fd *DeviceFD) Release(context.Context) {
fd.fs.conn.connected = false
}
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 84c222ad6..1ffe7ccd2 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -356,12 +356,12 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque
vfsObj := &vfs.VirtualFilesystem{}
fuseDev := &DeviceFD{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(system.Ctx); err != nil {
return nil, nil, err
}
vd := vfsObj.NewAnonVirtualDentry("genCountFD")
- defer vd.DecRef()
+ defer vd.DecRef(system.Ctx)
if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, nil, err
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 200a93bbf..a1405f7c3 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -191,9 +191,9 @@ func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
- fs.Filesystem.Release()
+ fs.Filesystem.Release(ctx)
}
// inode implements kernfs.Inode.
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 8c7c8e1b3..1679066ba 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -122,7 +122,7 @@ type directoryFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *directoryFD) Release() {
+func (fd *directoryFD) Release(context.Context) {
}
// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
@@ -139,7 +139,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fd.dirents = ds
}
- d.InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ d.InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent)
if d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 00e3c99cd..e6af37d0d 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -55,7 +55,7 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// Sync regular files.
for _, d := range ds {
err := d.syncSharedHandle(ctx)
- d.DecRef()
+ d.DecRef(ctx)
if err != nil && retErr == nil {
retErr = err
}
@@ -65,7 +65,7 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
err := sffd.Sync(ctx)
- sffd.vfsfd.DecRef()
+ sffd.vfsfd.DecRef(ctx)
if err != nil && retErr == nil {
retErr = err
}
@@ -133,7 +133,7 @@ afterSymlink:
return d, nil
}
if name == ".." {
- if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil {
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
return nil, err
} else if isRoot || d.parent == nil {
rp.Advance()
@@ -146,7 +146,7 @@ afterSymlink:
//
// Call rp.CheckMount() before updating d.parent's metadata, since if
// we traverse to another mount then d.parent's metadata is irrelevant.
- if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
return nil, err
}
if d != d.parent && !d.cachedMetadataAuthoritative() {
@@ -164,7 +164,7 @@ afterSymlink:
if child == nil {
return nil, syserror.ENOENT
}
- if err := rp.CheckMount(&child.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
return nil, err
}
if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
@@ -239,7 +239,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// has 0 references, drop it). Wait to update parent.children until we
// know what to replace the existing dentry with (i.e. one of the
// returns below), to avoid a redundant map access.
- vfsObj.InvalidateDentry(&child.vfsd)
+ vfsObj.InvalidateDentry(ctx, &child.vfsd)
if child.isSynthetic() {
// Normally we don't mark invalidated dentries as deleted since
// they may still exist (but at a different path), and also for
@@ -332,7 +332,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string) error, createInSyntheticDir func(parent *dentry, name string) error) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
if !start.cachedMetadataAuthoritative() {
// Get updated metadata for start as required by
@@ -384,7 +384,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if dir {
ev |= linux.IN_ISDIR
}
- parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
return nil
}
if fs.opts.interop == InteropModeShared {
@@ -405,7 +405,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if dir {
ev |= linux.IN_ISDIR
}
- parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
return nil
}
if child := parent.children[name]; child != nil {
@@ -426,7 +426,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if dir {
ev |= linux.IN_ISDIR
}
- parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
return nil
}
@@ -434,7 +434,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
if !start.cachedMetadataAuthoritative() {
// Get updated metadata for start as required by
@@ -470,7 +470,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
}
vfsObj := rp.VirtualFilesystem()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
@@ -600,17 +600,17 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
// Generate inotify events for rmdir or unlink.
if dir {
- parent.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ parent.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
} else {
var cw *vfs.Watches
if child != nil {
cw = &child.watches
}
- vfs.InotifyRemoveChild(cw, &parent.watches, name)
+ vfs.InotifyRemoveChild(ctx, cw, &parent.watches, name)
}
if child != nil {
- vfsObj.CommitDeleteDentry(&child.vfsd)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
child.setDeleted()
if child.isSynthetic() {
parent.syntheticChildren--
@@ -637,7 +637,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
-func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) {
+func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
fs.renameMu.RUnlock()
if *ds == nil {
return
@@ -645,20 +645,20 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) {
if len(**ds) != 0 {
fs.renameMu.Lock()
for _, d := range **ds {
- d.checkCachingLocked()
+ d.checkCachingLocked(ctx)
}
fs.renameMu.Unlock()
}
putDentrySlice(*ds)
}
-func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) {
+func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
if *ds == nil {
fs.renameMu.Unlock()
return
}
for _, d := range **ds {
- d.checkCachingLocked()
+ d.checkCachingLocked(ctx)
}
fs.renameMu.Unlock()
putDentrySlice(*ds)
@@ -668,7 +668,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) {
func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return err
@@ -680,7 +680,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -701,7 +701,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
if !start.cachedMetadataAuthoritative() {
// Get updated metadata for start as required by
@@ -812,7 +812,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
if !start.cachedMetadataAuthoritative() {
@@ -1126,7 +1126,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
childVFSFD = &fd.vfsfd
}
- d.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
+ d.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
return childVFSFD, nil
}
@@ -1134,7 +1134,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
@@ -1154,7 +1154,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
var ds *[]*dentry
fs.renameMu.Lock()
- defer fs.renameMuUnlockAndCheckCaching(&ds)
+ defer fs.renameMuUnlockAndCheckCaching(ctx, &ds)
newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
if err != nil {
return err
@@ -1244,7 +1244,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return nil
}
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
return err
}
@@ -1269,7 +1269,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
// Update the dentry tree.
- vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD)
+ vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD)
if replaced != nil {
replaced.setDeleted()
if replaced.isSynthetic() {
@@ -1331,17 +1331,17 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.renameMu.RLock()
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil {
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
- d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
}
return nil
}
@@ -1350,7 +1350,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return linux.Statx{}, err
@@ -1367,7 +1367,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return linux.Statfs{}, err
@@ -1417,7 +1417,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -1443,7 +1443,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -1455,7 +1455,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
@@ -1469,16 +1469,16 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
fs.renameMu.RLock()
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil {
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
@@ -1488,16 +1488,16 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath,
fs.renameMu.RLock()
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
if err := d.removexattr(ctx, rp.Credentials(), name); err != nil {
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
- fs.renameMuRUnlockAndCheckCaching(&ds)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index e20de84b5..2e5575d8d 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -482,7 +482,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
if err != nil {
attachFile.close(ctx)
- fs.vfsfs.DecRef()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
// Set the root's reference count to 2. One reference is returned to the
@@ -495,8 +495,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
- ctx := context.Background()
+func (fs *filesystem) Release(ctx context.Context) {
mf := fs.mfp.MemoryFile()
fs.syncMu.Lock()
@@ -1089,10 +1088,10 @@ func (d *dentry) TryIncRef() bool {
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef() {
+func (d *dentry) DecRef(ctx context.Context) {
if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
d.fs.renameMu.Lock()
- d.checkCachingLocked()
+ d.checkCachingLocked(ctx)
d.fs.renameMu.Unlock()
} else if refs < 0 {
panic("gofer.dentry.DecRef() called without holding a reference")
@@ -1109,7 +1108,7 @@ func (d *dentry) decRefLocked() {
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
-func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
+func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {
if d.isDir() {
events |= linux.IN_ISDIR
}
@@ -1117,9 +1116,9 @@ func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
d.fs.renameMu.RLock()
// The ordering below is important, Linux always notifies the parent first.
if d.parent != nil {
- d.parent.watches.Notify(d.name, events, cookie, et, d.isDeleted())
+ d.parent.watches.Notify(ctx, d.name, events, cookie, et, d.isDeleted())
}
- d.watches.Notify("", events, cookie, et, d.isDeleted())
+ d.watches.Notify(ctx, "", events, cookie, et, d.isDeleted())
d.fs.renameMu.RUnlock()
}
@@ -1131,10 +1130,10 @@ func (d *dentry) Watches() *vfs.Watches {
// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
//
// If no watches are left on this dentry and it has no references, cache it.
-func (d *dentry) OnZeroWatches() {
+func (d *dentry) OnZeroWatches(ctx context.Context) {
if atomic.LoadInt64(&d.refs) == 0 {
d.fs.renameMu.Lock()
- d.checkCachingLocked()
+ d.checkCachingLocked(ctx)
d.fs.renameMu.Unlock()
}
}
@@ -1149,7 +1148,7 @@ func (d *dentry) OnZeroWatches() {
// do nothing.
//
// Preconditions: d.fs.renameMu must be locked for writing.
-func (d *dentry) checkCachingLocked() {
+func (d *dentry) checkCachingLocked(ctx context.Context) {
// Dentries with a non-zero reference count must be retained. (The only way
// to obtain a reference on a dentry with zero references is via path
// resolution, which requires renameMu, so if d.refs is zero then it will
@@ -1171,14 +1170,14 @@ func (d *dentry) checkCachingLocked() {
// reachable by path resolution and should be dropped immediately.
if d.vfsd.IsDead() {
if d.isDeleted() {
- d.watches.HandleDeletion()
+ d.watches.HandleDeletion(ctx)
}
if d.cached {
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentriesLen--
d.cached = false
}
- d.destroyLocked()
+ d.destroyLocked(ctx)
return
}
// If d still has inotify watches and it is not deleted or invalidated, we
@@ -1213,7 +1212,7 @@ func (d *dentry) checkCachingLocked() {
if !victim.vfsd.IsDead() {
// Note that victim can't be a mount point (in any mount
// namespace), since VFS holds references on mount points.
- d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(&victim.vfsd)
+ d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
delete(victim.parent.children, victim.name)
// We're only deleting the dentry, not the file it
// represents, so we don't need to update
@@ -1221,7 +1220,7 @@ func (d *dentry) checkCachingLocked() {
}
victim.parent.dirMu.Unlock()
}
- victim.destroyLocked()
+ victim.destroyLocked(ctx)
}
// Whether or not victim was destroyed, we brought fs.cachedDentriesLen
// back down to fs.opts.maxCachedDentries, so we don't loop.
@@ -1233,7 +1232,7 @@ func (d *dentry) checkCachingLocked() {
//
// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is
// not a child dentry.
-func (d *dentry) destroyLocked() {
+func (d *dentry) destroyLocked(ctx context.Context) {
switch atomic.LoadInt64(&d.refs) {
case 0:
// Mark the dentry destroyed.
@@ -1244,7 +1243,6 @@ func (d *dentry) destroyLocked() {
panic("dentry.destroyLocked() called with references on the dentry")
}
- ctx := context.Background()
d.handleMu.Lock()
if !d.handle.file.isNil() {
mf := d.fs.mfp.MemoryFile()
@@ -1276,7 +1274,7 @@ func (d *dentry) destroyLocked() {
// d.fs.renameMu.
if d.parent != nil {
if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkCachingLocked()
+ d.parent.checkCachingLocked(ctx)
} else if refs < 0 {
panic("gofer.dentry.DecRef() called without holding a reference")
}
@@ -1514,7 +1512,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return err
}
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
- fd.dentry().InotifyWithParent(ev, 0, vfs.InodeEvent)
+ fd.dentry().InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
}
return nil
}
@@ -1535,7 +1533,7 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption
if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil {
return err
}
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
@@ -1545,7 +1543,7 @@ func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil {
return err
}
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index adff39490..56d80bcf8 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -50,7 +50,7 @@ func TestDestroyIdempotent(t *testing.T) {
}
parent.cacheNewChildLocked(child, "child")
- child.checkCachingLocked()
+ child.checkCachingLocked(ctx)
if got := atomic.LoadInt64(&child.refs); got != -1 {
t.Fatalf("child.refs=%d, want: -1", got)
}
@@ -58,6 +58,6 @@ func TestDestroyIdempotent(t *testing.T) {
if got := atomic.LoadInt64(&parent.refs); got != -1 {
t.Fatalf("parent.refs=%d, want: -1", got)
}
- child.checkCachingLocked()
- child.checkCachingLocked()
+ child.checkCachingLocked(ctx)
+ child.checkCachingLocked(ctx)
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 09f142cfc..420e8efe2 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -48,7 +48,7 @@ type regularFileFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {
+func (fd *regularFileFD) Release(context.Context) {
}
// OnClose implements vfs.FileDescriptionImpl.OnClose.
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index d6dbe9092..85d2bee72 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -108,7 +108,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
// We don't need the receiver.
c.CloseRecv()
- c.Release()
+ c.Release(ctx)
return c, nil
}
@@ -136,8 +136,8 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla
}
// Release implements transport.BoundEndpoint.Release.
-func (e *endpoint) Release() {
- e.dentry.DecRef()
+func (e *endpoint) Release(ctx context.Context) {
+ e.dentry.DecRef(ctx)
}
// Passcred implements transport.BoundEndpoint.Passcred.
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 811528982..fc269ef2b 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -80,11 +80,11 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *specialFileFD) Release() {
+func (fd *specialFileFD) Release(ctx context.Context) {
if fd.haveQueue {
fdnotifier.RemoveFD(fd.handle.fd)
}
- fd.handle.close(context.Background())
+ fd.handle.close(ctx)
fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
fs.syncMu.Lock()
delete(fs.specialFileFDs, fd)
diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go
index b9082a20f..0135e4428 100644
--- a/pkg/sentry/fsimpl/host/control.go
+++ b/pkg/sentry/fsimpl/host/control.go
@@ -58,7 +58,7 @@ func (c *scmRights) Clone() transport.RightsControlMessage {
}
// Release implements transport.RightsControlMessage.Release.
-func (c *scmRights) Release() {
+func (c *scmRights) Release(ctx context.Context) {
for _, fd := range c.fds {
syscall.Close(fd)
}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index c894f2ca0..bf922c566 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -117,7 +117,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
d.Init(i)
// i.open will take a reference on d.
- defer d.DecRef()
+ defer d.DecRef(ctx)
// For simplicity, fileDescription.offset is set to 0. Technically, we
// should only set to 0 on files that are not seekable (sockets, pipes,
@@ -168,9 +168,9 @@ type filesystem struct {
devMinor uint32
}
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
- fs.Filesystem.Release()
+ fs.Filesystem.Release(ctx)
}
func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
@@ -431,12 +431,12 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
// DecRef implements kernfs.Inode.
-func (i *inode) DecRef() {
- i.AtomicRefCount.DecRefWithDestructor(i.Destroy)
+func (i *inode) DecRef(ctx context.Context) {
+ i.AtomicRefCount.DecRefWithDestructor(ctx, i.Destroy)
}
// Destroy implements kernfs.Inode.
-func (i *inode) Destroy() {
+func (i *inode) Destroy(context.Context) {
if i.wouldBlock {
fdnotifier.RemoveFD(int32(i.hostFD))
}
@@ -542,7 +542,7 @@ func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux
}
// Release implements vfs.FileDescriptionImpl.
-func (f *fileDescription) Release() {
+func (f *fileDescription) Release(context.Context) {
// noop
}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index fd16bd92d..4979dd0a9 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -139,7 +139,7 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable
}
// Send implements transport.ConnectedEndpoint.Send.
-func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -216,7 +216,7 @@ func (c *ConnectedEndpoint) EventUpdate() {
}
// Recv implements transport.Receiver.Recv.
-func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -317,8 +317,8 @@ func (c *ConnectedEndpoint) destroyLocked() {
// Release implements transport.ConnectedEndpoint.Release and
// transport.Receiver.Release.
-func (c *ConnectedEndpoint) Release() {
- c.ref.DecRefWithDestructor(func() {
+func (c *ConnectedEndpoint) Release(ctx context.Context) {
+ c.ref.DecRefWithDestructor(ctx, func(context.Context) {
c.mu.Lock()
c.destroyLocked()
c.mu.Unlock()
@@ -347,8 +347,8 @@ func (e *SCMConnectedEndpoint) Init() error {
// Release implements transport.ConnectedEndpoint.Release and
// transport.Receiver.Release.
-func (e *SCMConnectedEndpoint) Release() {
- e.ref.DecRefWithDestructor(func() {
+func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
+ e.ref.DecRefWithDestructor(ctx, func(context.Context) {
e.mu.Lock()
if err := syscall.Close(e.fd); err != nil {
log.Warningf("Failed to close host fd %d: %v", err)
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
index 4ee9270cc..d372c60cb 100644
--- a/pkg/sentry/fsimpl/host/tty.go
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -67,12 +67,12 @@ func (t *TTYFileDescription) ForegroundProcessGroup() *kernel.ProcessGroup {
}
// Release implements fs.FileOperations.Release.
-func (t *TTYFileDescription) Release() {
+func (t *TTYFileDescription) Release(ctx context.Context) {
t.mu.Lock()
t.fgProcessGroup = nil
t.mu.Unlock()
- t.fileDescription.Release()
+ t.fileDescription.Release(ctx)
}
// PRead implements vfs.FileDescriptionImpl.
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index c6c4472e7..12adf727a 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -122,7 +122,7 @@ func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, of
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *DynamicBytesFD) Release() {}
+func (fd *DynamicBytesFD) Release(context.Context) {}
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 1d37ccb98..fcee6200a 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -113,7 +113,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *GenericDirectoryFD) Release() {}
+func (fd *GenericDirectoryFD) Release(context.Context) {}
func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
return fd.vfsfd.VirtualDentry().Mount().Filesystem()
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 61a36cff9..d7edb6342 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -56,13 +56,13 @@ afterSymlink:
return vfsd, nil
}
if name == ".." {
- if isRoot, err := rp.CheckRoot(vfsd); err != nil {
+ if isRoot, err := rp.CheckRoot(ctx, vfsd); err != nil {
return nil, err
} else if isRoot || d.parent == nil {
rp.Advance()
return vfsd, nil
}
- if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
return nil, err
}
rp.Advance()
@@ -77,7 +77,7 @@ afterSymlink:
if err != nil {
return nil, err
}
- if err := rp.CheckMount(&next.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &next.vfsd); err != nil {
return nil, err
}
// Resolve any symlink at current path component.
@@ -88,7 +88,7 @@ afterSymlink:
}
if targetVD.Ok() {
err := rp.HandleJump(targetVD)
- targetVD.DecRef()
+ targetVD.DecRef(ctx)
if err != nil {
return nil, err
}
@@ -116,7 +116,7 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// Cached dentry exists, revalidate.
if !child.inode.Valid(ctx) {
delete(parent.children, name)
- vfsObj.InvalidateDentry(&child.vfsd)
+ vfsObj.InvalidateDentry(ctx, &child.vfsd)
fs.deferDecRef(&child.vfsd) // Reference from Lookup.
child = nil
}
@@ -234,7 +234,7 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Den
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *Filesystem) Release() {
+func (fs *Filesystem) Release(context.Context) {
}
// Sync implements vfs.FilesystemImpl.Sync.
@@ -246,7 +246,7 @@ func (fs *Filesystem) Sync(ctx context.Context) error {
// AccessAt implements vfs.Filesystem.Impl.AccessAt.
func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
fs.mu.RLock()
- defer fs.processDeferredDecRefs()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.RUnlock()
_, inode, err := fs.walkExistingLocked(ctx, rp)
@@ -259,7 +259,7 @@ func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
fs.mu.RLock()
- defer fs.processDeferredDecRefs()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.RUnlock()
vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
@@ -282,7 +282,7 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
fs.mu.RLock()
- defer fs.processDeferredDecRefs()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.RUnlock()
vfsd, _, err := fs.walkParentDirLocked(ctx, rp)
if err != nil {
@@ -300,7 +300,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
fs.mu.Lock()
defer fs.mu.Unlock()
parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -337,7 +337,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
fs.mu.Lock()
defer fs.mu.Unlock()
parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -365,7 +365,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
fs.mu.Lock()
defer fs.mu.Unlock()
parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -397,7 +397,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// Do not create new file.
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
- defer fs.processDeferredDecRefs()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.RUnlock()
vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
@@ -429,7 +429,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
}
afterTrailingSymlink:
parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return nil, err
}
@@ -483,7 +483,7 @@ afterTrailingSymlink:
}
if targetVD.Ok() {
err := rp.HandleJump(targetVD)
- targetVD.DecRef()
+ targetVD.DecRef(ctx)
if err != nil {
return nil, err
}
@@ -507,7 +507,7 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
fs.mu.RLock()
d, inode, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return "", err
}
@@ -526,7 +526,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
fs.mu.Lock()
- defer fs.processDeferredDecRefsLocked()
+ defer fs.processDeferredDecRefsLocked(ctx)
defer fs.mu.Unlock()
// Resolve the destination directory first to verify that it's on this
@@ -584,7 +584,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
virtfs := rp.VirtualFilesystem()
// We can't deadlock here due to lock ordering because we're protected from
@@ -615,7 +615,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
dstDir.children = make(map[string]*Dentry)
}
dstDir.children[pc] = src
- virtfs.CommitRenameReplaceDentry(srcVFSD, replaced)
+ virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaced)
return nil
}
@@ -624,7 +624,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
fs.mu.Lock()
defer fs.mu.Unlock()
vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -648,7 +648,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
defer parentDentry.dirMu.Unlock()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
@@ -656,7 +656,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
virtfs.AbortDeleteDentry(vfsd)
return err
}
- virtfs.CommitDeleteDentry(vfsd)
+ virtfs.CommitDeleteDentry(ctx, vfsd)
return nil
}
@@ -665,7 +665,7 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.mu.RLock()
_, inode, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -680,7 +680,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
fs.mu.RLock()
_, inode, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statx{}, err
}
@@ -692,7 +692,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statfs{}, err
}
@@ -708,7 +708,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
fs.mu.Lock()
defer fs.mu.Unlock()
parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -733,7 +733,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
fs.mu.Lock()
defer fs.mu.Unlock()
vfsd, _, err := fs.walkExistingLocked(ctx, rp)
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -753,7 +753,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
parentDentry.dirMu.Lock()
defer parentDentry.dirMu.Unlock()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
@@ -761,7 +761,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
virtfs.AbortDeleteDentry(vfsd)
return err
}
- virtfs.CommitDeleteDentry(vfsd)
+ virtfs.CommitDeleteDentry(ctx, vfsd)
return nil
}
@@ -770,7 +770,7 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
fs.mu.RLock()
_, inode, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return nil, err
}
@@ -785,7 +785,7 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return nil, err
}
@@ -798,7 +798,7 @@ func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return "", err
}
@@ -811,7 +811,7 @@ func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -824,7 +824,7 @@ func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath,
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
- fs.processDeferredDecRefs()
+ fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 579e627f0..c3efcf3ec 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -40,7 +40,7 @@ func (InodeNoopRefCount) IncRef() {
}
// DecRef implements Inode.DecRef.
-func (InodeNoopRefCount) DecRef() {
+func (InodeNoopRefCount) DecRef(context.Context) {
}
// TryIncRef implements Inode.TryIncRef.
@@ -49,7 +49,7 @@ func (InodeNoopRefCount) TryIncRef() bool {
}
// Destroy implements Inode.Destroy.
-func (InodeNoopRefCount) Destroy() {
+func (InodeNoopRefCount) Destroy(context.Context) {
}
// InodeDirectoryNoNewChildren partially implements the Inode interface.
@@ -366,12 +366,12 @@ func (o *OrderedChildren) Init(opts OrderedChildrenOptions) {
}
// DecRef implements Inode.DecRef.
-func (o *OrderedChildren) DecRef() {
- o.AtomicRefCount.DecRefWithDestructor(o.Destroy)
+func (o *OrderedChildren) DecRef(ctx context.Context) {
+ o.AtomicRefCount.DecRefWithDestructor(ctx, o.Destroy)
}
// Destroy cleans up resources referenced by this OrderedChildren.
-func (o *OrderedChildren) Destroy() {
+func (o *OrderedChildren) Destroy(context.Context) {
o.mu.Lock()
defer o.mu.Unlock()
o.order.Reset()
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 46f207664..080118841 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -116,17 +116,17 @@ func (fs *Filesystem) deferDecRef(d *vfs.Dentry) {
// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the
// droppedDentries list. See comment on Filesystem.mu.
-func (fs *Filesystem) processDeferredDecRefs() {
+func (fs *Filesystem) processDeferredDecRefs(ctx context.Context) {
fs.mu.Lock()
- fs.processDeferredDecRefsLocked()
+ fs.processDeferredDecRefsLocked(ctx)
fs.mu.Unlock()
}
// Precondition: fs.mu must be held for writing.
-func (fs *Filesystem) processDeferredDecRefsLocked() {
+func (fs *Filesystem) processDeferredDecRefsLocked(ctx context.Context) {
fs.droppedDentriesMu.Lock()
for _, d := range fs.droppedDentries {
- d.DecRef()
+ d.DecRef(ctx)
}
fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse.
fs.droppedDentriesMu.Unlock()
@@ -212,16 +212,16 @@ func (d *Dentry) isSymlink() bool {
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *Dentry) DecRef() {
- d.AtomicRefCount.DecRefWithDestructor(d.destroy)
+func (d *Dentry) DecRef(ctx context.Context) {
+ d.AtomicRefCount.DecRefWithDestructor(ctx, d.destroy)
}
// Precondition: Dentry must be removed from VFS' dentry cache.
-func (d *Dentry) destroy() {
- d.inode.DecRef() // IncRef from Init.
+func (d *Dentry) destroy(ctx context.Context) {
+ d.inode.DecRef(ctx) // IncRef from Init.
d.inode = nil
if d.parent != nil {
- d.parent.DecRef() // IncRef from Dentry.InsertChild.
+ d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild.
}
}
@@ -230,7 +230,7 @@ func (d *Dentry) destroy() {
// Although Linux technically supports inotify on pseudo filesystems (inotify
// is implemented at the vfs layer), it is not particularly useful. It is left
// unimplemented until someone actually needs it.
-func (d *Dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {}
+func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {}
// Watches implements vfs.DentryImpl.Watches.
func (d *Dentry) Watches() *vfs.Watches {
@@ -238,7 +238,7 @@ func (d *Dentry) Watches() *vfs.Watches {
}
// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
-func (d *Dentry) OnZeroWatches() {}
+func (d *Dentry) OnZeroWatches(context.Context) {}
// InsertChild inserts child into the vfs dentry cache with the given name under
// this dentry. This does not update the directory inode, so calling this on
@@ -326,12 +326,12 @@ type Inode interface {
type inodeRefs interface {
IncRef()
- DecRef()
+ DecRef(ctx context.Context)
TryIncRef() bool
// Destroy is called when the inode reaches zero references. Destroy release
// all resources (references) on objects referenced by the inode, including
// any child dentries.
- Destroy()
+ Destroy(ctx context.Context)
}
type inodeMetadata interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index dc407eb1d..c5d5afedf 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -46,7 +46,7 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System {
ctx := contexttest.Context(t)
creds := auth.CredentialsFromContext(ctx)
v := &vfs.VirtualFilesystem{}
- if err := v.Init(); err != nil {
+ if err := v.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
@@ -163,7 +163,7 @@ func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*
dir := d.fs.newDir(creds, opts.Mode, nil)
dirVFSD := dir.VFSDentry()
if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil {
- dir.DecRef()
+ dir.DecRef(ctx)
return nil, err
}
d.IncLinks(1)
@@ -175,7 +175,7 @@ func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*
f := d.fs.newFile(creds, "")
fVFSD := f.VFSDentry()
if err := d.OrderedChildren.Insert(name, fVFSD); err != nil {
- f.DecRef()
+ f.DecRef(ctx)
return nil, err
}
return fVFSD, nil
@@ -213,7 +213,7 @@ func TestBasic(t *testing.T) {
})
})
defer sys.Destroy()
- sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef()
+ sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef(sys.Ctx)
}
func TestMkdirGetDentry(t *testing.T) {
@@ -228,7 +228,7 @@ func TestMkdirGetDentry(t *testing.T) {
if err := sys.VFS.MkdirAt(sys.Ctx, sys.Creds, pop, &vfs.MkdirOptions{Mode: 0755}); err != nil {
t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err)
}
- sys.GetDentryOrDie(pop).DecRef()
+ sys.GetDentryOrDie(pop).DecRef(sys.Ctx)
}
func TestReadStaticFile(t *testing.T) {
@@ -246,7 +246,7 @@ func TestReadStaticFile(t *testing.T) {
if err != nil {
t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(sys.Ctx)
content, err := sys.ReadToEnd(fd)
if err != nil {
@@ -273,7 +273,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) {
}
// Close the file. The file should persist.
- fd.DecRef()
+ fd.DecRef(sys.Ctx)
fd, err = sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{
Flags: linux.O_RDONLY,
@@ -281,7 +281,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) {
if err != nil {
t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err)
}
- fd.DecRef()
+ fd.DecRef(sys.Ctx)
}
func TestDirFDReadWrite(t *testing.T) {
@@ -297,7 +297,7 @@ func TestDirFDReadWrite(t *testing.T) {
if err != nil {
t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(sys.Ctx)
// Read/Write should fail for directory FDs.
if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR {
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 8f8dcfafe..b3d19ff82 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -98,7 +98,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
if err != nil {
return err
}
- defer oldFD.DecRef()
+ defer oldFD.DecRef(ctx)
newFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &newpop, &vfs.OpenOptions{
Flags: linux.O_WRONLY | linux.O_CREAT | linux.O_EXCL,
Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
@@ -106,7 +106,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
if err != nil {
return err
}
- defer newFD.DecRef()
+ defer newFD.DecRef(ctx)
bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
for {
readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
@@ -241,13 +241,13 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
Mask: linux.STATX_INO,
})
if err != nil {
- d.upperVD.DecRef()
+ d.upperVD.DecRef(ctx)
d.upperVD = vfs.VirtualDentry{}
cleanupUndoCopyUp()
return err
}
if upperStat.Mask&linux.STATX_INO == 0 {
- d.upperVD.DecRef()
+ d.upperVD.DecRef(ctx)
d.upperVD = vfs.VirtualDentry{}
cleanupUndoCopyUp()
return syserror.EREMOTE
diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go
index f5c2462a5..fccb94105 100644
--- a/pkg/sentry/fsimpl/overlay/directory.go
+++ b/pkg/sentry/fsimpl/overlay/directory.go
@@ -46,7 +46,7 @@ func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string
readdirErr = err
return false
}
- defer layerFD.DecRef()
+ defer layerFD.DecRef(ctx)
// Reuse slice allocated for maybeWhiteouts from a previous layer to
// reduce allocations.
@@ -108,7 +108,7 @@ type directoryFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *directoryFD) Release() {
+func (fd *directoryFD) Release(ctx context.Context) {
}
// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
@@ -177,7 +177,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
readdirErr = err
return false
}
- defer layerFD.DecRef()
+ defer layerFD.DecRef(ctx)
// Reuse slice allocated for maybeWhiteouts from a previous layer to
// reduce allocations.
@@ -282,6 +282,6 @@ func (fd *directoryFD) Sync(ctx context.Context) error {
return err
}
err = upperFD.Sync(ctx)
- upperFD.DecRef()
+ upperFD.DecRef(ctx)
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 6b705e955..986b36ead 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -77,7 +77,7 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
-func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) {
+func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
fs.renameMu.RUnlock()
if *ds == nil {
return
@@ -85,20 +85,20 @@ func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) {
if len(**ds) != 0 {
fs.renameMu.Lock()
for _, d := range **ds {
- d.checkDropLocked()
+ d.checkDropLocked(ctx)
}
fs.renameMu.Unlock()
}
putDentrySlice(*ds)
}
-func (fs *filesystem) renameMuUnlockAndCheckDrop(ds **[]*dentry) {
+func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
if *ds == nil {
fs.renameMu.Unlock()
return
}
for _, d := range **ds {
- d.checkDropLocked()
+ d.checkDropLocked(ctx)
}
fs.renameMu.Unlock()
putDentrySlice(*ds)
@@ -126,13 +126,13 @@ afterSymlink:
return d, nil
}
if name == ".." {
- if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil {
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
return nil, err
} else if isRoot || d.parent == nil {
rp.Advance()
return d, nil
}
- if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
return nil, err
}
rp.Advance()
@@ -142,7 +142,7 @@ afterSymlink:
if err != nil {
return nil, err
}
- if err := rp.CheckMount(&child.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
return nil, err
}
if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
@@ -272,11 +272,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
})
if lookupErr != nil {
- child.destroyLocked()
+ child.destroyLocked(ctx)
return nil, lookupErr
}
if !existsOnAnyLayer {
- child.destroyLocked()
+ child.destroyLocked(ctx)
return nil, syserror.ENOENT
}
@@ -430,7 +430,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
start := rp.Start().Impl().(*dentry)
parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
@@ -501,7 +501,7 @@ func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.V
func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return err
@@ -513,7 +513,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -532,7 +532,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -553,7 +553,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
start := rp.Start().Impl().(*dentry)
d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
@@ -720,7 +720,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
start := rp.Start().Impl().(*dentry)
if rp.Done() {
@@ -825,7 +825,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
fd.LockFD.Init(&d.locks)
layerFDOpts := layerFD.Options()
if err := fd.vfsfd.Init(fd, layerFlags, mnt, &d.vfsd, &layerFDOpts); err != nil {
- layerFD.DecRef()
+ layerFD.DecRef(ctx)
return nil, err
}
return &fd.vfsfd, nil
@@ -920,7 +920,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
fd.LockFD.Init(&child.locks)
upperFDOpts := upperFD.Options()
if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil {
- upperFD.DecRef()
+ upperFD.DecRef(ctx)
// Don't bother with cleanup; the file was created successfully, we
// just can't open it anymore for some reason.
return nil, err
@@ -932,7 +932,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
@@ -952,7 +952,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
var ds *[]*dentry
fs.renameMu.Lock()
- defer fs.renameMuUnlockAndCheckDrop(&ds)
+ defer fs.renameMuUnlockAndCheckDrop(ctx, &ds)
newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
if err != nil {
return err
@@ -979,7 +979,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
start := rp.Start().Impl().(*dentry)
parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
@@ -1001,7 +1001,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
vfsObj := rp.VirtualFilesystem()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
@@ -1086,7 +1086,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
- vfsObj.CommitDeleteDentry(&child.vfsd)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
delete(parent.children, name)
ds = appendDentry(ds, child)
parent.dirents = nil
@@ -1097,7 +1097,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return err
@@ -1132,7 +1132,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return linux.Statx{}, err
@@ -1160,7 +1160,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
_, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return linux.Statfs{}, err
@@ -1211,7 +1211,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
start := rp.Start().Impl().(*dentry)
parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
@@ -1233,7 +1233,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
vfsObj := rp.VirtualFilesystem()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
@@ -1298,7 +1298,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
if child != nil {
- vfsObj.CommitDeleteDentry(&child.vfsd)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
delete(parent.children, name)
ds = appendDentry(ds, child)
}
@@ -1310,7 +1310,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
_, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -1324,7 +1324,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
_, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
@@ -1336,7 +1336,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
_, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return err
@@ -1348,7 +1348,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
_, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
index c0749e711..d3060a481 100644
--- a/pkg/sentry/fsimpl/overlay/non_directory.go
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -81,11 +81,11 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip
oldOff, oldOffErr := fd.cachedFD.Seek(ctx, 0, linux.SEEK_CUR)
if oldOffErr == nil {
if _, err := upperFD.Seek(ctx, oldOff, linux.SEEK_SET); err != nil {
- upperFD.DecRef()
+ upperFD.DecRef(ctx)
return nil, err
}
}
- fd.cachedFD.DecRef()
+ fd.cachedFD.DecRef(ctx)
fd.copiedUp = true
fd.cachedFD = upperFD
fd.cachedFlags = statusFlags
@@ -99,8 +99,8 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *nonDirectoryFD) Release() {
- fd.cachedFD.DecRef()
+func (fd *nonDirectoryFD) Release(ctx context.Context) {
+ fd.cachedFD.DecRef(ctx)
fd.cachedFD = nil
}
@@ -138,7 +138,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux
Mask: layerMask,
Sync: opts.Sync,
})
- wrappedFD.DecRef()
+ wrappedFD.DecRef(ctx)
if err != nil {
return linux.Statx{}, err
}
@@ -187,7 +187,7 @@ func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, off
if err != nil {
return 0, err
}
- defer wrappedFD.DecRef()
+ defer wrappedFD.DecRef(ctx)
return wrappedFD.PRead(ctx, dst, offset, opts)
}
@@ -209,7 +209,7 @@ func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, of
if err != nil {
return 0, err
}
- defer wrappedFD.DecRef()
+ defer wrappedFD.DecRef(ctx)
return wrappedFD.PWrite(ctx, src, offset, opts)
}
@@ -250,7 +250,7 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
return err
}
wrappedFD.IncRef()
- defer wrappedFD.DecRef()
+ defer wrappedFD.DecRef(ctx)
fd.mu.Unlock()
return wrappedFD.Sync(ctx)
}
@@ -261,6 +261,6 @@ func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOp
if err != nil {
return err
}
- defer wrappedFD.DecRef()
+ defer wrappedFD.DecRef(ctx)
return wrappedFD.ConfigureMMap(ctx, opts)
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index e720d4825..75cc006bf 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -123,7 +123,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// filesystem with any number of lower layers.
} else {
vfsroot := vfs.RootFromContext(ctx)
- defer vfsroot.DecRef()
+ defer vfsroot.DecRef(ctx)
upperPathname, ok := mopts["upperdir"]
if ok {
delete(mopts, "upperdir")
@@ -147,13 +147,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
return nil, nil, err
}
- defer upperRoot.DecRef()
+ defer upperRoot.DecRef(ctx)
privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
if err != nil {
ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
return nil, nil, err
}
- defer privateUpperRoot.DecRef()
+ defer privateUpperRoot.DecRef(ctx)
fsopts.UpperRoot = privateUpperRoot
}
lowerPathnamesStr, ok := mopts["lowerdir"]
@@ -190,13 +190,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
}
- defer lowerRoot.DecRef()
+ defer lowerRoot.DecRef(ctx)
privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
if err != nil {
ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
}
- defer privateLowerRoot.DecRef()
+ defer privateLowerRoot.DecRef(ctx)
fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot)
}
}
@@ -264,19 +264,19 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Mask: rootStatMask,
})
if err != nil {
- root.destroyLocked()
- fs.vfsfs.DecRef()
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
if rootStat.Mask&rootStatMask != rootStatMask {
- root.destroyLocked()
- fs.vfsfs.DecRef()
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
return nil, nil, syserror.EREMOTE
}
if isWhiteout(&rootStat) {
ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout")
- root.destroyLocked()
- fs.vfsfs.DecRef()
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
return nil, nil, syserror.EINVAL
}
root.mode = uint32(rootStat.Mode)
@@ -319,17 +319,17 @@ func clonePrivateMount(vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, forc
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
vfsObj := fs.vfsfs.VirtualFilesystem()
vfsObj.PutAnonBlockDevMinor(fs.dirDevMinor)
for _, lowerDevMinor := range fs.lowerDevMinors {
vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
}
if fs.opts.UpperRoot.Ok() {
- fs.opts.UpperRoot.DecRef()
+ fs.opts.UpperRoot.DecRef(ctx)
}
for _, lowerRoot := range fs.opts.LowerRoots {
- lowerRoot.DecRef()
+ lowerRoot.DecRef(ctx)
}
}
@@ -452,10 +452,10 @@ func (d *dentry) TryIncRef() bool {
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef() {
+func (d *dentry) DecRef(ctx context.Context) {
if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
d.fs.renameMu.Lock()
- d.checkDropLocked()
+ d.checkDropLocked(ctx)
d.fs.renameMu.Unlock()
} else if refs < 0 {
panic("overlay.dentry.DecRef() called without holding a reference")
@@ -466,7 +466,7 @@ func (d *dentry) DecRef() {
// becomes deleted.
//
// Preconditions: d.fs.renameMu must be locked for writing.
-func (d *dentry) checkDropLocked() {
+func (d *dentry) checkDropLocked(ctx context.Context) {
// Dentries with a positive reference count must be retained. (The only way
// to obtain a reference on a dentry with zero references is via path
// resolution, which requires renameMu, so if d.refs is zero then it will
@@ -476,14 +476,14 @@ func (d *dentry) checkDropLocked() {
return
}
// Refs is still zero; destroy it.
- d.destroyLocked()
+ d.destroyLocked(ctx)
return
}
// destroyLocked destroys the dentry.
//
// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
-func (d *dentry) destroyLocked() {
+func (d *dentry) destroyLocked(ctx context.Context) {
switch atomic.LoadInt64(&d.refs) {
case 0:
// Mark the dentry destroyed.
@@ -495,10 +495,10 @@ func (d *dentry) destroyLocked() {
}
if d.upperVD.Ok() {
- d.upperVD.DecRef()
+ d.upperVD.DecRef(ctx)
}
for _, lowerVD := range d.lowerVDs {
- lowerVD.DecRef()
+ lowerVD.DecRef(ctx)
}
if d.parent != nil {
@@ -510,7 +510,7 @@ func (d *dentry) destroyLocked() {
// Drop the reference held by d on its parent without recursively
// locking d.fs.renameMu.
if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkDropLocked()
+ d.parent.checkDropLocked(ctx)
} else if refs < 0 {
panic("overlay.dentry.DecRef() called without holding a reference")
}
@@ -518,7 +518,7 @@ func (d *dentry) destroyLocked() {
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
-func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {
+func (d *dentry) InotifyWithParent(ctx context.Context, events uint32, cookie uint32, et vfs.EventType) {
// TODO(gvisor.dev/issue/1479): Implement inotify.
}
@@ -531,7 +531,7 @@ func (d *dentry) Watches() *vfs.Watches {
// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
//
// TODO(gvisor.dev/issue/1479): Implement inotify.
-func (d *dentry) OnZeroWatches() {}
+func (d *dentry) OnZeroWatches(context.Context) {}
// iterLayers invokes yield on each layer comprising d, from top to bottom. If
// any call to yield returns false, iterLayer stops iteration.
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 811f80a5f..2ca793db9 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -63,9 +63,9 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
- fs.Filesystem.Release()
+ fs.Filesystem.Release(ctx)
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
@@ -160,6 +160,6 @@ func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vf
inode := newInode(ctx, fs)
var d kernfs.Dentry
d.Init(inode)
- defer d.DecRef()
+ defer d.DecRef(ctx)
return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
}
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 609210253..2463d51cd 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -77,9 +77,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
- fs.Filesystem.Release()
+ fs.Filesystem.Release(ctx)
}
// dynamicInode is an overfitted interface for common Inodes with
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index fea29e5f0..f0d3f7f5e 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -43,12 +43,12 @@ func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags)
return file, flags
}
-func taskFDExists(t *kernel.Task, fd int32) bool {
+func taskFDExists(ctx context.Context, t *kernel.Task, fd int32) bool {
file, _ := getTaskFD(t, fd)
if file == nil {
return false
}
- file.DecRef()
+ file.DecRef(ctx)
return true
}
@@ -68,7 +68,7 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, off
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.GetFDs()
+ fds = fdTable.GetFDs(ctx)
}
})
@@ -135,7 +135,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
return nil, syserror.ENOENT
}
fd := int32(fdInt)
- if !taskFDExists(i.task, fd) {
+ if !taskFDExists(ctx, i.task, fd) {
return nil, syserror.ENOENT
}
taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno())
@@ -204,9 +204,9 @@ func (s *fdSymlink) Readlink(ctx context.Context) (string, error) {
if file == nil {
return "", syserror.ENOENT
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
root := vfs.RootFromContext(ctx)
- defer root.DecRef()
+ defer root.DecRef(ctx)
return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry())
}
@@ -215,7 +215,7 @@ func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDen
if file == nil {
return vfs.VirtualDentry{}, "", syserror.ENOENT
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
vd := file.VirtualDentry()
vd.IncRef()
return vd, "", nil
@@ -258,7 +258,7 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry,
return nil, syserror.ENOENT
}
fd := int32(fdInt)
- if !taskFDExists(i.task, fd) {
+ if !taskFDExists(ctx, i.task, fd) {
return nil, syserror.ENOENT
}
data := &fdInfoData{
@@ -297,7 +297,7 @@ func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
if file == nil {
return syserror.ENOENT
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
// TODO(b/121266871): Include pos, locks, and other data. For now we only
// have flags.
// See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 859b7d727..830b78949 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -677,7 +677,7 @@ func (s *exeSymlink) Readlink(ctx context.Context) (string, error) {
if err != nil {
return "", err
}
- defer exec.DecRef()
+ defer exec.DecRef(ctx)
return exec.PathnameWithDeleted(ctx), nil
}
@@ -692,7 +692,7 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent
if err != nil {
return vfs.VirtualDentry{}, "", err
}
- defer exec.DecRef()
+ defer exec.DecRef(ctx)
vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
vd.IncRef()
@@ -748,7 +748,7 @@ func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// Root has been destroyed. Don't try to read mounts.
return nil
}
- defer rootDir.DecRef()
+ defer rootDir.DecRef(ctx)
i.task.Kernel().VFS().GenerateProcMountInfo(ctx, rootDir, buf)
return nil
}
@@ -779,7 +779,7 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// Root has been destroyed. Don't try to read mounts.
return nil
}
- defer rootDir.DecRef()
+ defer rootDir.DecRef(ctx)
i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf)
return nil
}
@@ -825,7 +825,7 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir
dentry.Init(&namespaceInode{})
vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
vd.IncRef()
- dentry.DecRef()
+ dentry.DecRef(ctx)
return vd, "", nil
}
@@ -887,8 +887,8 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err
}
// Release implements FileDescriptionImpl.
-func (fd *namespaceFD) Release() {
- fd.inode.DecRef()
+func (fd *namespaceFD) Release(ctx context.Context) {
+ fd.inode.DecRef(ctx)
}
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 6bde27376..a4c884bf9 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -212,7 +212,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
continue
}
if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
- s.DecRef()
+ s.DecRef(ctx)
// Not a unix socket.
continue
}
@@ -281,7 +281,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
fmt.Fprintf(buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
return nil
}
@@ -359,7 +359,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
- s.DecRef()
+ s.DecRef(ctx)
// Not tcp4 sockets.
continue
}
@@ -455,7 +455,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
fmt.Fprintf(buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
return nil
@@ -524,7 +524,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
- s.DecRef()
+ s.DecRef(ctx)
// Not udp4 socket.
continue
}
@@ -600,7 +600,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
return nil
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 19abb5034..3c9297dee 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -218,7 +218,7 @@ func TestTasks(t *testing.T) {
if err != nil {
t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(s.Ctx)
buf := make([]byte, 1)
bufIOSeq := usermem.BytesIOSequence(buf)
if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
@@ -336,7 +336,7 @@ func TestTasksOffset(t *testing.T) {
if err != nil {
t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
}
- defer fd.DecRef()
+ defer fd.DecRef(s.Ctx)
if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil {
t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
}
@@ -441,7 +441,7 @@ func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.F
t.Errorf("vfsfs.OpenAt(%v) failed: %v", absPath, err)
continue
}
- defer child.DecRef()
+ defer child.DecRef(ctx)
stat, err := child.Stat(ctx, vfs.StatOptions{})
if err != nil {
t.Errorf("Stat(%v) failed: %v", absPath, err)
@@ -476,7 +476,7 @@ func TestTree(t *testing.T) {
if err != nil {
t.Fatalf("failed to create test file: %v", err)
}
- defer file.DecRef()
+ defer file.DecRef(s.Ctx)
var tasks []*kernel.Task
for i := 0; i < 5; i++ {
@@ -501,5 +501,5 @@ func TestTree(t *testing.T) {
t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err)
}
iterateDir(ctx, t, s, fd)
- fd.DecRef()
+ fd.DecRef(ctx)
}
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index 242ba9b5d..6297e1df4 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -54,7 +54,7 @@ var _ vfs.FileDescriptionImpl = (*SignalFileDescription)(nil)
// New creates a new signal fd.
func New(vfsObj *vfs.VirtualFilesystem, target *kernel.Task, mask linux.SignalSet, flags uint32) (*vfs.FileDescription, error) {
vd := vfsObj.NewAnonVirtualDentry("[signalfd]")
- defer vd.DecRef()
+ defer vd.DecRef(target)
sfd := &SignalFileDescription{
target: target,
mask: mask,
@@ -133,4 +133,4 @@ func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) {
}
// Release implements FileDescriptionImpl.Release()
-func (sfd *SignalFileDescription) Release() {}
+func (sfd *SignalFileDescription) Release(context.Context) {}
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index ee0828a15..c61818ff6 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -67,9 +67,9 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
- fs.Filesystem.Release()
+ fs.Filesystem.Release(ctx)
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 01ce30a4d..f81b0c38f 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -87,9 +87,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
- fs.Filesystem.Release()
+ fs.Filesystem.Release(ctx)
}
// dir implements kernfs.Inode.
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 242d5fd12..9fd38b295 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -59,7 +59,7 @@ func TestReadCPUFile(t *testing.T) {
if err != nil {
t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(s.Ctx)
content, err := s.ReadToEnd(fd)
if err != nil {
t.Fatalf("Read failed: %v", err)
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index e743e8114..1e57744e8 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -127,7 +127,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns
return nil, err
}
m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
- m.SetExecutable(fsbridge.NewVFSFile(exe))
+ m.SetExecutable(ctx, fsbridge.NewVFSFile(exe))
config := &kernel.TaskConfig{
Kernel: k,
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
index 0556af877..568132121 100644
--- a/pkg/sentry/fsimpl/testutil/testutil.go
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -97,8 +97,8 @@ func (s *System) WithTemporaryContext(ctx context.Context) *System {
// Destroy release resources associated with a test system.
func (s *System) Destroy() {
- s.Root.DecRef()
- s.MntNs.DecRef() // Reference on MntNs passed to NewSystem.
+ s.Root.DecRef(s.Ctx)
+ s.MntNs.DecRef(s.Ctx) // Reference on MntNs passed to NewSystem.
}
// ReadToEnd reads the contents of fd until EOF to a string.
@@ -149,7 +149,7 @@ func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector {
if err != nil {
s.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(s.Ctx)
collector := &DirentCollector{}
if err := fd.IterDirents(s.Ctx, collector); err != nil {
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
index 2dc90d484..86beaa0a8 100644
--- a/pkg/sentry/fsimpl/timerfd/timerfd.go
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -47,9 +47,9 @@ var _ vfs.FileDescriptionImpl = (*TimerFileDescription)(nil)
var _ ktime.TimerListener = (*TimerFileDescription)(nil)
// New returns a new timer fd.
-func New(vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) {
+func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) {
vd := vfsObj.NewAnonVirtualDentry("[timerfd]")
- defer vd.DecRef()
+ defer vd.DecRef(ctx)
tfd := &TimerFileDescription{}
tfd.timer = ktime.NewTimer(clock, tfd)
if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
@@ -129,7 +129,7 @@ func (tfd *TimerFileDescription) ResumeTimer() {
}
// Release implements FileDescriptionImpl.Release()
-func (tfd *TimerFileDescription) Release() {
+func (tfd *TimerFileDescription) Release(context.Context) {
tfd.timer.Destroy()
}
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 2fb5c4d84..d263147c2 100644
--- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -83,7 +83,7 @@ func fileOpOn(ctx context.Context, mntns *fs.MountNamespace, root, wd *fs.Dirent
}
err = fn(root, d)
- d.DecRef()
+ d.DecRef(ctx)
return err
}
@@ -105,17 +105,17 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create mount namespace: %v", err)
}
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create nested directories with given depth.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
d := root
d.IncRef()
- defer d.DecRef()
+ defer d.DecRef(ctx)
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil {
@@ -125,7 +125,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- d.DecRef()
+ d.DecRef(ctx)
d = next
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
@@ -136,7 +136,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- file.DecRef()
+ file.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
@@ -176,7 +176,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
b.Fatalf("VFS init: %v", err)
}
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -186,14 +186,14 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create nested directories with given depth.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
vd := root
vd.IncRef()
for i := depth; i > 0; i-- {
@@ -212,7 +212,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- vd.DecRef()
+ vd.DecRef(ctx)
vd = nextVD
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
@@ -228,12 +228,12 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) {
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
Mode: 0644,
})
- vd.DecRef()
+ vd.DecRef(ctx)
vd = vfs.VirtualDentry{}
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
@@ -278,14 +278,14 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create mount namespace: %v", err)
}
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create and mount the submount.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
if err := root.Inode.CreateDirectory(ctx, root, mountPointName, fs.FilePermsFromMode(0755)); err != nil {
b.Fatalf("failed to create mount point: %v", err)
}
@@ -293,7 +293,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount point: %v", err)
}
- defer mountPoint.DecRef()
+ defer mountPoint.DecRef(ctx)
submountInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil)
if err != nil {
b.Fatalf("failed to create tmpfs submount: %v", err)
@@ -309,7 +309,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount root: %v", err)
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil {
@@ -319,7 +319,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- d.DecRef()
+ d.DecRef(ctx)
d = next
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
@@ -330,7 +330,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- file.DecRef()
+ file.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
@@ -370,7 +370,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
b.Fatalf("VFS init: %v", err)
}
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -380,14 +380,14 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create the mount point.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
pop := vfs.PathOperation{
Root: root,
Start: root,
@@ -403,7 +403,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount point: %v", err)
}
- defer mountPoint.DecRef()
+ defer mountPoint.DecRef(ctx)
// Create and mount the submount.
if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
@@ -432,7 +432,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- vd.DecRef()
+ vd.DecRef(ctx)
vd = nextVD
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
@@ -448,11 +448,11 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
Mode: 0644,
})
- vd.DecRef()
+ vd.DecRef(ctx)
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- fd.DecRef()
+ fd.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 0a1ad4765..78b4fc5be 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -95,7 +95,7 @@ type directoryFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *directoryFD) Release() {
+func (fd *directoryFD) Release(ctx context.Context) {
if fd.iter != nil {
dir := fd.inode().impl.(*directory)
dir.iterMu.Lock()
@@ -110,7 +110,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fs := fd.filesystem()
dir := fd.inode().impl.(*directory)
- defer fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ defer fd.dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent)
// fs.mu is required to read d.parent and dentry.name.
fs.mu.RLock()
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index ef210a69b..fb77f95cc 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -40,7 +40,7 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// stepLocked is loosely analogous to fs/namei.c:walk_component().
//
// Preconditions: filesystem.mu must be locked. !rp.Done().
-func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
+func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
dir, ok := d.inode.impl.(*directory)
if !ok {
return nil, syserror.ENOTDIR
@@ -55,13 +55,13 @@ afterSymlink:
return d, nil
}
if name == ".." {
- if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil {
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
return nil, err
} else if isRoot || d.parent == nil {
rp.Advance()
return d, nil
}
- if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
return nil, err
}
rp.Advance()
@@ -74,7 +74,7 @@ afterSymlink:
if !ok {
return nil, syserror.ENOENT
}
- if err := rp.CheckMount(&child.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
return nil, err
}
if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
@@ -98,9 +98,9 @@ afterSymlink:
// fs/namei.c:path_parentat().
//
// Preconditions: filesystem.mu must be locked. !rp.Done().
-func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*directory, error) {
+func walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*directory, error) {
for !rp.Final() {
- next, err := stepLocked(rp, d)
+ next, err := stepLocked(ctx, rp, d)
if err != nil {
return nil, err
}
@@ -118,10 +118,10 @@ func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*directory, error) {
// resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat().
//
// Preconditions: filesystem.mu must be locked.
-func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) {
+func resolveLocked(ctx context.Context, rp *vfs.ResolvingPath) (*dentry, error) {
d := rp.Start().Impl().(*dentry)
for !rp.Done() {
- next, err := stepLocked(rp, d)
+ next, err := stepLocked(ctx, rp, d)
if err != nil {
return nil, err
}
@@ -141,10 +141,10 @@ func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) {
//
// Preconditions: !rp.Done(). For the final path component in rp,
// !rp.ShouldFollowSymlink().
-func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error {
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error {
fs.mu.Lock()
defer fs.mu.Unlock()
- parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
if err != nil {
return err
}
@@ -182,7 +182,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
if dir {
ev |= linux.IN_ISDIR
}
- parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ parentDir.inode.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
parentDir.inode.touchCMtime()
return nil
}
@@ -191,7 +191,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return err
}
@@ -202,7 +202,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return nil, err
}
@@ -222,7 +222,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- dir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ dir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
if err != nil {
return nil, err
}
@@ -232,7 +232,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
// LinkAt implements vfs.FilesystemImpl.LinkAt.
func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error {
if rp.Mount() != vd.Mount() {
return syserror.EXDEV
}
@@ -251,7 +251,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EMLINK
}
i.incLinksLocked()
- i.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */)
+ i.watches.Notify(ctx, "", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */)
parentDir.insertChildLocked(fs.newDentry(i), name)
return nil
})
@@ -259,7 +259,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
- return fs.doCreateAt(rp, true /* dir */, func(parentDir *directory, name string) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parentDir *directory, name string) error {
creds := rp.Credentials()
if parentDir.inode.nlink == maxLinks {
return syserror.EMLINK
@@ -273,7 +273,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error {
creds := rp.Credentials()
var childInode *inode
switch opts.Mode.FileType() {
@@ -308,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return nil, err
}
@@ -330,7 +330,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
return start.open(ctx, rp, &opts, false /* afterCreate */)
}
afterTrailingSymlink:
- parentDir, err := walkParentDirLocked(rp, start)
+ parentDir, err := walkParentDirLocked(ctx, rp, start)
if err != nil {
return nil, err
}
@@ -368,7 +368,7 @@ afterTrailingSymlink:
if err != nil {
return nil, err
}
- parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
+ parentDir.inode.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
parentDir.inode.touchCMtime()
return fd, nil
}
@@ -376,7 +376,7 @@ afterTrailingSymlink:
return nil, syserror.EEXIST
}
// Is the file mounted over?
- if err := rp.CheckMount(&child.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
return nil, err
}
// Do we need to resolve a trailing symlink?
@@ -445,7 +445,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return "", err
}
@@ -467,7 +467,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Resolve newParent first to verify that it's on this Mount.
fs.mu.Lock()
defer fs.mu.Unlock()
- newParentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
if err != nil {
return err
}
@@ -555,7 +555,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
vfsObj := rp.VirtualFilesystem()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
var replacedVFSD *vfs.Dentry
if replaced != nil {
replacedVFSD = &replaced.vfsd
@@ -566,17 +566,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if replaced != nil {
newParentDir.removeChildLocked(replaced)
if replaced.inode.isDir() {
- newParentDir.inode.decLinksLocked() // from replaced's ".."
+ newParentDir.inode.decLinksLocked(ctx) // from replaced's ".."
}
- replaced.inode.decLinksLocked()
+ replaced.inode.decLinksLocked(ctx)
}
oldParentDir.removeChildLocked(renamed)
newParentDir.insertChildLocked(renamed, newName)
- vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD)
+ vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD)
oldParentDir.inode.touchCMtime()
if oldParentDir != newParentDir {
if renamed.inode.isDir() {
- oldParentDir.inode.decLinksLocked()
+ oldParentDir.inode.decLinksLocked(ctx)
newParentDir.inode.incLinksLocked()
}
newParentDir.inode.touchCMtime()
@@ -591,7 +591,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
defer fs.mu.Unlock()
- parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
if err != nil {
return err
}
@@ -626,17 +626,17 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
defer mnt.EndWrite()
vfsObj := rp.VirtualFilesystem()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
return err
}
parentDir.removeChildLocked(child)
- parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ parentDir.inode.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
// Remove links for child, child/., and child/..
- child.inode.decLinksLocked()
- child.inode.decLinksLocked()
- parentDir.inode.decLinksLocked()
- vfsObj.CommitDeleteDentry(&child.vfsd)
+ child.inode.decLinksLocked(ctx)
+ child.inode.decLinksLocked(ctx)
+ parentDir.inode.decLinksLocked(ctx)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
parentDir.inode.touchCMtime()
return nil
}
@@ -644,7 +644,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
return err
@@ -656,7 +656,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.mu.RUnlock()
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
- d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
}
return nil
}
@@ -665,7 +665,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return linux.Statx{}, err
}
@@ -678,7 +678,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- if _, err := resolveLocked(rp); err != nil {
+ if _, err := resolveLocked(ctx, rp); err != nil {
return linux.Statfs{}, err
}
statfs := linux.Statfs{
@@ -695,7 +695,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error {
creds := rp.Credentials()
child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target))
parentDir.insertChildLocked(child, name)
@@ -707,7 +707,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
defer fs.mu.Unlock()
- parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
if err != nil {
return err
}
@@ -738,7 +738,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
defer mnt.EndWrite()
vfsObj := rp.VirtualFilesystem()
mntns := vfs.MountNamespaceFromContext(ctx)
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
return err
}
@@ -746,11 +746,11 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
// Generate inotify events. Note that this must take place before the link
// count of the child is decremented, or else the watches may be dropped
// before these events are added.
- vfs.InotifyRemoveChild(&child.inode.watches, &parentDir.inode.watches, name)
+ vfs.InotifyRemoveChild(ctx, &child.inode.watches, &parentDir.inode.watches, name)
parentDir.removeChildLocked(child)
- child.inode.decLinksLocked()
- vfsObj.CommitDeleteDentry(&child.vfsd)
+ child.inode.decLinksLocked(ctx)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
parentDir.inode.touchCMtime()
return nil
}
@@ -759,7 +759,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return nil, err
}
@@ -778,7 +778,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return nil, err
}
@@ -789,7 +789,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
return "", err
}
@@ -799,7 +799,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
fs.mu.RLock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
return err
@@ -810,14 +810,14 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
}
fs.mu.RUnlock()
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
- d, err := resolveLocked(rp)
+ d, err := resolveLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
return err
@@ -828,7 +828,7 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath,
}
fs.mu.RUnlock()
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 1614f2c39..ec2701d8b 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -32,7 +32,7 @@ const fileName = "mypipe"
func TestSeparateFDs(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the read side. This is done in a concurrently because opening
// One end the pipe blocks until the other end is opened.
@@ -55,13 +55,13 @@ func TestSeparateFDs(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
}
- defer wfd.DecRef()
+ defer wfd.DecRef(ctx)
rfd, ok := <-rfdchan
if !ok {
t.Fatalf("failed to open pipe for reading %q", fileName)
}
- defer rfd.DecRef()
+ defer rfd.DecRef(ctx)
const msg = "vamos azul"
checkEmpty(ctx, t, rfd)
@@ -71,7 +71,7 @@ func TestSeparateFDs(t *testing.T) {
func TestNonblockingRead(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the read side as nonblocking.
pop := vfs.PathOperation{
@@ -85,7 +85,7 @@ func TestNonblockingRead(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for reading %q: %v", fileName, err)
}
- defer rfd.DecRef()
+ defer rfd.DecRef(ctx)
// Open the write side.
openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY}
@@ -93,7 +93,7 @@ func TestNonblockingRead(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
}
- defer wfd.DecRef()
+ defer wfd.DecRef(ctx)
const msg = "geh blau"
checkEmpty(ctx, t, rfd)
@@ -103,7 +103,7 @@ func TestNonblockingRead(t *testing.T) {
func TestNonblockingWriteError(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the write side as nonblocking, which should return ENXIO.
pop := vfs.PathOperation{
@@ -121,7 +121,7 @@ func TestNonblockingWriteError(t *testing.T) {
func TestSingleFD(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the pipe as readable and writable.
pop := vfs.PathOperation{
@@ -135,7 +135,7 @@ func TestSingleFD(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(ctx)
const msg = "forza blu"
checkEmpty(ctx, t, fd)
@@ -152,7 +152,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create VFS.
vfsObj := &vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index abbaa5d60..0710b65db 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -270,7 +270,7 @@ type regularFileFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {
+func (fd *regularFileFD) Release(context.Context) {
// noop
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 2545d88e9..68e615e8b 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -185,7 +185,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
case linux.S_IFDIR:
root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry
default:
- fs.vfsfs.DecRef()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
}
return &fs.vfsfs, &root.vfsd, nil
@@ -197,7 +197,7 @@ func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *au
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
+func (fs *filesystem) Release(ctx context.Context) {
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
}
@@ -249,12 +249,12 @@ func (d *dentry) TryIncRef() bool {
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef() {
- d.inode.decRef()
+func (d *dentry) DecRef(ctx context.Context) {
+ d.inode.decRef(ctx)
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
-func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
+func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {
if d.inode.isDir() {
events |= linux.IN_ISDIR
}
@@ -266,9 +266,9 @@ func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
d.inode.fs.mu.RLock()
// The ordering below is important, Linux always notifies the parent first.
if d.parent != nil {
- d.parent.inode.watches.Notify(d.name, events, cookie, et, deleted)
+ d.parent.inode.watches.Notify(ctx, d.name, events, cookie, et, deleted)
}
- d.inode.watches.Notify("", events, cookie, et, deleted)
+ d.inode.watches.Notify(ctx, "", events, cookie, et, deleted)
d.inode.fs.mu.RUnlock()
}
@@ -278,7 +278,7 @@ func (d *dentry) Watches() *vfs.Watches {
}
// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
-func (d *dentry) OnZeroWatches() {}
+func (d *dentry) OnZeroWatches(context.Context) {}
// inode represents a filesystem object.
type inode struct {
@@ -359,12 +359,12 @@ func (i *inode) incLinksLocked() {
// remove a reference on i as well.
//
// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
-func (i *inode) decLinksLocked() {
+func (i *inode) decLinksLocked(ctx context.Context) {
if i.nlink == 0 {
panic("tmpfs.inode.decLinksLocked() called with no existing links")
}
if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 {
- i.decRef()
+ i.decRef(ctx)
}
}
@@ -386,9 +386,9 @@ func (i *inode) tryIncRef() bool {
}
}
-func (i *inode) decRef() {
+func (i *inode) decRef(ctx context.Context) {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
- i.watches.HandleDeletion()
+ i.watches.HandleDeletion(ctx)
if regFile, ok := i.impl.(*regularFile); ok {
// Release memory used by regFile to store data. Since regFile is
// no longer usable, we don't need to grab any locks or update any
@@ -701,7 +701,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
- d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
}
return nil
}
@@ -724,7 +724,7 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption
}
// Generate inotify events.
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
@@ -736,13 +736,13 @@ func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
}
// Generate inotify events.
- d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
// NewMemfd creates a new tmpfs regular file and file description that can back
// an anonymous fd created by memfd_create.
-func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name string) (*vfs.FileDescription, error) {
+func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) {
fs, ok := mount.Filesystem().Impl().(*filesystem)
if !ok {
panic("NewMemfd() called with non-tmpfs mount")
@@ -757,7 +757,7 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s
}
d := fs.newDentry(inode)
- defer d.DecRef()
+ defer d.DecRef(ctx)
d.name = name
// Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
index a240fb276..6f3e3ae6f 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
@@ -34,7 +34,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr
creds := auth.CredentialsFromContext(ctx)
vfsObj := &vfs.VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
}
@@ -47,8 +47,8 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr
}
root := mntns.Root()
return vfsObj, root, func() {
- root.DecRef()
- mntns.DecRef()
+ root.DecRef(ctx)
+ mntns.DecRef(ctx)
}, nil
}
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 920fe4329..52ed5cea2 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -17,6 +17,7 @@ package kernel
import (
"syscall"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
@@ -31,7 +32,7 @@ type abstractEndpoint struct {
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (e *abstractEndpoint) WeakRefGone() {
+func (e *abstractEndpoint) WeakRefGone(context.Context) {
e.ns.mu.Lock()
if e.ns.endpoints[e.name].ep == e.ep {
delete(e.ns.endpoints, e.name)
@@ -64,9 +65,9 @@ type boundEndpoint struct {
}
// Release implements transport.BoundEndpoint.Release.
-func (e *boundEndpoint) Release() {
- e.rc.DecRef()
- e.BoundEndpoint.Release()
+func (e *boundEndpoint) Release(ctx context.Context) {
+ e.rc.DecRef(ctx)
+ e.BoundEndpoint.Release(ctx)
}
// BoundEndpoint retrieves the endpoint bound to the given name. The return
@@ -93,13 +94,13 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp
//
// When the last reference managed by rc is dropped, ep may be removed from the
// namespace.
-func (a *AbstractSocketNamespace) Bind(name string, ep transport.BoundEndpoint, rc refs.RefCounter) error {
+func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, rc refs.RefCounter) error {
a.mu.Lock()
defer a.mu.Unlock()
if ep, ok := a.endpoints[name]; ok {
if rc := ep.wr.Get(); rc != nil {
- rc.DecRef()
+ rc.DecRef(ctx)
return syscall.EADDRINUSE
}
}
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 4c0f1e41f..15519f0df 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -76,8 +76,8 @@ type pollEntry struct {
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
// weakReferenceGone is called when the file in the weak reference is destroyed.
// The poll entry is removed in response to this.
-func (p *pollEntry) WeakRefGone() {
- p.epoll.RemoveEntry(p.id)
+func (p *pollEntry) WeakRefGone(ctx context.Context) {
+ p.epoll.RemoveEntry(ctx, p.id)
}
// EventPoll holds all the state associated with an event poll object, that is,
@@ -144,14 +144,14 @@ func NewEventPoll(ctx context.Context) *fs.File {
// name matches fs/eventpoll.c:epoll_create1.
dirent := fs.NewDirent(ctx, anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]"))
// Release the initial dirent reference after NewFile takes a reference.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{
files: make(map[FileIdentifier]*pollEntry),
})
}
// Release implements fs.FileOperations.Release.
-func (e *EventPoll) Release() {
+func (e *EventPoll) Release(ctx context.Context) {
// We need to take the lock now because files may be attempting to
// remove entries in parallel if they get destroyed.
e.mu.Lock()
@@ -160,7 +160,7 @@ func (e *EventPoll) Release() {
// Go through all entries and clean up.
for _, entry := range e.files {
entry.id.File.EventUnregister(&entry.waiter)
- entry.file.Drop()
+ entry.file.Drop(ctx)
}
e.files = nil
}
@@ -423,7 +423,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter
}
// RemoveEntry a files from the collection of observed files.
-func (e *EventPoll) RemoveEntry(id FileIdentifier) error {
+func (e *EventPoll) RemoveEntry(ctx context.Context, id FileIdentifier) error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -445,7 +445,7 @@ func (e *EventPoll) RemoveEntry(id FileIdentifier) error {
// Remove file from map, and drop weak reference.
delete(e.files, id)
- entry.file.Drop()
+ entry.file.Drop(ctx)
return nil
}
diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go
index 22630e9c5..55b505593 100644
--- a/pkg/sentry/kernel/epoll/epoll_test.go
+++ b/pkg/sentry/kernel/epoll/epoll_test.go
@@ -26,7 +26,8 @@ func TestFileDestroyed(t *testing.T) {
f := filetest.NewTestFile(t)
id := FileIdentifier{f, 12}
- efile := NewEventPoll(contexttest.Context(t))
+ ctx := contexttest.Context(t)
+ efile := NewEventPoll(ctx)
e := efile.FileOperations.(*EventPoll)
if err := e.AddEntry(id, 0, waiter.EventIn, [2]int32{}); err != nil {
t.Fatalf("addEntry failed: %v", err)
@@ -44,7 +45,7 @@ func TestFileDestroyed(t *testing.T) {
}
// Destroy the file. Check that we get no more events.
- f.DecRef()
+ f.DecRef(ctx)
evt = e.ReadEvents(1)
if len(evt) != 0 {
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 87951adeb..bbf568dfc 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -70,7 +70,7 @@ func New(ctx context.Context, initVal uint64, semMode bool) *fs.File {
// name matches fs/eventfd.c:eventfd_file_create.
dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[eventfd]")
// Release the initial dirent reference after NewFile takes a reference.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{
val: initVal,
semMode: semMode,
@@ -106,7 +106,7 @@ func (e *EventOperations) HostFD() (int, error) {
}
// Release implements fs.FileOperations.Release.
-func (e *EventOperations) Release() {
+func (e *EventOperations) Release(context.Context) {
e.mu.Lock()
defer e.mu.Unlock()
if e.hostfd >= 0 {
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 4b7d234a4..ce53af69b 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -98,7 +98,7 @@ type FDTable struct {
func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
m := make(map[int32]descriptor)
- f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
m[fd] = descriptor{
file: file,
fileVFS2: fileVFS2,
@@ -109,6 +109,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
}
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
+ ctx := context.Background()
f.init() // Initialize table.
for fd, d := range m {
f.setAll(fd, d.file, d.fileVFS2, d.flags)
@@ -118,9 +119,9 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
// reference taken by set above.
switch {
case d.file != nil:
- d.file.DecRef()
+ d.file.DecRef(ctx)
case d.fileVFS2 != nil:
- d.fileVFS2.DecRef()
+ d.fileVFS2.DecRef(ctx)
}
}
}
@@ -144,14 +145,15 @@ func (f *FDTable) drop(file *fs.File) {
d.InotifyEvent(ev, 0)
// Drop the table reference.
- file.DecRef()
+ file.DecRef(context.Background())
}
// dropVFS2 drops the table reference.
func (f *FDTable) dropVFS2(file *vfs.FileDescription) {
// Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the
// entire file.
- err := file.UnlockPOSIX(context.Background(), f, 0, 0, linux.SEEK_SET)
+ ctx := context.Background()
+ err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET)
if err != nil && err != syserror.ENOLCK {
panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
}
@@ -161,10 +163,10 @@ func (f *FDTable) dropVFS2(file *vfs.FileDescription) {
if file.IsWritable() {
ev = linux.IN_CLOSE_WRITE
}
- file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(ctx, ev, 0, vfs.PathEvent)
// Drop the table's reference.
- file.DecRef()
+ file.DecRef(ctx)
}
// NewFDTable allocates a new FDTable that may be used by tasks in k.
@@ -175,15 +177,15 @@ func (k *Kernel) NewFDTable() *FDTable {
}
// destroy removes all of the file descriptors from the map.
-func (f *FDTable) destroy() {
- f.RemoveIf(func(*fs.File, *vfs.FileDescription, FDFlags) bool {
+func (f *FDTable) destroy(ctx context.Context) {
+ f.RemoveIf(ctx, func(*fs.File, *vfs.FileDescription, FDFlags) bool {
return true
})
}
// DecRef implements RefCounter.DecRef with destructor f.destroy.
-func (f *FDTable) DecRef() {
- f.DecRefWithDestructor(f.destroy)
+func (f *FDTable) DecRef(ctx context.Context) {
+ f.DecRefWithDestructor(ctx, f.destroy)
}
// Size returns the number of file descriptor slots currently allocated.
@@ -195,7 +197,7 @@ func (f *FDTable) Size() int {
// forEach iterates over all non-nil files in sorted order.
//
// It is the caller's responsibility to acquire an appropriate lock.
-func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
+func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
// retries tracks the number of failed TryIncRef attempts for the same FD.
retries := 0
fd := int32(0)
@@ -214,7 +216,7 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes
continue // Race caught.
}
fn(fd, file, nil, flags)
- file.DecRef()
+ file.DecRef(ctx)
case fileVFS2 != nil:
if !fileVFS2.TryIncRef() {
retries++
@@ -224,7 +226,7 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes
continue // Race caught.
}
fn(fd, nil, fileVFS2, flags)
- fileVFS2.DecRef()
+ fileVFS2.DecRef(ctx)
}
retries = 0
fd++
@@ -234,7 +236,8 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes
// String is a stringer for FDTable.
func (f *FDTable) String() string {
var buf strings.Builder
- f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ ctx := context.Background()
+ f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
switch {
case file != nil:
n, _ := file.Dirent.FullName(nil /* root */)
@@ -242,7 +245,7 @@ func (f *FDTable) String() string {
case fileVFS2 != nil:
vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem()
- name, err := vfsObj.PathnameWithDeleted(context.Background(), vfs.VirtualDentry{}, fileVFS2.VirtualDentry())
+ name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry())
if err != nil {
fmt.Fprintf(&buf, "\n", err)
return
@@ -541,9 +544,9 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) {
//
// Precondition: The caller must be running on the task goroutine, or Task.mu
// must be locked.
-func (f *FDTable) GetFDs() []int32 {
+func (f *FDTable) GetFDs(ctx context.Context) []int32 {
fds := make([]int32, 0, int(atomic.LoadInt32(&f.used)))
- f.forEach(func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
fds = append(fds, fd)
})
return fds
@@ -552,9 +555,9 @@ func (f *FDTable) GetFDs() []int32 {
// GetRefs returns a stable slice of references to all files and bumps the
// reference count on each. The caller must use DecRef on each reference when
// they're done using the slice.
-func (f *FDTable) GetRefs() []*fs.File {
+func (f *FDTable) GetRefs(ctx context.Context) []*fs.File {
files := make([]*fs.File, 0, f.Size())
- f.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ f.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
file.IncRef() // Acquire a reference for caller.
files = append(files, file)
})
@@ -564,9 +567,9 @@ func (f *FDTable) GetRefs() []*fs.File {
// GetRefsVFS2 returns a stable slice of references to all files and bumps the
// reference count on each. The caller must use DecRef on each reference when
// they're done using the slice.
-func (f *FDTable) GetRefsVFS2() []*vfs.FileDescription {
+func (f *FDTable) GetRefsVFS2(ctx context.Context) []*vfs.FileDescription {
files := make([]*vfs.FileDescription, 0, f.Size())
- f.forEach(func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) {
+ f.forEach(ctx, func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) {
file.IncRef() // Acquire a reference for caller.
files = append(files, file)
})
@@ -574,10 +577,10 @@ func (f *FDTable) GetRefsVFS2() []*vfs.FileDescription {
}
// Fork returns an independent FDTable.
-func (f *FDTable) Fork() *FDTable {
+func (f *FDTable) Fork(ctx context.Context) *FDTable {
clone := f.k.NewFDTable()
- f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
// The set function here will acquire an appropriate table
// reference for the clone. We don't need anything else.
switch {
@@ -622,11 +625,11 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) {
}
// RemoveIf removes all FDs where cond is true.
-func (f *FDTable) RemoveIf(cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) {
+func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) {
f.mu.Lock()
defer f.mu.Unlock()
- f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
if cond(file, fileVFS2, flags) {
f.set(fd, nil, FDFlags{}) // Clear from table.
// Update current available position.
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index 29f95a2c4..e3f30ba2a 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -154,7 +154,7 @@ func TestFDTable(t *testing.T) {
if ref == nil {
t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success")
}
- ref.DecRef()
+ ref.DecRef(ctx)
if ref, _ := fdTable.Remove(1); ref != nil {
t.Fatalf("r.Remove(1) for a removed FD: got success, want failure")
@@ -191,7 +191,7 @@ func BenchmarkFDLookupAndDecRef(b *testing.B) {
b.StartTimer() // Benchmark.
for i := 0; i < b.N; i++ {
tf, _ := fdTable.Get(fds[i%len(fds)])
- tf.DecRef()
+ tf.DecRef(ctx)
}
})
}
@@ -219,7 +219,7 @@ func BenchmarkFDLookupAndDecRefConcurrent(b *testing.B) {
defer wg.Done()
for i := 0; i < each; i++ {
tf, _ := fdTable.Get(fds[i%len(fds)])
- tf.DecRef()
+ tf.DecRef(ctx)
}
}()
}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index 47f78df9a..8f2d36d5a 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -17,6 +17,7 @@ package kernel
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -89,28 +90,28 @@ func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext {
// Note that there may still be calls to WorkingDirectory() or RootDirectory()
// (that return nil). This is because valid references may still be held via
// proc files or other mechanisms.
-func (f *FSContext) destroy() {
+func (f *FSContext) destroy(ctx context.Context) {
// Hold f.mu so that we don't race with RootDirectory() and
// WorkingDirectory().
f.mu.Lock()
defer f.mu.Unlock()
if VFS2Enabled {
- f.rootVFS2.DecRef()
+ f.rootVFS2.DecRef(ctx)
f.rootVFS2 = vfs.VirtualDentry{}
- f.cwdVFS2.DecRef()
+ f.cwdVFS2.DecRef(ctx)
f.cwdVFS2 = vfs.VirtualDentry{}
} else {
- f.root.DecRef()
+ f.root.DecRef(ctx)
f.root = nil
- f.cwd.DecRef()
+ f.cwd.DecRef(ctx)
f.cwd = nil
}
}
// DecRef implements RefCounter.DecRef with destructor f.destroy.
-func (f *FSContext) DecRef() {
- f.DecRefWithDestructor(f.destroy)
+func (f *FSContext) DecRef(ctx context.Context) {
+ f.DecRefWithDestructor(ctx, f.destroy)
}
// Fork forks this FSContext.
@@ -165,7 +166,7 @@ func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
// This will take an extra reference on the Dirent.
//
// This is not a valid call after destroy.
-func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
+func (f *FSContext) SetWorkingDirectory(ctx context.Context, d *fs.Dirent) {
if d == nil {
panic("FSContext.SetWorkingDirectory called with nil dirent")
}
@@ -180,21 +181,21 @@ func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
old := f.cwd
f.cwd = d
d.IncRef()
- old.DecRef()
+ old.DecRef(ctx)
}
// SetWorkingDirectoryVFS2 sets the current working directory.
// This will take an extra reference on the VirtualDentry.
//
// This is not a valid call after destroy.
-func (f *FSContext) SetWorkingDirectoryVFS2(d vfs.VirtualDentry) {
+func (f *FSContext) SetWorkingDirectoryVFS2(ctx context.Context, d vfs.VirtualDentry) {
f.mu.Lock()
defer f.mu.Unlock()
old := f.cwdVFS2
f.cwdVFS2 = d
d.IncRef()
- old.DecRef()
+ old.DecRef(ctx)
}
// RootDirectory returns the current filesystem root.
@@ -226,7 +227,7 @@ func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
// This will take an extra reference on the Dirent.
//
// This is not a valid call after free.
-func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
+func (f *FSContext) SetRootDirectory(ctx context.Context, d *fs.Dirent) {
if d == nil {
panic("FSContext.SetRootDirectory called with nil dirent")
}
@@ -241,13 +242,13 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
old := f.root
f.root = d
d.IncRef()
- old.DecRef()
+ old.DecRef(ctx)
}
// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
//
// This is not a valid call after free.
-func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
+func (f *FSContext) SetRootDirectoryVFS2(ctx context.Context, vd vfs.VirtualDentry) {
if !vd.Ok() {
panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
}
@@ -263,7 +264,7 @@ func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
vd.IncRef()
f.rootVFS2 = vd
f.mu.Unlock()
- old.DecRef()
+ old.DecRef(ctx)
}
// Umask returns the current umask.
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index c5021f2db..daa2dae76 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -51,6 +51,7 @@ go_test(
srcs = ["futex_test.go"],
library = ":futex",
deps = [
+ "//pkg/context",
"//pkg/sync",
"//pkg/usermem",
],
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index bcc1b29a8..e4dcc4d40 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -19,6 +19,7 @@ package futex
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -66,9 +67,9 @@ type Key struct {
Offset uint64
}
-func (k *Key) release() {
+func (k *Key) release(t Target) {
if k.MappingIdentity != nil {
- k.MappingIdentity.DecRef()
+ k.MappingIdentity.DecRef(t)
}
k.Mappable = nil
k.MappingIdentity = nil
@@ -94,6 +95,8 @@ func (k *Key) matches(k2 *Key) bool {
// Target abstracts memory accesses and keys.
type Target interface {
+ context.Context
+
// SwapUint32 gives access to usermem.IO.SwapUint32.
SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
@@ -296,7 +299,7 @@ func (b *bucket) wakeWaiterLocked(w *Waiter) {
// bucket "to".
//
// Preconditions: b and to must be locked.
-func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int {
+func (b *bucket) requeueLocked(t Target, to *bucket, key, nkey *Key, n int) int {
done := 0
for w := b.waiters.Front(); done < n && w != nil; {
if !w.key.matches(key) {
@@ -308,7 +311,7 @@ func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int {
requeued := w
w = w.Next() // Next iteration.
b.waiters.Remove(requeued)
- requeued.key.release()
+ requeued.key.release(t)
requeued.key = nkey.clone()
to.waiters.PushBack(requeued)
requeued.bucket.Store(to)
@@ -456,7 +459,7 @@ func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32
r := b.wakeLocked(&k, bitmask, n)
b.mu.Unlock()
- k.release()
+ k.release(t)
return r, nil
}
@@ -465,12 +468,12 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch
if err != nil {
return 0, err
}
- defer k1.release()
+ defer k1.release(t)
k2, err := getKey(t, naddr, private)
if err != nil {
return 0, err
}
- defer k2.release()
+ defer k2.release(t)
b1, b2 := m.lockBuckets(&k1, &k2)
defer b1.mu.Unlock()
@@ -488,7 +491,7 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch
done := b1.wakeLocked(&k1, ^uint32(0), nwake)
// Requeue the number required.
- b1.requeueLocked(b2, &k1, &k2, nreq)
+ b1.requeueLocked(t, b2, &k1, &k2, nreq)
return done, nil
}
@@ -515,12 +518,12 @@ func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwak
if err != nil {
return 0, err
}
- defer k1.release()
+ defer k1.release(t)
k2, err := getKey(t, addr2, private)
if err != nil {
return 0, err
}
- defer k2.release()
+ defer k2.release(t)
b1, b2 := m.lockBuckets(&k1, &k2)
defer b1.mu.Unlock()
@@ -571,7 +574,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo
// Perform our atomic check.
if err := check(t, addr, val); err != nil {
b.mu.Unlock()
- w.key.release()
+ w.key.release(t)
return err
}
@@ -585,7 +588,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo
// WaitComplete must be called when a Waiter previously added by WaitPrepare is
// no longer eligible to be woken.
-func (m *Manager) WaitComplete(w *Waiter) {
+func (m *Manager) WaitComplete(w *Waiter, t Target) {
// Remove w from the bucket it's in.
for {
b := w.bucket.Load()
@@ -617,7 +620,7 @@ func (m *Manager) WaitComplete(w *Waiter) {
}
// Release references held by the waiter.
- w.key.release()
+ w.key.release(t)
}
// LockPI attempts to lock the futex following the Priority-inheritance futex
@@ -648,13 +651,13 @@ func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, pri
success, err := m.lockPILocked(w, t, addr, tid, b, try)
if err != nil {
- w.key.release()
+ w.key.release(t)
b.mu.Unlock()
return false, err
}
if success || try {
// Release waiter if it's not going to be a wait.
- w.key.release()
+ w.key.release(t)
}
b.mu.Unlock()
return success, nil
@@ -730,7 +733,7 @@ func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool
err = m.unlockPILocked(t, addr, tid, b, &k)
- k.release()
+ k.release(t)
b.mu.Unlock()
return err
}
diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go
index 7c5c7665b..d0128c548 100644
--- a/pkg/sentry/kernel/futex/futex_test.go
+++ b/pkg/sentry/kernel/futex/futex_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"unsafe"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -29,28 +30,33 @@ import (
// testData implements the Target interface, and allows us to
// treat the address passed for futex operations as an index in
// a byte slice for testing simplicity.
-type testData []byte
+type testData struct {
+ context.Context
+ data []byte
+}
const sizeofInt32 = 4
func newTestData(size uint) testData {
- return make([]byte, size)
+ return testData{
+ data: make([]byte, size),
+ }
}
func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
- val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t[addr])), new)
+ val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), new)
return val, nil
}
func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
- if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t[addr])), old, new) {
+ if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), old, new) {
return old, nil
}
- return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil
}
func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) {
- return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil
}
func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) {
@@ -83,7 +89,7 @@ func TestFutexWake(t *testing.T) {
// Start waiting for wakeup.
w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w)
+ defer m.WaitComplete(w, d)
// Perform a wakeup.
if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 {
@@ -106,7 +112,7 @@ func TestFutexWakeBitmask(t *testing.T) {
// Start waiting for wakeup.
w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff)
- defer m.WaitComplete(w)
+ defer m.WaitComplete(w, d)
// Perform a wakeup using the wrong bitmask.
if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 {
@@ -141,7 +147,7 @@ func TestFutexWakeTwo(t *testing.T) {
var ws [3]*Waiter
for i := range ws {
ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(ws[i])
+ defer m.WaitComplete(ws[i], d)
}
// Perform two wakeups.
@@ -174,9 +180,9 @@ func TestFutexWakeUnrelated(t *testing.T) {
// Start two waiters waiting for wakeup on different addresses.
w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform two wakeups on the second address.
if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 {
@@ -216,9 +222,9 @@ func TestWakeOpFirstNonEmpty(t *testing.T) {
// Add two waiters on address 0.
w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform 10 wakeups on address 0.
if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 {
@@ -244,9 +250,9 @@ func TestWakeOpSecondNonEmpty(t *testing.T) {
// Add two waiters on address sizeofInt32.
w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform 10 wakeups on address sizeofInt32 (contingent on
// d.Op(0), which should succeed).
@@ -273,9 +279,9 @@ func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) {
// Add two waiters on address sizeofInt32.
w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform 10 wakeups on address sizeofInt32 (contingent on
// d.Op(1), which should fail).
@@ -302,15 +308,15 @@ func TestWakeOpAllNonEmpty(t *testing.T) {
// Add two waiters on address 0.
w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Add two waiters on address sizeofInt32.
w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w3)
+ defer m.WaitComplete(w3, d)
w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w4)
+ defer m.WaitComplete(w4, d)
// Perform 10 wakeups on address 0 (unconditionally), and 10
// wakeups on address sizeofInt32 (contingent on d.Op(0), which
@@ -344,15 +350,15 @@ func TestWakeOpAllNonEmptyFailingOp(t *testing.T) {
// Add two waiters on address 0.
w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Add two waiters on address sizeofInt32.
w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w3)
+ defer m.WaitComplete(w3, d)
w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w4)
+ defer m.WaitComplete(w4, d)
// Perform 10 wakeups on address 0 (unconditionally), and 10
// wakeups on address sizeofInt32 (contingent on d.Op(1), which
@@ -388,7 +394,7 @@ func TestWakeOpSameAddress(t *testing.T) {
var ws [4]*Waiter
for i := range ws {
ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(ws[i])
+ defer m.WaitComplete(ws[i], d)
}
// Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup
@@ -422,7 +428,7 @@ func TestWakeOpSameAddressFailingOp(t *testing.T) {
var ws [4]*Waiter
for i := range ws {
ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(ws[i])
+ defer m.WaitComplete(ws[i], d)
}
// Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup
@@ -472,7 +478,7 @@ func (t *testMutex) Lock() {
for {
// Attempt to grab the lock.
if atomic.CompareAndSwapUint32(
- (*uint32)(unsafe.Pointer(&t.d[t.a])),
+ (*uint32)(unsafe.Pointer(&t.d.data[t.a])),
testMutexUnlocked,
testMutexLocked) {
// Lock held.
@@ -490,7 +496,7 @@ func (t *testMutex) Lock() {
panic("WaitPrepare returned unexpected error: " + err.Error())
}
<-w.C
- t.m.WaitComplete(w)
+ t.m.WaitComplete(w, t.d)
}
}
@@ -498,7 +504,7 @@ func (t *testMutex) Lock() {
// This will notify any waiters via the futex manager.
func (t *testMutex) Unlock() {
// Unlock.
- atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d[t.a])), testMutexUnlocked)
+ atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d.data[t.a])), testMutexUnlocked)
// Notify all waiters.
t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 15dae0f5b..316df249d 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -376,7 +376,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.netlinkPorts = port.New()
if VFS2Enabled {
- if err := k.vfs.Init(); err != nil {
+ ctx := k.SupervisorContext()
+ if err := k.vfs.Init(ctx); err != nil {
return fmt.Errorf("failed to initialize VFS: %v", err)
}
@@ -384,19 +385,19 @@ func (k *Kernel) Init(args InitKernelArgs) error {
if err != nil {
return fmt.Errorf("failed to create pipefs filesystem: %v", err)
}
- defer pipeFilesystem.DecRef()
+ defer pipeFilesystem.DecRef(ctx)
pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{})
if err != nil {
return fmt.Errorf("failed to create pipefs mount: %v", err)
}
k.pipeMount = pipeMount
- tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(k.SupervisorContext(), &k.vfs, auth.NewRootCredentials(k.rootUserNamespace))
+ tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(ctx, &k.vfs, auth.NewRootCredentials(k.rootUserNamespace))
if err != nil {
return fmt.Errorf("failed to create tmpfs filesystem: %v", err)
}
- defer tmpfsFilesystem.DecRef()
- defer tmpfsRoot.DecRef()
+ defer tmpfsFilesystem.DecRef(ctx)
+ defer tmpfsRoot.DecRef(ctx)
shmMount, err := k.vfs.NewDisconnectedMount(tmpfsFilesystem, tmpfsRoot, &vfs.MountOptions{})
if err != nil {
return fmt.Errorf("failed to create tmpfs mount: %v", err)
@@ -407,7 +408,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
if err != nil {
return fmt.Errorf("failed to create sockfs filesystem: %v", err)
}
- defer socketFilesystem.DecRef()
+ defer socketFilesystem.DecRef(ctx)
socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{})
if err != nil {
return fmt.Errorf("failed to create sockfs mount: %v", err)
@@ -430,8 +431,8 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
defer k.extMu.Unlock()
// Stop time.
- k.pauseTimeLocked()
- defer k.resumeTimeLocked()
+ k.pauseTimeLocked(ctx)
+ defer k.resumeTimeLocked(ctx)
// Evict all evictable MemoryFile allocations.
k.mf.StartEvictions()
@@ -447,12 +448,12 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// Remove all epoll waiter objects from underlying wait queues.
// NOTE: for programs to resume execution in future snapshot scenarios,
// we will need to re-establish these waiter objects after saving.
- k.tasks.unregisterEpollWaiters()
+ k.tasks.unregisterEpollWaiters(ctx)
// Clear the dirent cache before saving because Dirents must be Loaded in a
// particular order (parents before children), and Loading dirents from a cache
// breaks that order.
- if err := k.flushMountSourceRefs(); err != nil {
+ if err := k.flushMountSourceRefs(ctx); err != nil {
return err
}
@@ -505,7 +506,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// flushMountSourceRefs flushes the MountSources for all mounted filesystems
// and open FDs.
-func (k *Kernel) flushMountSourceRefs() error {
+func (k *Kernel) flushMountSourceRefs(ctx context.Context) error {
// Flush all mount sources for currently mounted filesystems in each task.
flushed := make(map[*fs.MountNamespace]struct{})
k.tasks.mu.RLock()
@@ -521,7 +522,7 @@ func (k *Kernel) flushMountSourceRefs() error {
// There may be some open FDs whose filesystems have been unmounted. We
// must flush those as well.
- return k.tasks.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error {
+ return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
file.Dirent.Inode.MountSource.FlushDirentRefs()
return nil
})
@@ -531,7 +532,7 @@ func (k *Kernel) flushMountSourceRefs() error {
// each task.
//
// Precondition: Must be called with the kernel paused.
-func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) (err error) {
+func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.FileDescription) error) (err error) {
// TODO(gvisor.dev/issue/1663): Add save support for VFS2.
if VFS2Enabled {
return nil
@@ -544,7 +545,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error)
if t.fdTable == nil {
continue
}
- t.fdTable.forEach(func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) {
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) {
if lastErr := f(file, fileVFS2); lastErr != nil && err == nil {
err = lastErr
}
@@ -555,7 +556,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error)
func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
// TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- return ts.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error {
+ return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
}
@@ -602,7 +603,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
return nil
}
-func (ts *TaskSet) unregisterEpollWaiters() {
+func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
// TODO(gvisor.dev/issue/1663): Add save support for VFS2.
if VFS2Enabled {
return
@@ -623,7 +624,7 @@ func (ts *TaskSet) unregisterEpollWaiters() {
if _, ok := processed[t.fdTable]; ok {
continue
}
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
e.UnregisterEpollWaiters()
}
@@ -900,7 +901,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
root := args.MountNamespaceVFS2.Root()
// The call to newFSContext below will take a reference on root, so we
// don't need to hold this one.
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Grab the working directory.
wd := root // Default.
@@ -918,7 +919,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
if err != nil {
return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
}
- defer wd.DecRef()
+ defer wd.DecRef(ctx)
}
opener = fsbridge.NewVFSLookup(mntnsVFS2, root, wd)
fsContext = NewFSContextVFS2(root, wd, args.Umask)
@@ -933,7 +934,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
root := mntns.Root()
// The call to newFSContext below will take a reference on root, so we
// don't need to hold this one.
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Grab the working directory.
remainingTraversals := args.MaxSymlinkTraversals
@@ -944,7 +945,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
if err != nil {
return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
}
- defer wd.DecRef()
+ defer wd.DecRef(ctx)
}
opener = fsbridge.NewFSLookup(mntns, root, wd)
fsContext = newFSContext(root, wd, args.Umask)
@@ -1054,7 +1055,7 @@ func (k *Kernel) Start() error {
// If k was created by LoadKernelFrom, timers were stopped during
// Kernel.SaveTo and need to be resumed. If k was created by NewKernel,
// this is a no-op.
- k.resumeTimeLocked()
+ k.resumeTimeLocked(k.SupervisorContext())
// Start task goroutines.
k.tasks.mu.RLock()
defer k.tasks.mu.RUnlock()
@@ -1068,7 +1069,7 @@ func (k *Kernel) Start() error {
//
// Preconditions: Any task goroutines running in k must be stopped. k.extMu
// must be locked.
-func (k *Kernel) pauseTimeLocked() {
+func (k *Kernel) pauseTimeLocked(ctx context.Context) {
// k.cpuClockTicker may be nil since Kernel.SaveTo() may be called before
// Kernel.Start().
if k.cpuClockTicker != nil {
@@ -1090,7 +1091,7 @@ func (k *Kernel) pauseTimeLocked() {
// This means we'll iterate FDTables shared by multiple tasks repeatedly,
// but ktime.Timer.Pause is idempotent so this is harmless.
if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
if VFS2Enabled {
if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok {
tfd.PauseTimer()
@@ -1112,7 +1113,7 @@ func (k *Kernel) pauseTimeLocked() {
//
// Preconditions: Any task goroutines running in k must be stopped. k.extMu
// must be locked.
-func (k *Kernel) resumeTimeLocked() {
+func (k *Kernel) resumeTimeLocked(ctx context.Context) {
if k.cpuClockTicker != nil {
k.cpuClockTicker.Resume()
}
@@ -1126,7 +1127,7 @@ func (k *Kernel) resumeTimeLocked() {
}
}
if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
if VFS2Enabled {
if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok {
tfd.ResumeTimer()
@@ -1511,7 +1512,7 @@ type SocketEntry struct {
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *SocketEntry) WeakRefGone() {
+func (s *SocketEntry) WeakRefGone(context.Context) {
s.k.extMu.Lock()
s.k.sockets.Remove(s)
s.k.extMu.Unlock()
@@ -1600,7 +1601,7 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return vfs.VirtualDentry{}
}
mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
// Root() takes a reference on the root dirent for us.
return mntns.Root()
case vfs.CtxMountNamespace:
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index 4b688c627..6497dc4ba 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -93,7 +93,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
if !waitFor(&i.mu, &i.wWakeup, ctx) {
- r.DecRef()
+ r.DecRef(ctx)
return nil, syserror.ErrInterrupted
}
}
@@ -111,12 +111,12 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
// On a nonblocking, write-only open, the open fails with ENXIO if the
// read side isn't open yet.
if flags.NonBlocking {
- w.DecRef()
+ w.DecRef(ctx)
return nil, syserror.ENXIO
}
if !waitFor(&i.mu, &i.rWakeup, ctx) {
- w.DecRef()
+ w.DecRef(ctx)
return nil, syserror.ErrInterrupted
}
}
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index ab75a87ff..ce0db5583 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -167,7 +167,7 @@ func TestClosedReaderBlocksWriteOpen(t *testing.T) {
f := NewInodeOperations(ctx, perms, newNamedPipe(t))
rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil)
- rFile.DecRef()
+ rFile.DecRef(ctx)
wDone := make(chan struct{})
// This open for write should block because the reader is now gone.
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 79645d7d2..297e8f28f 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -152,7 +152,7 @@ func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.
d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino))
// The p.Open calls below will each take a reference on the Dirent. We
// must drop the one we already have.
- defer d.DecRef()
+ defer d.DecRef(ctx)
return p.Open(ctx, d, fs.FileFlags{Read: true}), p.Open(ctx, d, fs.FileFlags{Write: true})
}
diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go
index bda739dbe..fe97e9800 100644
--- a/pkg/sentry/kernel/pipe/pipe_test.go
+++ b/pkg/sentry/kernel/pipe/pipe_test.go
@@ -27,8 +27,8 @@ import (
func TestPipeRW(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, 65536, 4096)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
msg := []byte("here's some bytes")
wantN := int64(len(msg))
@@ -47,8 +47,8 @@ func TestPipeRW(t *testing.T) {
func TestPipeReadBlock(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, 65536, 4096)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1)))
if n != 0 || err != syserror.ErrWouldBlock {
@@ -62,8 +62,8 @@ func TestPipeWriteBlock(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
msg := make([]byte, capacity+1)
n, err := w.Writev(ctx, usermem.BytesIOSequence(msg))
@@ -77,8 +77,8 @@ func TestPipeWriteUntilEnd(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
msg := []byte("here's some bytes")
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index aacf28da2..6d58b682f 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -33,7 +33,7 @@ import (
// the old fs architecture.
// Release cleans up the pipe's state.
-func (p *Pipe) Release() {
+func (p *Pipe) Release(context.Context) {
p.rClose()
p.wClose()
diff --git a/pkg/sentry/kernel/pipe/reader.go b/pkg/sentry/kernel/pipe/reader.go
index 7724b4452..ac18785c0 100644
--- a/pkg/sentry/kernel/pipe/reader.go
+++ b/pkg/sentry/kernel/pipe/reader.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -29,7 +30,7 @@ type Reader struct {
// Release implements fs.FileOperations.Release.
//
// This overrides ReaderWriter.Release.
-func (r *Reader) Release() {
+func (r *Reader) Release(context.Context) {
r.Pipe.rClose()
// Wake up writers.
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 45d4c5fc1..28f998e45 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -101,7 +101,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
// If this pipe is being opened as blocking and there's no
// writer, we have to wait for a writer to open the other end.
if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
- fd.DecRef()
+ fd.DecRef(ctx)
return nil, syserror.EINTR
}
@@ -112,12 +112,12 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
// Non-blocking, write-only opens fail with ENXIO when the read
// side isn't open yet.
if statusFlags&linux.O_NONBLOCK != 0 {
- fd.DecRef()
+ fd.DecRef(ctx)
return nil, syserror.ENXIO
}
// Wait for a reader to open the other end.
if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
- fd.DecRef()
+ fd.DecRef(ctx)
return nil, syserror.EINTR
}
}
@@ -169,7 +169,7 @@ type VFSPipeFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *VFSPipeFD) Release() {
+func (fd *VFSPipeFD) Release(context.Context) {
var event waiter.EventMask
if fd.vfsfd.IsReadable() {
fd.pipe.rClose()
diff --git a/pkg/sentry/kernel/pipe/writer.go b/pkg/sentry/kernel/pipe/writer.go
index 5bc6aa931..ef4b70ca3 100644
--- a/pkg/sentry/kernel/pipe/writer.go
+++ b/pkg/sentry/kernel/pipe/writer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -29,7 +30,7 @@ type Writer struct {
// Release implements fs.FileOperations.Release.
//
// This overrides ReaderWriter.Release.
-func (w *Writer) Release() {
+func (w *Writer) Release(context.Context) {
w.Pipe.wClose()
// Wake up readers.
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 0e19286de..5c4c622c2 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -16,6 +16,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
@@ -70,7 +71,7 @@ func (s *Session) incRef() {
//
// Precondition: callers must hold TaskSet.mu for writing.
func (s *Session) decRef() {
- s.refs.DecRefWithDestructor(func() {
+ s.refs.DecRefWithDestructor(nil, func(context.Context) {
// Remove translations from the leader.
for ns := s.leader.pidns; ns != nil; ns = ns.parent {
id := ns.sids[s]
@@ -162,7 +163,7 @@ func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) {
}
alive := true
- pg.refs.DecRefWithDestructor(func() {
+ pg.refs.DecRefWithDestructor(nil, func(context.Context) {
alive = false // don't bother with handleOrphan.
// Remove translations from the originator.
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 55b4c2cdb..13ec7afe0 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -431,8 +431,8 @@ func (s *Shm) InodeID() uint64 {
// DecRef overrides refs.RefCount.DecRef with a destructor.
//
// Precondition: Caller must not hold s.mu.
-func (s *Shm) DecRef() {
- s.DecRefWithDestructor(s.destroy)
+func (s *Shm) DecRef(ctx context.Context) {
+ s.DecRefWithDestructor(ctx, s.destroy)
}
// Msync implements memmap.MappingIdentity.Msync. Msync is a no-op for shm
@@ -642,7 +642,7 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
return nil
}
-func (s *Shm) destroy() {
+func (s *Shm) destroy(context.Context) {
s.mfp.MemoryFile().DecRef(s.fr)
s.registry.remove(s)
}
@@ -651,7 +651,7 @@ func (s *Shm) destroy() {
// destroyed once it has no references. MarkDestroyed may be called multiple
// times, and is safe to call after a segment has already been destroyed. See
// shmctl(IPC_RMID).
-func (s *Shm) MarkDestroyed() {
+func (s *Shm) MarkDestroyed(ctx context.Context) {
s.registry.dissociateKey(s)
s.mu.Lock()
@@ -663,7 +663,7 @@ func (s *Shm) MarkDestroyed() {
//
// N.B. This cannot be the final DecRef, as the caller also
// holds a reference.
- s.DecRef()
+ s.DecRef(ctx)
return
}
}
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 8243bb93e..b07e1c1bd 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -76,7 +76,7 @@ func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
}
// Release implements fs.FileOperations.Release.
-func (s *SignalOperations) Release() {}
+func (s *SignalOperations) Release(context.Context) {}
// Mask returns the signal mask.
func (s *SignalOperations) Mask() linux.SignalSet {
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index c4db05bd8..5aee699e7 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -730,17 +730,17 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
func (t *Task) IsChrooted() bool {
if VFS2Enabled {
realRoot := t.mountNamespaceVFS2.Root()
- defer realRoot.DecRef()
+ defer realRoot.DecRef(t)
root := t.fsContext.RootDirectoryVFS2()
- defer root.DecRef()
+ defer root.DecRef(t)
return root != realRoot
}
realRoot := t.tg.mounts.Root()
- defer realRoot.DecRef()
+ defer realRoot.DecRef(t)
root := t.fsContext.RootDirectory()
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(t)
}
return root != realRoot
}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index e1ecca99e..fe6ba6041 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -237,7 +237,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
var fdTable *FDTable
if opts.NewFiles {
- fdTable = t.fdTable.Fork()
+ fdTable = t.fdTable.Fork(t)
} else {
fdTable = t.fdTable
fdTable.IncRef()
@@ -294,7 +294,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
nt, err := t.tg.pidns.owner.NewTask(cfg)
if err != nil {
if opts.NewThreadGroup {
- tg.release()
+ tg.release(t)
}
return 0, nil, err
}
@@ -510,7 +510,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
var oldFDTable *FDTable
if opts.NewFiles {
oldFDTable = t.fdTable
- t.fdTable = oldFDTable.Fork()
+ t.fdTable = oldFDTable.Fork(t)
}
var oldFSContext *FSContext
if opts.NewFSContext {
@@ -519,10 +519,10 @@ func (t *Task) Unshare(opts *SharingOptions) error {
}
t.mu.Unlock()
if oldFDTable != nil {
- oldFDTable.DecRef()
+ oldFDTable.DecRef(t)
}
if oldFSContext != nil {
- oldFSContext.DecRef()
+ oldFSContext.DecRef(t)
}
return nil
}
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 7803b98d0..47c28b8ff 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -199,11 +199,11 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.tg.pidns.owner.mu.Unlock()
oldFDTable := t.fdTable
- t.fdTable = t.fdTable.Fork()
- oldFDTable.DecRef()
+ t.fdTable = t.fdTable.Fork(t)
+ oldFDTable.DecRef(t)
// Remove FDs with the CloseOnExec flag set.
- t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool {
+ t.fdTable.RemoveIf(t, func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool {
return flags.CloseOnExec
})
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 231ac548a..c165d6cb1 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -269,12 +269,12 @@ func (*runExitMain) execute(t *Task) taskRunState {
// Releasing the MM unblocks a blocked CLONE_VFORK parent.
t.unstopVforkParent()
- t.fsContext.DecRef()
- t.fdTable.DecRef()
+ t.fsContext.DecRef(t)
+ t.fdTable.DecRef(t)
t.mu.Lock()
if t.mountNamespaceVFS2 != nil {
- t.mountNamespaceVFS2.DecRef()
+ t.mountNamespaceVFS2.DecRef(t)
t.mountNamespaceVFS2 = nil
}
t.mu.Unlock()
@@ -282,7 +282,7 @@ func (*runExitMain) execute(t *Task) taskRunState {
// If this is the last task to exit from the thread group, release the
// thread group's resources.
if lastExiter {
- t.tg.release()
+ t.tg.release(t)
}
// Detach tracees.
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index eeccaa197..ab86ceedc 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -203,6 +203,6 @@ func (t *Task) traceExecEvent(tc *TaskContext) {
trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>")
return
}
- defer file.DecRef()
+ defer file.DecRef(t)
trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t))
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 8485fb4b6..64c1e120a 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -102,10 +102,10 @@ func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) {
t, err := ts.newTask(cfg)
if err != nil {
cfg.TaskContext.release()
- cfg.FSContext.DecRef()
- cfg.FDTable.DecRef()
+ cfg.FSContext.DecRef(t)
+ cfg.FDTable.DecRef(t)
if cfg.MountNamespaceVFS2 != nil {
- cfg.MountNamespaceVFS2.DecRef()
+ cfg.MountNamespaceVFS2.DecRef(t)
}
return nil, err
}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 4dfd2c990..0b34c0099 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -308,7 +308,7 @@ func (tg *ThreadGroup) Limits() *limits.LimitSet {
}
// release releases the thread group's resources.
-func (tg *ThreadGroup) release() {
+func (tg *ThreadGroup) release(t *Task) {
// Timers must be destroyed without holding the TaskSet or signal mutexes
// since timers send signals with Timer.mu locked.
tg.itimerRealTimer.Destroy()
@@ -325,7 +325,7 @@ func (tg *ThreadGroup) release() {
it.DestroyTimer()
}
if tg.mounts != nil {
- tg.mounts.DecRef()
+ tg.mounts.DecRef(t)
}
}
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index ddeaff3db..20dd1cc21 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -281,7 +281,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
}
defer func() {
if mopts.MappingIdentity != nil {
- mopts.MappingIdentity.DecRef()
+ mopts.MappingIdentity.DecRef(ctx)
}
}()
if err := f.ConfigureMMap(ctx, &mopts); err != nil {
@@ -663,7 +663,7 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error
ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err)
return loadedELF{}, nil, err
}
- defer intFile.DecRef()
+ defer intFile.DecRef(ctx)
interp, err = loadInterpreterELF(ctx, args.MemoryManager, intFile, bin)
if err != nil {
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 986c7fb4d..8d6802ea3 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -154,7 +154,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
return loadedELF{}, nil, nil, nil, err
}
// Ensure file is release in case the code loops or errors out.
- defer args.File.DecRef()
+ defer args.File.DecRef(ctx)
} else {
if err := checkIsRegularFile(ctx, args.File, args.Filename); err != nil {
return loadedELF{}, nil, nil, nil, err
@@ -223,7 +223,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
// Load the VDSO.
vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
@@ -292,7 +292,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
m.SetEnvvStart(sl.EnvvStart)
m.SetEnvvEnd(sl.EnvvEnd)
m.SetAuxv(auxv)
- m.SetExecutable(file)
+ m.SetExecutable(ctx, file)
symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn")
if err != nil {
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index c188f6c29..59c92c7e8 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -238,7 +238,7 @@ type MappingIdentity interface {
IncRef()
// DecRef decrements the MappingIdentity's reference count.
- DecRef()
+ DecRef(ctx context.Context)
// MappedName returns the application-visible name shown in
// /proc/[pid]/maps.
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 1999ec706..16fea53c4 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -258,8 +258,8 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) {
}
// DecRef implements refs.RefCounter.DecRef.
-func (m *aioMappable) DecRef() {
- m.AtomicRefCount.DecRefWithDestructor(func() {
+func (m *aioMappable) DecRef(ctx context.Context) {
+ m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) {
m.mfp.MemoryFile().DecRef(m.fr)
})
}
@@ -367,7 +367,7 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
if err != nil {
return 0, err
}
- defer m.DecRef()
+ defer m.DecRef(ctx)
addr, err := mm.MMap(ctx, memmap.MMapOpts{
Length: aioRingBufferSize,
MappingIdentity: m,
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index aac56679b..4d7773f8b 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -258,7 +258,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) {
mm.executable = nil
mm.metadataMu.Unlock()
if exe != nil {
- exe.DecRef()
+ exe.DecRef(ctx)
}
mm.activeMu.Lock()
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
index 28e5057f7..0cfd60f6c 100644
--- a/pkg/sentry/mm/metadata.go
+++ b/pkg/sentry/mm/metadata.go
@@ -15,6 +15,7 @@
package mm
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/usermem"
@@ -147,7 +148,7 @@ func (mm *MemoryManager) Executable() fsbridge.File {
// SetExecutable sets the executable.
//
// This takes a reference on d.
-func (mm *MemoryManager) SetExecutable(file fsbridge.File) {
+func (mm *MemoryManager) SetExecutable(ctx context.Context, file fsbridge.File) {
mm.metadataMu.Lock()
// Grab a new reference.
@@ -164,7 +165,7 @@ func (mm *MemoryManager) SetExecutable(file fsbridge.File) {
// Do this without holding the lock, since it may wind up doing some
// I/O to sync the dirent, etc.
if orig != nil {
- orig.DecRef()
+ orig.DecRef(ctx)
}
}
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index 0e142fb11..4cdb52eb6 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -50,8 +50,8 @@ func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.F
}
// DecRef implements refs.RefCounter.DecRef.
-func (m *SpecialMappable) DecRef() {
- m.AtomicRefCount.DecRefWithDestructor(func() {
+func (m *SpecialMappable) DecRef(ctx context.Context) {
+ m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) {
m.mfp.MemoryFile().DecRef(m.fr)
})
}
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 3f496aa9f..e74d4e1c1 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -101,7 +101,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme
if err != nil {
return 0, err
}
- defer m.DecRef()
+ defer m.DecRef(ctx)
opts.MappingIdentity = m
opts.Mappable = m
}
@@ -1191,7 +1191,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length ui
mr := vseg.mappableRangeOf(vseg.Range().Intersect(ar))
mm.mappingMu.RUnlock()
err := id.Msync(ctx, mr)
- id.DecRef()
+ id.DecRef(ctx)
if err != nil {
return err
}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index 16d8207e9..bd751d696 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -377,7 +377,7 @@ func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRa
vma.mappable.RemoveMapping(ctx, mm, vmaAR, vma.off, vma.canWriteMappableLocked())
}
if vma.id != nil {
- vma.id.DecRef()
+ vma.id.DecRef(ctx)
}
mm.usageAS -= uint64(vmaAR.Length())
if vma.isPrivateDataLocked() {
@@ -446,7 +446,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa
}
if vma2.id != nil {
- vma2.id.DecRef()
+ vma2.id.DecRef(context.Background())
}
return vma1, true
}
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 8b439a078..70ccf77a7 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -68,7 +68,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
for _, fd := range fds {
file := t.GetFile(fd)
if file == nil {
- files.Release()
+ files.Release(t)
return nil, syserror.EBADF
}
files = append(files, file)
@@ -100,9 +100,9 @@ func (fs *RightsFiles) Clone() transport.RightsControlMessage {
}
// Release implements transport.RightsControlMessage.Release.
-func (fs *RightsFiles) Release() {
+func (fs *RightsFiles) Release(ctx context.Context) {
for _, f := range *fs {
- f.DecRef()
+ f.DecRef(ctx)
}
*fs = nil
}
@@ -115,7 +115,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32
fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{
CloseOnExec: cloexec,
})
- files[0].DecRef()
+ files[0].DecRef(t)
files = files[1:]
if err != nil {
t.Warningf("Error inserting FD: %v", err)
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
index fd08179be..d9621968c 100644
--- a/pkg/sentry/socket/control/control_vfs2.go
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -46,7 +46,7 @@ func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) {
for _, fd := range fds {
file := t.GetFileVFS2(fd)
if file == nil {
- files.Release()
+ files.Release(t)
return nil, syserror.EBADF
}
files = append(files, file)
@@ -78,9 +78,9 @@ func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage {
}
// Release implements transport.RightsControlMessage.Release.
-func (fs *RightsFilesVFS2) Release() {
+func (fs *RightsFilesVFS2) Release(ctx context.Context) {
for _, f := range *fs {
- f.DecRef()
+ f.DecRef(ctx)
}
*fs = nil
}
@@ -93,7 +93,7 @@ func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int)
fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{
CloseOnExec: cloexec,
})
- files[0].DecRef()
+ files[0].DecRef(t)
files = files[1:]
if err != nil {
t.Warningf("Error inserting FD: %v", err)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 532a1ea5d..242e6bf76 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -100,12 +100,12 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc
return nil, syserr.FromError(err)
}
dirent := socket.NewDirent(ctx, socketDevice)
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil
}
// Release implements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release() {
+func (s *socketOpsCommon) Release(context.Context) {
fdnotifier.RemoveFD(int32(s.fd))
syscall.Close(s.fd)
}
@@ -269,7 +269,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
syscall.Close(fd)
return 0, nil, 0, err
}
- defer f.DecRef()
+ defer f.DecRef(t)
kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{
CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
@@ -281,7 +281,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
syscall.Close(fd)
return 0, nil, 0, err
}
- defer f.DecRef()
+ defer f.DecRef(t)
kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{
CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 0d45e5053..31e374833 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -97,7 +97,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int
}
d := socket.NewDirent(t, netlinkSocketDevice)
- defer d.DecRef()
+ defer d.DecRef(t)
return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 98ca7add0..68a9b9a96 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -140,14 +140,14 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke
// Bind the endpoint for good measure so we can connect to it. The
// bound address will never be exposed.
if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
- ep.Close()
+ ep.Close(t)
return nil, err
}
// Create a connection from which the kernel can write messages.
connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
if err != nil {
- ep.Close()
+ ep.Close(t)
return nil, err
}
@@ -164,9 +164,9 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke
}
// Release implements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release() {
- s.connection.Release()
- s.ep.Close()
+func (s *socketOpsCommon) Release(ctx context.Context) {
+ s.connection.Release(ctx)
+ s.ep.Close(ctx)
if s.bound {
s.ports.Release(s.protocol.Protocol(), s.portID)
@@ -621,7 +621,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys
if len(bufs) > 0 {
// RecvMsg never receives the address, so we don't need to send
// one.
- _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send(ctx, bufs, cms, tcpip.FullAddress{})
// If the buffer is full, we simply drop messages, just like
// Linux.
if err != nil && err != syserr.ErrWouldBlock {
@@ -648,7 +648,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys
// Add the dump_done_errno payload.
m.Put(int64(0))
- _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
return err
}
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index dbcd8b49a..a38d25da9 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -57,14 +57,14 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
// Bind the endpoint for good measure so we can connect to it. The
// bound address will never be exposed.
if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
- ep.Close()
+ ep.Close(t)
return nil, err
}
// Create a connection from which the kernel can write messages.
connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
if err != nil {
- ep.Close()
+ ep.Close(t)
return nil, err
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 31a168f7e..e4846bc0b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -330,7 +330,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
}
dirent := socket.NewDirent(t, netstackDevice)
- defer dirent.DecRef()
+ defer dirent.DecRef(t)
return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{
socketOpsCommon: socketOpsCommon{
Queue: queue,
@@ -479,7 +479,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
// Release implements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release() {
+func (s *socketOpsCommon) Release(context.Context) {
s.Endpoint.Close()
}
@@ -854,7 +854,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
if err != nil {
return 0, nil, 0, err
}
- defer ns.DecRef()
+ defer ns.DecRef(t)
if flags&linux.SOCK_NONBLOCK != 0 {
flags := ns.Flags()
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index a9025b0ec..3335e7430 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -169,7 +169,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
if err != nil {
return 0, nil, 0, err
}
- defer ns.DecRef()
+ defer ns.DecRef(t)
if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil {
return 0, nil, 0, syserr.FromError(err)
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index d112757fb..04b259d27 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -46,8 +46,8 @@ type ControlMessages struct {
}
// Release releases Unix domain socket credentials and rights.
-func (c *ControlMessages) Release() {
- c.Unix.Release()
+func (c *ControlMessages) Release(ctx context.Context) {
+ c.Unix.Release(ctx)
}
// Socket is an interface combining fs.FileOperations and SocketOps,
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index a1e49cc57..c67b602f0 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -211,7 +211,7 @@ func (e *connectionedEndpoint) Listening() bool {
// The socket will be a fresh state after a call to close and may be reused.
// That is, close may be used to "unbind" or "disconnect" the socket in error
// paths.
-func (e *connectionedEndpoint) Close() {
+func (e *connectionedEndpoint) Close(ctx context.Context) {
e.Lock()
var c ConnectedEndpoint
var r Receiver
@@ -233,7 +233,7 @@ func (e *connectionedEndpoint) Close() {
case e.Listening():
close(e.acceptedChan)
for n := range e.acceptedChan {
- n.Close()
+ n.Close(ctx)
}
e.acceptedChan = nil
e.path = ""
@@ -241,11 +241,11 @@ func (e *connectionedEndpoint) Close() {
e.Unlock()
if c != nil {
c.CloseNotify()
- c.Release()
+ c.Release(ctx)
}
if r != nil {
r.CloseNotify()
- r.Release()
+ r.Release(ctx)
}
}
@@ -340,7 +340,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
return nil
default:
// Busy; return ECONNREFUSED per spec.
- ne.Close()
+ ne.Close(ctx)
e.Unlock()
ce.Unlock()
return syserr.ErrConnectionRefused
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 4b06d63ac..70ee8f9b8 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -54,10 +54,10 @@ func (e *connectionlessEndpoint) isBound() bool {
// Close puts the endpoint in a closed state and frees all resources associated
// with it.
-func (e *connectionlessEndpoint) Close() {
+func (e *connectionlessEndpoint) Close(ctx context.Context) {
e.Lock()
if e.connected != nil {
- e.connected.Release()
+ e.connected.Release(ctx)
e.connected = nil
}
@@ -71,7 +71,7 @@ func (e *connectionlessEndpoint) Close() {
e.Unlock()
r.CloseNotify()
- r.Release()
+ r.Release(ctx)
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
@@ -108,10 +108,10 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C
if err != nil {
return 0, syserr.ErrInvalidEndpointState
}
- defer connected.Release()
+ defer connected.Release(ctx)
e.Lock()
- n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ n, notify, err := connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
e.Unlock()
if notify {
@@ -135,7 +135,7 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi
e.Lock()
if e.connected != nil {
- e.connected.Release()
+ e.connected.Release(ctx)
}
e.connected = connected
e.Unlock()
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index d8f3ad63d..ef6043e19 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -15,6 +15,7 @@
package transport
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
@@ -57,10 +58,10 @@ func (q *queue) Close() {
// Both the read and write queues must be notified after resetting:
// q.ReaderQueue.Notify(waiter.EventIn)
// q.WriterQueue.Notify(waiter.EventOut)
-func (q *queue) Reset() {
+func (q *queue) Reset(ctx context.Context) {
q.mu.Lock()
for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
- cur.Release()
+ cur.Release(ctx)
}
q.dataList.Reset()
q.used = 0
@@ -68,8 +69,8 @@ func (q *queue) Reset() {
}
// DecRef implements RefCounter.DecRef with destructor q.Reset.
-func (q *queue) DecRef() {
- q.DecRefWithDestructor(q.Reset)
+func (q *queue) DecRef(ctx context.Context) {
+ q.DecRefWithDestructor(ctx, q.Reset)
// We don't need to notify after resetting because no one cares about
// this queue after all references have been dropped.
}
@@ -111,7 +112,7 @@ func (q *queue) IsWritable() bool {
//
// If notify is true, ReaderQueue.Notify must be called:
// q.ReaderQueue.Notify(waiter.EventIn)
-func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) {
+func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) {
q.mu.Lock()
if q.closed {
@@ -124,7 +125,7 @@ func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress
}
if discardEmpty && l == 0 {
q.mu.Unlock()
- c.Release()
+ c.Release(ctx)
return 0, false, nil
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2f1b127df..475d7177e 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -37,7 +37,7 @@ type RightsControlMessage interface {
Clone() RightsControlMessage
// Release releases any resources owned by the RightsControlMessage.
- Release()
+ Release(ctx context.Context)
}
// A CredentialsControlMessage is a control message containing Unix credentials.
@@ -74,9 +74,9 @@ func (c *ControlMessages) Clone() ControlMessages {
}
// Release releases both the credentials and the rights.
-func (c *ControlMessages) Release() {
+func (c *ControlMessages) Release(ctx context.Context) {
if c.Rights != nil {
- c.Rights.Release()
+ c.Rights.Release(ctx)
}
*c = ControlMessages{}
}
@@ -90,7 +90,7 @@ type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
- Close()
+ Close(ctx context.Context)
// RecvMsg reads data and a control message from the endpoint. This method
// does not block if there is no data pending.
@@ -252,7 +252,7 @@ type BoundEndpoint interface {
// Release releases any resources held by the BoundEndpoint. It must be
// called before dropping all references to a BoundEndpoint returned by a
// function.
- Release()
+ Release(ctx context.Context)
}
// message represents a message passed over a Unix domain socket.
@@ -281,8 +281,8 @@ func (m *message) Length() int64 {
}
// Release releases any resources held by the message.
-func (m *message) Release() {
- m.Control.Release()
+func (m *message) Release(ctx context.Context) {
+ m.Control.Release(ctx)
}
// Peek returns a copy of the message.
@@ -304,7 +304,7 @@ type Receiver interface {
// See Endpoint.RecvMsg for documentation on shared arguments.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+ Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -333,7 +333,7 @@ type Receiver interface {
// Release releases any resources owned by the Receiver. It should be
// called before droping all references to a Receiver.
- Release()
+ Release(ctx context.Context)
}
// queueReceiver implements Receiver for datagram sockets.
@@ -344,7 +344,7 @@ type queueReceiver struct {
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *queueReceiver) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
var m *message
var notify bool
var err *syserr.Error
@@ -398,8 +398,8 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 {
}
// Release implements Receiver.Release.
-func (q *queueReceiver) Release() {
- q.readQueue.DecRef()
+func (q *queueReceiver) Release(ctx context.Context) {
+ q.readQueue.DecRef(ctx)
}
// streamQueueReceiver implements Receiver for stream sockets.
@@ -456,7 +456,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
}
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -502,7 +502,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int,
var cmTruncated bool
if c.Rights != nil && numRights == 0 {
- c.Rights.Release()
+ c.Rights.Release(ctx)
c.Rights = nil
cmTruncated = true
}
@@ -557,7 +557,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int,
// Consume rights.
if numRights == 0 {
cmTruncated = true
- q.control.Rights.Release()
+ q.control.Rights.Release(ctx)
} else {
c.Rights = q.control.Rights
haveRights = true
@@ -582,7 +582,7 @@ type ConnectedEndpoint interface {
//
// syserr.ErrWouldBlock can be returned along with a partial write if
// the caller should block to send the rest of the data.
- Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
+ Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
// SendNotify notifies the ConnectedEndpoint of a successful Send. This
// must not be called while holding any endpoint locks.
@@ -616,7 +616,7 @@ type ConnectedEndpoint interface {
// Release releases any resources owned by the ConnectedEndpoint. It should
// be called before droping all references to a ConnectedEndpoint.
- Release()
+ Release(ctx context.Context)
// CloseUnread sets the fact that this end is closed with unread data to
// the peer socket.
@@ -654,7 +654,7 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
}
// Send implements ConnectedEndpoint.Send.
-func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+func (e *connectedEndpoint) Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
discardEmpty := false
truncate := false
if e.endpoint.Type() == linux.SOCK_STREAM {
@@ -669,7 +669,7 @@ func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.Fu
truncate = true
}
- return e.writeQueue.Enqueue(data, c, from, discardEmpty, truncate)
+ return e.writeQueue.Enqueue(ctx, data, c, from, discardEmpty, truncate)
}
// SendNotify implements ConnectedEndpoint.SendNotify.
@@ -707,8 +707,8 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 {
}
// Release implements ConnectedEndpoint.Release.
-func (e *connectedEndpoint) Release() {
- e.writeQueue.DecRef()
+func (e *connectedEndpoint) Release(ctx context.Context) {
+ e.writeQueue.DecRef(ctx)
}
// CloseUnread implements ConnectedEndpoint.CloseUnread.
@@ -798,7 +798,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n
return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
}
- recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(ctx, data, creds, numRights, peek)
e.Unlock()
if err != nil {
return 0, 0, ControlMessages{}, false, err
@@ -827,7 +827,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
return 0, syserr.ErrAlreadyConnected
}
- n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ n, notify, err := e.connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
e.Unlock()
if notify {
@@ -1001,6 +1001,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// Release implements BoundEndpoint.Release.
-func (*baseEndpoint) Release() {
+func (*baseEndpoint) Release(context.Context) {
// Binding a baseEndpoint doesn't take a reference.
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 0482d33cf..2b8454edb 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -62,7 +62,7 @@ type SocketOperations struct {
// New creates a new unix socket.
func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true})
}
@@ -97,17 +97,17 @@ type socketOpsCommon struct {
}
// DecRef implements RefCounter.DecRef.
-func (s *socketOpsCommon) DecRef() {
- s.DecRefWithDestructor(func() {
- s.ep.Close()
+func (s *socketOpsCommon) DecRef(ctx context.Context) {
+ s.DecRefWithDestructor(ctx, func(context.Context) {
+ s.ep.Close(ctx)
})
}
// Release implemements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release() {
+func (s *socketOpsCommon) Release(ctx context.Context) {
// Release only decrements a reference on s because s may be referenced in
// the abstract socket namespace.
- s.DecRef()
+ s.DecRef(ctx)
}
func (s *socketOpsCommon) isPacket() bool {
@@ -234,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
ns := New(t, ep, s.stype)
- defer ns.DecRef()
+ defer ns.DecRef(t)
if flags&linux.SOCK_NONBLOCK != 0 {
flags := ns.Flags()
@@ -284,7 +284,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if t.IsNetworkNamespaced() {
return syserr.ErrInvalidEndpointState
}
- if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil {
// syserr.ErrPortInUse corresponds to EADDRINUSE.
return syserr.ErrPortInUse
}
@@ -294,7 +294,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
var name string
cwd := t.FSContext().WorkingDirectory()
- defer cwd.DecRef()
+ defer cwd.DecRef(t)
// Is there no slash at all?
if !strings.Contains(p, "/") {
@@ -302,7 +302,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
name = p
} else {
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
// Find the last path component, we know that something follows
// that final slash, otherwise extractPath() would have failed.
lastSlash := strings.LastIndex(p, "/")
@@ -318,7 +318,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// No path available.
return syserr.ErrNoSuchFile
}
- defer d.DecRef()
+ defer d.DecRef(t)
name = p[lastSlash+1:]
}
@@ -332,7 +332,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if err != nil {
return syserr.ErrPortInUse
}
- childDir.DecRef()
+ childDir.DecRef(t)
}
return nil
@@ -378,9 +378,9 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
FollowFinalSymlink: true,
}
ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path})
- root.DecRef()
+ root.DecRef(t)
if relPath {
- start.DecRef()
+ start.DecRef(t)
}
if e != nil {
return nil, syserr.FromError(e)
@@ -393,15 +393,15 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
cwd := t.FSContext().WorkingDirectory()
remainingTraversals := uint(fs.DefaultTraversalLimit)
d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals)
- cwd.DecRef()
- root.DecRef()
+ cwd.DecRef(t)
+ root.DecRef(t)
if e != nil {
return nil, syserr.FromError(e)
}
// Extract the endpoint if one is there.
ep := d.Inode.BoundEndpoint(path)
- d.DecRef()
+ d.DecRef(t)
if ep == nil {
// No socket!
return nil, syserr.ErrConnectionRefused
@@ -415,7 +415,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
if err != nil {
return err
}
- defer ep.Release()
+ defer ep.Release(t)
// Connect the server endpoint.
err = s.ep.Connect(t, ep)
@@ -473,7 +473,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
if err != nil {
return 0, err
}
- defer ep.Release()
+ defer ep.Release(t)
w.To = ep
if ep.Passcred() && w.Control.Credentials == nil {
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 05c16fcfe..dfa25241a 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -136,7 +136,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
if err != nil {
return 0, nil, 0, err
}
- defer ns.DecRef()
+ defer ns.DecRef(t)
if flags&linux.SOCK_NONBLOCK != 0 {
ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
@@ -183,19 +183,19 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if t.IsNetworkNamespaced() {
return syserr.ErrInvalidEndpointState
}
- if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil {
// syserr.ErrPortInUse corresponds to EADDRINUSE.
return syserr.ErrPortInUse
}
} else {
path := fspath.Parse(p)
root := t.FSContext().RootDirectoryVFS2()
- defer root.DecRef()
+ defer root.DecRef(t)
start := root
relPath := !path.Absolute
if relPath {
start = t.FSContext().WorkingDirectoryVFS2()
- defer start.DecRef()
+ defer start.DecRef(t)
}
pop := vfs.PathOperation{
Root: root,
@@ -333,7 +333,7 @@ func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int)
f, err := NewSockfsFile(t, ep, stype)
if err != nil {
- ep.Close()
+ ep.Close(t)
return nil, err
}
return f, nil
@@ -357,14 +357,14 @@ func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*
ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
s1, err := NewSockfsFile(t, ep1, stype)
if err != nil {
- ep1.Close()
- ep2.Close()
+ ep1.Close(t)
+ ep2.Close(t)
return nil, nil, err
}
s2, err := NewSockfsFile(t, ep2, stype)
if err != nil {
- s1.DecRef()
- ep2.Close()
+ s1.DecRef(t)
+ ep2.Close(t)
return nil, nil, err
}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 68ca537c8..87b239730 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -147,14 +147,14 @@ func fd(t *kernel.Task, fd int32) string {
root := t.FSContext().RootDirectory()
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(t)
}
if fd == linux.AT_FDCWD {
wd := t.FSContext().WorkingDirectory()
var name string
if wd != nil {
- defer wd.DecRef()
+ defer wd.DecRef(t)
name, _ = wd.FullName(root)
} else {
name = "(unknown cwd)"
@@ -167,7 +167,7 @@ func fd(t *kernel.Task, fd int32) string {
// Cast FD to uint64 to avoid printing negative hex.
return fmt.Sprintf("%#x (bad FD)", uint64(fd))
}
- defer file.DecRef()
+ defer file.DecRef(t)
name, _ := file.Dirent.FullName(root)
return fmt.Sprintf("%#x %s", fd, name)
@@ -175,12 +175,12 @@ func fd(t *kernel.Task, fd int32) string {
func fdVFS2(t *kernel.Task, fd int32) string {
root := t.FSContext().RootDirectoryVFS2()
- defer root.DecRef()
+ defer root.DecRef(t)
vfsObj := root.Mount().Filesystem().VirtualFilesystem()
if fd == linux.AT_FDCWD {
wd := t.FSContext().WorkingDirectoryVFS2()
- defer wd.DecRef()
+ defer wd.DecRef(t)
name, _ := vfsObj.PathnameWithDeleted(t, root, wd)
return fmt.Sprintf("AT_FDCWD %s", name)
@@ -191,7 +191,7 @@ func fdVFS2(t *kernel.Task, fd int32) string {
// Cast FD to uint64 to avoid printing negative hex.
return fmt.Sprintf("%#x (bad FD)", uint64(fd))
}
- defer file.DecRef()
+ defer file.DecRef(t)
name, _ := vfsObj.PathnameWithDeleted(t, root, file.VirtualDentry())
return fmt.Sprintf("%#x %s", fd, name)
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
index d9fb808c0..d23a0068a 100644
--- a/pkg/sentry/syscalls/epoll.go
+++ b/pkg/sentry/syscalls/epoll.go
@@ -28,7 +28,7 @@ import (
// CreateEpoll implements the epoll_create(2) linux syscall.
func CreateEpoll(t *kernel.Task, closeOnExec bool) (int32, error) {
file := epoll.NewEventPoll(t)
- defer file.DecRef()
+ defer file.DecRef(t)
fd, err := t.NewFDFrom(0, file, kernel.FDFlags{
CloseOnExec: closeOnExec,
@@ -47,14 +47,14 @@ func AddEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, mask
if epollfile == nil {
return syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Get the target file id.
file := t.GetFile(fd)
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
@@ -73,14 +73,14 @@ func UpdateEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, m
if epollfile == nil {
return syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Get the target file id.
file := t.GetFile(fd)
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
@@ -99,14 +99,14 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
if epollfile == nil {
return syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Get the target file id.
file := t.GetFile(fd)
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
@@ -115,7 +115,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
}
// Try to remove the entry.
- return e.RemoveEntry(epoll.FileIdentifier{file, fd})
+ return e.RemoveEntry(t, epoll.FileIdentifier{file, fd})
}
// WaitEpoll implements the epoll_wait(2) linux syscall.
@@ -125,7 +125,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve
if epollfile == nil {
return nil, syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index ba2557c52..e9d64dec5 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -247,7 +247,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu
ev.Result = -int64(kernel.ExtractErrno(err, 0))
}
- file.DecRef()
+ file.DecRef(ctx)
// Queue the result for delivery.
actx.FinishRequest(ev)
@@ -257,7 +257,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu
// wake up.
if eventFile != nil {
eventFile.FileOperations.(*eventfd.EventOperations).Signal(1)
- eventFile.DecRef()
+ eventFile.DecRef(ctx)
}
}
}
@@ -269,7 +269,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
// File not found.
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Was there an eventFD? Extract it.
var eventFile *fs.File
@@ -279,7 +279,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
// Bad FD.
return syserror.EBADF
}
- defer eventFile.DecRef()
+ defer eventFile.DecRef(t)
// Check that it is an eventfd.
if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok {
diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go
index ed3413ca6..3b4f879e4 100644
--- a/pkg/sentry/syscalls/linux/sys_eventfd.go
+++ b/pkg/sentry/syscalls/linux/sys_eventfd.go
@@ -37,7 +37,7 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
event.SetFlags(fs.SettableFileFlags{
NonBlocking: flags&linux.EFD_NONBLOCK != 0,
})
- defer event.DecRef()
+ defer event.DecRef(t)
fd, err := t.NewFDFrom(0, event, kernel.FDFlags{
CloseOnExec: flags&linux.EFD_CLOEXEC != 0,
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 8cf6401e7..1bc9b184e 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -40,7 +40,7 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent,
// Common case: we are accessing a file in the root.
root := t.FSContext().RootDirectory()
err := fn(root, root, name, linux.MaxSymlinkTraversals)
- root.DecRef()
+ root.DecRef(t)
return err
} else if dir == "." && dirFD == linux.AT_FDCWD {
// Common case: we are accessing a file relative to the current
@@ -48,8 +48,8 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent,
wd := t.FSContext().WorkingDirectory()
root := t.FSContext().RootDirectory()
err := fn(root, wd, name, linux.MaxSymlinkTraversals)
- wd.DecRef()
- root.DecRef()
+ wd.DecRef(t)
+ root.DecRef(t)
return err
}
@@ -97,19 +97,19 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro
} else {
d, err = t.MountNamespace().FindLink(t, root, rel, path, &remainingTraversals)
}
- root.DecRef()
+ root.DecRef(t)
if wd != nil {
- wd.DecRef()
+ wd.DecRef(t)
}
if f != nil {
- f.DecRef()
+ f.DecRef(t)
}
if err != nil {
return err
}
err = fn(root, d, remainingTraversals)
- d.DecRef()
+ d.DecRef(t)
return err
}
@@ -186,7 +186,7 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
if err != nil {
return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Success.
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
@@ -242,7 +242,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode
if err != nil {
return err
}
- file.DecRef()
+ file.DecRef(t)
return nil
case linux.ModeNamedPipe:
@@ -332,7 +332,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
if err != nil {
break
}
- defer found.DecRef()
+ defer found.DecRef(t)
// We found something (possibly a symlink). If the
// O_EXCL flag was passed, then we can immediately
@@ -357,7 +357,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
resolved, err = found.Inode.Getlink(t)
if err == nil {
// No more resolution necessary.
- defer resolved.DecRef()
+ defer resolved.DecRef(t)
break
}
if err != fs.ErrResolveViaReadlink {
@@ -384,7 +384,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
if err != nil {
break
}
- defer newParent.DecRef()
+ defer newParent.DecRef(t)
// Repeat the process with the parent and name of the
// symlink target.
@@ -416,7 +416,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
if err != nil {
return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
- defer newFile.DecRef()
+ defer newFile.DecRef(t)
case syserror.ENOENT:
// File does not exist. Proceed with creation.
@@ -432,7 +432,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
// No luck, bail.
return err
}
- defer newFile.DecRef()
+ defer newFile.DecRef(t)
found = newFile.Dirent
default:
return err
@@ -596,7 +596,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Shared flags between file and socket.
switch request {
@@ -671,9 +671,9 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
addr := args[0].Pointer()
size := args[1].SizeT()
cwd := t.FSContext().WorkingDirectory()
- defer cwd.DecRef()
+ defer cwd.DecRef(t)
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
// Get our fullname from the root and preprend unreachable if the root was
// unreachable from our current dirent this is the same behavior as on linux.
@@ -722,7 +722,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return err
}
- t.FSContext().SetRootDirectory(d)
+ t.FSContext().SetRootDirectory(t, d)
return nil
})
}
@@ -747,7 +747,7 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return err
}
- t.FSContext().SetWorkingDirectory(d)
+ t.FSContext().SetWorkingDirectory(t, d)
return nil
})
}
@@ -760,7 +760,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Is it a directory?
if !fs.IsDir(file.Dirent.Inode.StableAttr) {
@@ -772,7 +772,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
- t.FSContext().SetWorkingDirectory(file.Dirent)
+ t.FSContext().SetWorkingDirectory(t, file.Dirent)
return 0, nil, nil
}
@@ -791,7 +791,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := file.Flush(t)
return 0, nil, handleIOError(t, false /* partial */, err, syserror.EINTR, "close", file)
@@ -805,7 +805,7 @@ func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{})
if err != nil {
@@ -826,7 +826,7 @@ func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if oldFile == nil {
return 0, nil, syserror.EBADF
}
- defer oldFile.DecRef()
+ defer oldFile.DecRef(t)
return uintptr(newfd), nil, nil
}
@@ -850,7 +850,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if oldFile == nil {
return 0, nil, syserror.EBADF
}
- defer oldFile.DecRef()
+ defer oldFile.DecRef(t)
err := t.NewFDAt(newfd, oldFile, kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0})
if err != nil {
@@ -925,7 +925,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
switch cmd {
case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
@@ -1132,7 +1132,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// If the FD refers to a pipe or FIFO, return error.
if fs.IsPipe(file.Dirent.Inode.StableAttr) {
@@ -1171,7 +1171,7 @@ func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode
switch err {
case nil:
// The directory existed.
- defer f.DecRef()
+ defer f.DecRef(t)
return syserror.EEXIST
case syserror.EACCES:
// Permission denied while walking to the directory.
@@ -1349,7 +1349,7 @@ func linkAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32
if target == nil {
return syserror.EBADF
}
- defer target.DecRef()
+ defer target.DecRef(t)
if err := mayLinkAt(t, target.Dirent.Inode); err != nil {
return err
}
@@ -1602,7 +1602,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Reject truncation if the file flags do not permit this operation.
// This is different from truncate(2) above.
@@ -1730,7 +1730,7 @@ func chownAt(t *kernel.Task, fd int32, addr usermem.Addr, resolve, allowEmpty bo
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return chown(t, file.Dirent, uid, gid)
}
@@ -1768,7 +1768,7 @@ func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, chown(t, file.Dirent, uid, gid)
}
@@ -1833,7 +1833,7 @@ func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, chmod(t, file.Dirent, mode)
}
@@ -1893,10 +1893,10 @@ func utimes(t *kernel.Task, dirFD int32, addr usermem.Addr, ts fs.TimeSpec, reso
if f == nil {
return syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
return setTimestamp(root, f.Dirent, linux.MaxSymlinkTraversals)
}
@@ -2088,7 +2088,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
if offset < 0 || length <= 0 {
return 0, nil, syserror.EINVAL
@@ -2141,7 +2141,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// flock(2): EBADF fd is not an open file descriptor.
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
nonblocking := operation&linux.LOCK_NB != 0
operation &^= linux.LOCK_NB
@@ -2224,8 +2224,8 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
- defer dirent.DecRef()
- defer file.DecRef()
+ defer dirent.DecRef(t)
+ defer file.DecRef(t)
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
CloseOnExec: cloExec,
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index f04d78856..9d1b2edb1 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -73,7 +73,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo
err = t.BlockWithDeadline(w.C, true, ktime.FromTimespec(ts))
}
- t.Futex().WaitComplete(w)
+ t.Futex().WaitComplete(w, t)
return 0, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
@@ -95,7 +95,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add
}
remaining, err := t.BlockWithTimeout(w.C, !forever, duration)
- t.Futex().WaitComplete(w)
+ t.Futex().WaitComplete(w, t)
if err == nil {
return 0, nil
}
@@ -148,7 +148,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.A
timer.Destroy()
}
- t.Futex().WaitComplete(w)
+ t.Futex().WaitComplete(w, t)
return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index b126fecc0..f5699e55d 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -68,7 +68,7 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir
if dir == nil {
return 0, syserror.EBADF
}
- defer dir.DecRef()
+ defer dir.DecRef(t)
w := &usermem.IOReadWriter{
Ctx: t,
diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go
index b2c7b3444..cf47bb9dd 100644
--- a/pkg/sentry/syscalls/linux/sys_inotify.go
+++ b/pkg/sentry/syscalls/linux/sys_inotify.go
@@ -40,7 +40,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
NonBlocking: flags&linux.IN_NONBLOCK != 0,
}
n := fs.NewFile(t, dirent, fileFlags, fs.NewInotify(t))
- defer n.DecRef()
+ defer n.DecRef(t)
fd, err := t.NewFDFrom(0, n, kernel.FDFlags{
CloseOnExec: flags&linux.IN_CLOEXEC != 0,
@@ -71,7 +71,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*fs.Inotify, *fs.File, error) {
ino, ok := file.FileOperations.(*fs.Inotify)
if !ok {
// Not an inotify fd.
- file.DecRef()
+ file.DecRef(t)
return nil, nil, syserror.EINVAL
}
@@ -98,7 +98,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -128,6 +128,6 @@ func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
- return 0, nil, ino.RmWatch(wd)
+ defer file.DecRef(t)
+ return 0, nil, ino.RmWatch(t, wd)
}
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 3f7691eae..1c38f8f4f 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -33,7 +33,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
var sw fs.SeekWhence
switch whence {
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 91694d374..72786b032 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -75,7 +75,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
}
defer func() {
if opts.MappingIdentity != nil {
- opts.MappingIdentity.DecRef()
+ opts.MappingIdentity.DecRef(t)
}
}()
@@ -85,7 +85,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
flags := file.Flags()
// mmap unconditionally requires that the FD is readable.
diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go
index eb5ff48f5..bd0633564 100644
--- a/pkg/sentry/syscalls/linux/sys_mount.go
+++ b/pkg/sentry/syscalls/linux/sys_mount.go
@@ -115,7 +115,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}); err != nil {
// Something went wrong. Drop our ref on rootInode before
// returning the error.
- rootInode.DecRef()
+ rootInode.DecRef(t)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 43c510930..3149e4aad 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -34,10 +34,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize)
r.SetFlags(linuxToFlags(flags).Settable())
- defer r.DecRef()
+ defer r.DecRef(t)
w.SetFlags(linuxToFlags(flags).Settable())
- defer w.DecRef()
+ defer w.DecRef(t)
fds, err := t.NewFDs(0, []*fs.File{r, w}, kernel.FDFlags{
CloseOnExec: flags&linux.O_CLOEXEC != 0,
@@ -49,7 +49,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
if _, err := t.CopyOut(addr, fds); err != nil {
for _, fd := range fds {
if file, _ := t.FDTable().Remove(fd); file != nil {
- file.DecRef()
+ file.DecRef(t)
}
}
return 0, err
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index f0198141c..3435bdf77 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -70,7 +70,7 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan
}
if ch == nil {
- defer file.DecRef()
+ defer file.DecRef(t)
} else {
state.file = file
state.waiter, _ = waiter.NewChannelEntry(ch)
@@ -82,11 +82,11 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan
}
// releaseState releases all the pollState in "state".
-func releaseState(state []pollState) {
+func releaseState(t *kernel.Task, state []pollState) {
for i := range state {
if state[i].file != nil {
state[i].file.EventUnregister(&state[i].waiter)
- state[i].file.DecRef()
+ state[i].file.DecRef(t)
}
}
}
@@ -107,7 +107,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
// result, we stop registering for events but still go through all files
// to get their ready masks.
state := make([]pollState, len(pfd))
- defer releaseState(state)
+ defer releaseState(t, state)
n := uintptr(0)
for i := range pfd {
initReadiness(t, &pfd[i], &state[i], ch)
@@ -266,7 +266,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
if file == nil {
return 0, syserror.EBADF
}
- file.DecRef()
+ file.DecRef(t)
var mask int16
if (rV & m) != 0 {
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index f92bf8096..64a725296 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -128,7 +128,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// They trying to set exe to a non-file?
if !fs.IsFile(file.Dirent.Inode.StableAttr) {
@@ -136,7 +136,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
// Set the underlying executable.
- t.MemoryManager().SetExecutable(fsbridge.NewFSFile(file))
+ t.MemoryManager().SetExecutable(t, fsbridge.NewFSFile(file))
case linux.PR_SET_MM_AUXV,
linux.PR_SET_MM_START_CODE,
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 071b4bacc..3bbc3fa4b 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -48,7 +48,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is readable.
if !file.Flags().Read {
@@ -84,7 +84,7 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is readable.
if !file.Flags().Read {
@@ -118,7 +118,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
@@ -164,7 +164,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is readable.
if !file.Flags().Read {
@@ -195,7 +195,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < 0 {
@@ -244,7 +244,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < -1 {
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index 4a8bc24a2..f0ae8fa8e 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -39,7 +39,7 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
- defer segment.DecRef()
+ defer segment.DecRef(t)
return uintptr(segment.ID), nil, nil
}
@@ -66,7 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, syserror.EINVAL
}
- defer segment.DecRef()
+ defer segment.DecRef(t)
opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{
Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC,
@@ -108,7 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
- defer segment.DecRef()
+ defer segment.DecRef(t)
stat, err := segment.IPCStat(t)
if err == nil {
@@ -132,7 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
- defer segment.DecRef()
+ defer segment.DecRef(t)
switch cmd {
case linux.IPC_SET:
@@ -145,7 +145,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
case linux.IPC_RMID:
- segment.MarkDestroyed()
+ segment.MarkDestroyed(t)
return 0, nil, nil
case linux.SHM_LOCK, linux.SHM_UNLOCK:
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index d2b0012ae..20cb1a5cb 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -536,7 +536,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Is this a signalfd?
if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok {
@@ -553,7 +553,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Set appropriate flags.
file.SetFlags(fs.SettableFileFlags{
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 414fce8e3..fec1c1974 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -200,7 +200,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
s.SetFlags(fs.SettableFileFlags{
NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
})
- defer s.DecRef()
+ defer s.DecRef(t)
fd, err := t.NewFDFrom(0, s, kernel.FDFlags{
CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
@@ -235,8 +235,8 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
s1.SetFlags(fileFlags)
s2.SetFlags(fileFlags)
- defer s1.DecRef()
- defer s2.DecRef()
+ defer s1.DecRef(t)
+ defer s2.DecRef(t)
// Create the FDs for the sockets.
fds, err := t.NewFDs(0, []*fs.File{s1, s2}, kernel.FDFlags{
@@ -250,7 +250,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if _, err := t.CopyOut(socks, fds); err != nil {
for _, fd := range fds {
if file, _ := t.FDTable().Remove(fd); file != nil {
- file.DecRef()
+ file.DecRef(t)
}
}
return 0, nil, err
@@ -270,7 +270,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -301,7 +301,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -360,7 +360,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -387,7 +387,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -416,7 +416,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -447,7 +447,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -529,7 +529,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -567,7 +567,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -595,7 +595,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -628,7 +628,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -681,7 +681,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -775,7 +775,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
if !cms.Unix.Empty() {
mflags |= linux.MSG_CTRUNC
- cms.Release()
+ cms.Release(t)
}
if int(msg.Flags) != mflags {
@@ -795,7 +795,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
- defer cms.Release()
+ defer cms.Release(t)
controlData := make([]byte, 0, msg.ControlLen)
controlData = control.PackControlMessages(t, cms, controlData)
@@ -851,7 +851,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -880,7 +880,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
}
n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
- cm.Release()
+ cm.Release(t)
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
@@ -924,7 +924,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -962,7 +962,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -1066,7 +1066,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
if err != nil {
- controlMessages.Release()
+ controlMessages.Release(t)
}
return uintptr(n), err
}
@@ -1084,7 +1084,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 77c78889d..b8846a10a 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -101,7 +101,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
if !inFile.Flags().Read {
return 0, nil, syserror.EBADF
@@ -111,7 +111,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
if !outFile.Flags().Write {
return 0, nil, syserror.EBADF
@@ -192,13 +192,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
inFile := t.GetFile(inFD)
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
// The operation is non-blocking if anything is non-blocking.
//
@@ -300,13 +300,13 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
inFile := t.GetFile(inFD)
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
// All files must be pipes.
if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) {
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 46ebf27a2..a5826f2dd 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -58,7 +58,7 @@ func Fstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, fstat(t, file, statAddr)
}
@@ -100,7 +100,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, fstat(t, file, statAddr)
}
@@ -158,7 +158,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
uattr, err := file.UnstableAttr(t)
if err != nil {
return 0, nil, err
@@ -249,7 +249,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, statfsImpl(t, file.Dirent, statfsAddr)
}
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 5ad465ae3..f2c0e5069 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -39,7 +39,7 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Use "sync-the-world" for now, it's guaranteed that fd is at least
// on the root filesystem.
@@ -54,7 +54,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll)
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
@@ -70,7 +70,7 @@ func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData)
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
@@ -103,7 +103,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// SYNC_FILE_RANGE_WAIT_BEFORE waits upon write-out of all pages in the
// specified range that have already been submitted to the device
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 00915fdde..2d16e4933 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -117,7 +117,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
var wd *fs.Dirent
var executable fsbridge.File
@@ -133,7 +133,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
closeOnExec = fdFlags.CloseOnExec
if atEmptyPath && len(pathname) == 0 {
@@ -155,7 +155,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
}
}
if wd != nil {
- defer wd.DecRef()
+ defer wd.DecRef(t)
}
// Load the new TaskContext.
diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go
index cf49b43db..34b03e4ee 100644
--- a/pkg/sentry/syscalls/linux/sys_timerfd.go
+++ b/pkg/sentry/syscalls/linux/sys_timerfd.go
@@ -43,7 +43,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.EINVAL
}
f := timerfd.NewFile(t, c)
- defer f.DecRef()
+ defer f.DecRef(t)
f.SetFlags(fs.SettableFileFlags{
NonBlocking: flags&linux.TFD_NONBLOCK != 0,
})
@@ -73,7 +73,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
tf, ok := f.FileOperations.(*timerfd.TimerOperations)
if !ok {
@@ -107,7 +107,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
tf, ok := f.FileOperations.(*timerfd.TimerOperations)
if !ok {
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 6ec0de96e..485526e28 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -48,7 +48,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is writable.
if !file.Flags().Write {
@@ -85,7 +85,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
@@ -131,7 +131,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is writable.
if !file.Flags().Write {
@@ -162,7 +162,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < 0 {
@@ -215,7 +215,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < -1 {
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index c24946160..97474fd3c 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -49,7 +49,7 @@ func FGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
n, err := getXattr(t, f.Dirent, nameAddr, valueAddr, size)
if err != nil {
@@ -153,7 +153,7 @@ func FSetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
return 0, nil, setXattr(t, f.Dirent, nameAddr, valueAddr, uint64(size), flags)
}
@@ -270,7 +270,7 @@ func FListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
n, err := listXattr(t, f.Dirent, listAddr, size)
if err != nil {
@@ -384,7 +384,7 @@ func FRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
return 0, nil, removeXattr(t, f.Dirent, nameAddr)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index e5cdefc50..399b4f60c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -88,7 +88,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if fd == nil {
return syserror.EBADF
}
- defer fd.DecRef()
+ defer fd.DecRef(t)
// Was there an eventFD? Extract it.
var eventFD *vfs.FileDescription
@@ -97,7 +97,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if eventFD == nil {
return syserror.EBADF
}
- defer eventFD.DecRef()
+ defer eventFD.DecRef(t)
// Check that it is an eventfd.
if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok {
@@ -169,7 +169,7 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use
ev.Result = -int64(kernel.ExtractErrno(err, 0))
}
- fd.DecRef()
+ fd.DecRef(ctx)
// Queue the result for delivery.
aioCtx.FinishRequest(ev)
@@ -179,7 +179,7 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use
// wake up.
if eventFD != nil {
eventFD.Impl().(*eventfd.EventFileDescription).Signal(1)
- eventFD.DecRef()
+ eventFD.DecRef(ctx)
}
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index 34c90ae3e..c62f03509 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -37,11 +37,11 @@ func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
return 0, nil, syserror.EINVAL
}
- file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ file, err := t.Kernel().VFS().NewEpollInstanceFD(t)
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
@@ -62,11 +62,11 @@ func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ file, err := t.Kernel().VFS().NewEpollInstanceFD(t)
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
if err != nil {
@@ -86,7 +86,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if epfile == nil {
return 0, nil, syserror.EBADF
}
- defer epfile.DecRef()
+ defer epfile.DecRef(t)
ep, ok := epfile.Impl().(*vfs.EpollInstance)
if !ok {
return 0, nil, syserror.EINVAL
@@ -95,7 +95,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
if epfile == file {
return 0, nil, syserror.EINVAL
}
@@ -135,7 +135,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if epfile == nil {
return 0, nil, syserror.EBADF
}
- defer epfile.DecRef()
+ defer epfile.DecRef(t)
ep, ok := epfile.Impl().(*vfs.EpollInstance)
if !ok {
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/vfs2/eventfd.go b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
index aff1a2070..807f909da 100644
--- a/pkg/sentry/syscalls/linux/vfs2/eventfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
@@ -38,11 +38,11 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
fileFlags |= linux.O_NONBLOCK
}
semMode := flags&linux.EFD_SEMAPHORE != 0
- eventfd, err := eventfd.New(vfsObj, initVal, semMode, fileFlags)
+ eventfd, err := eventfd.New(t, vfsObj, initVal, semMode, fileFlags)
if err != nil {
return 0, nil, err
}
- defer eventfd.DecRef()
+ defer eventfd.DecRef(t)
fd, err := t.NewFDFromVFS2(0, eventfd, kernel.FDFlags{
CloseOnExec: flags&linux.EFD_CLOEXEC != 0,
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
index aef0078a8..066ee0863 100644
--- a/pkg/sentry/syscalls/linux/vfs2/execve.go
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -71,7 +71,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user
}
root := t.FSContext().RootDirectoryVFS2()
- defer root.DecRef()
+ defer root.DecRef(t)
var executable fsbridge.File
closeOnExec := false
if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
@@ -90,7 +90,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user
}
start := dirfile.VirtualDentry()
start.IncRef()
- dirfile.DecRef()
+ dirfile.DecRef(t)
closeOnExec = dirfileFlags.CloseOnExec
file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
Root: root,
@@ -101,19 +101,19 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user
Flags: linux.O_RDONLY,
FileExec: true,
})
- start.DecRef()
+ start.DecRef(t)
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
executable = fsbridge.NewVFSFile(file)
}
// Load the new TaskContext.
mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
- defer mntns.DecRef()
+ defer mntns.DecRef(t)
wd := t.FSContext().WorkingDirectoryVFS2()
- defer wd.DecRef()
+ defer wd.DecRef(t)
remainingTraversals := uint(linux.MaxSymlinkTraversals)
loadArgs := loader.LoadArgs{
Opener: fsbridge.NewVFSLookup(mntns, root, wd),
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 67f191551..72ca916a0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -38,7 +38,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := file.OnClose(t)
return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
@@ -52,7 +52,7 @@ func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
if err != nil {
@@ -72,7 +72,7 @@ func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- file.DecRef()
+ file.DecRef(t)
return uintptr(newfd), nil, nil
}
@@ -101,7 +101,7 @@ func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
CloseOnExec: flags&linux.O_CLOEXEC != 0,
@@ -121,7 +121,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
switch cmd {
case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
@@ -332,7 +332,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// If the FD refers to a pipe or FIFO, return error.
if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe {
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index b6d2ddd65..01e0f9010 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -56,7 +56,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd i
if err != nil {
return err
}
- defer oldtpop.Release()
+ defer oldtpop.Release(t)
newpath, err := copyInPath(t, newpathAddr)
if err != nil {
@@ -66,7 +66,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd i
if err != nil {
return err
}
- defer newtpop.Release()
+ defer newtpop.Release(t)
return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
}
@@ -95,7 +95,7 @@ func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
})
@@ -127,7 +127,7 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
// "Zero file type is equivalent to type S_IFREG." - mknod(2)
if mode.FileType() == 0 {
@@ -174,7 +174,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
Flags: flags | linux.O_LARGEFILE,
@@ -183,7 +183,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
CloseOnExec: flags&linux.O_CLOEXEC != 0,
@@ -227,7 +227,7 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd
if err != nil {
return err
}
- defer oldtpop.Release()
+ defer oldtpop.Release(t)
newpath, err := copyInPath(t, newpathAddr)
if err != nil {
@@ -237,7 +237,7 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd
if err != nil {
return err
}
- defer newtpop.Release()
+ defer newtpop.Release(t)
return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
Flags: flags,
@@ -259,7 +259,7 @@ func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
}
@@ -278,7 +278,7 @@ func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
}
@@ -329,6 +329,6 @@ func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpath
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
index 317409a18..a7d4d2a36 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fscontext.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -31,8 +31,8 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
root := t.FSContext().RootDirectoryVFS2()
wd := t.FSContext().WorkingDirectoryVFS2()
s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
- root.DecRef()
- wd.DecRef()
+ root.DecRef(t)
+ wd.DecRef(t)
if err != nil {
return 0, nil, err
}
@@ -67,7 +67,7 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
CheckSearchable: true,
@@ -75,8 +75,8 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- t.FSContext().SetWorkingDirectoryVFS2(vd)
- vd.DecRef()
+ t.FSContext().SetWorkingDirectoryVFS2(t, vd)
+ vd.DecRef(t)
return 0, nil, nil
}
@@ -88,7 +88,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
CheckSearchable: true,
@@ -96,8 +96,8 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
- t.FSContext().SetWorkingDirectoryVFS2(vd)
- vd.DecRef()
+ t.FSContext().SetWorkingDirectoryVFS2(t, vd)
+ vd.DecRef(t)
return 0, nil, nil
}
@@ -117,7 +117,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
CheckSearchable: true,
@@ -125,7 +125,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
- t.FSContext().SetRootDirectoryVFS2(vd)
- vd.DecRef()
+ t.FSContext().SetRootDirectoryVFS2(t, vd)
+ vd.DecRef(t)
return 0, nil, nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
index c7c7bf7ce..5517595b5 100644
--- a/pkg/sentry/syscalls/linux/vfs2/getdents.go
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -44,7 +44,7 @@ func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (ui
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
cb := getGetdentsCallback(t, addr, size, isGetdents64)
err := file.IterDirents(t, cb)
diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go
index 5d98134a5..11753d8e5 100644
--- a/pkg/sentry/syscalls/linux/vfs2/inotify.go
+++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go
@@ -35,7 +35,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if err != nil {
return 0, nil, err
}
- defer ino.DecRef()
+ defer ino.DecRef(t)
fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{
CloseOnExec: flags&linux.IN_CLOEXEC != 0,
@@ -66,7 +66,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription,
ino, ok := f.Impl().(*vfs.Inotify)
if !ok {
// Not an inotify fd.
- f.DecRef()
+ f.DecRef(t)
return nil, nil, syserror.EINVAL
}
@@ -96,7 +96,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern
if err != nil {
return 0, nil, err
}
- defer f.DecRef()
+ defer f.DecRef(t)
path, err := copyInPath(t, addr)
if err != nil {
@@ -109,12 +109,12 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{})
if err != nil {
return 0, nil, err
}
- defer d.DecRef()
+ defer d.DecRef(t)
fd, err = ino.AddWatch(d.Dentry(), mask)
if err != nil {
@@ -132,6 +132,6 @@ func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if err != nil {
return 0, nil, err
}
- defer f.DecRef()
- return 0, nil, ino.RmWatch(wd)
+ defer f.DecRef(t)
+ return 0, nil, ino.RmWatch(t, wd)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index fd6ab94b2..38778a388 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -29,7 +29,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Handle ioctls that apply to all FDs.
switch args[1].Int() {
diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go
index bf19028c4..b910b5a74 100644
--- a/pkg/sentry/syscalls/linux/vfs2/lock.go
+++ b/pkg/sentry/syscalls/linux/vfs2/lock.go
@@ -32,7 +32,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// flock(2): EBADF fd is not an open file descriptor.
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
nonblocking := operation&linux.LOCK_NB != 0
operation &^= linux.LOCK_NB
diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go
index bbe248d17..519583e4e 100644
--- a/pkg/sentry/syscalls/linux/vfs2/memfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go
@@ -47,7 +47,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
}
shmMount := t.Kernel().ShmMount()
- file, err := tmpfs.NewMemfd(shmMount, t.Credentials(), allowSeals, memfdPrefix+name)
+ file, err := tmpfs.NewMemfd(t, t.Credentials(), shmMount, allowSeals, memfdPrefix+name)
if err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
index 60a43f0a0..dc05c2994 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mmap.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -61,7 +61,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
}
defer func() {
if opts.MappingIdentity != nil {
- opts.MappingIdentity.DecRef()
+ opts.MappingIdentity.DecRef(t)
}
}()
@@ -71,7 +71,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// mmap unconditionally requires that the FD is readable.
if !file.IsReadable() {
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
index ea337de7c..4bd5c7ca2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mount.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -108,7 +108,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- defer target.Release()
+ defer target.Release(t)
return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts)
}
@@ -140,7 +140,7 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
opts := vfs.UmountOptions{
Flags: uint32(flags),
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
index 97da6c647..90a511d9a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/path.go
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -42,7 +42,7 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA
haveStartRef := false
if !path.Absolute {
if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
- root.DecRef()
+ root.DecRef(t)
return taskPathOperation{}, syserror.ENOENT
}
if dirfd == linux.AT_FDCWD {
@@ -51,13 +51,13 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA
} else {
dirfile := t.GetFileVFS2(dirfd)
if dirfile == nil {
- root.DecRef()
+ root.DecRef(t)
return taskPathOperation{}, syserror.EBADF
}
start = dirfile.VirtualDentry()
start.IncRef()
haveStartRef = true
- dirfile.DecRef()
+ dirfile.DecRef(t)
}
}
return taskPathOperation{
@@ -71,10 +71,10 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA
}, nil
}
-func (tpop *taskPathOperation) Release() {
- tpop.pop.Root.DecRef()
+func (tpop *taskPathOperation) Release(t *kernel.Task) {
+ tpop.pop.Root.DecRef(t)
if tpop.haveStartRef {
- tpop.pop.Start.DecRef()
+ tpop.pop.Start.DecRef(t)
tpop.haveStartRef = false
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
index 4a01e4209..9b4848d9e 100644
--- a/pkg/sentry/syscalls/linux/vfs2/pipe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -42,8 +42,8 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
return syserror.EINVAL
}
r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(t)
+ defer w.DecRef(t)
fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{
CloseOnExec: flags&linux.O_CLOEXEC != 0,
@@ -54,7 +54,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
if _, err := t.CopyOut(addr, fds); err != nil {
for _, fd := range fds {
if _, file := t.FDTable().Remove(fd); file != nil {
- file.DecRef()
+ file.DecRef(t)
}
}
return err
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index ff1b25d7b..7b9d5e18a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -73,7 +73,7 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan
}
if ch == nil {
- defer file.DecRef()
+ defer file.DecRef(t)
} else {
state.file = file
state.waiter, _ = waiter.NewChannelEntry(ch)
@@ -85,11 +85,11 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan
}
// releaseState releases all the pollState in "state".
-func releaseState(state []pollState) {
+func releaseState(t *kernel.Task, state []pollState) {
for i := range state {
if state[i].file != nil {
state[i].file.EventUnregister(&state[i].waiter)
- state[i].file.DecRef()
+ state[i].file.DecRef(t)
}
}
}
@@ -110,7 +110,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
// result, we stop registering for events but still go through all files
// to get their ready masks.
state := make([]pollState, len(pfd))
- defer releaseState(state)
+ defer releaseState(t, state)
n := uintptr(0)
for i := range pfd {
initReadiness(t, &pfd[i], &state[i], ch)
@@ -269,7 +269,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
if file == nil {
return 0, syserror.EBADF
}
- file.DecRef()
+ file.DecRef(t)
var mask int16
if (rV & m) != 0 {
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
index cd25597a7..a905dae0a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/read_write.go
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -44,7 +44,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the size is legitimate.
si := int(size)
@@ -75,7 +75,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Get the destination of the read.
dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
@@ -94,7 +94,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
n, err := file.Read(t, dst, opts)
if err != syserror.ErrWouldBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return n, err
}
@@ -102,7 +102,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return n, err
}
@@ -135,7 +135,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
file.EventUnregister(&w)
if total > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return total, err
}
@@ -151,7 +151,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
@@ -188,7 +188,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < 0 {
@@ -226,7 +226,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < -1 {
@@ -258,7 +258,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
n, err := file.PRead(t, dst, offset, opts)
if err != syserror.ErrWouldBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return n, err
}
@@ -266,7 +266,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return n, err
}
@@ -299,7 +299,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
file.EventUnregister(&w)
if total > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return total, err
}
@@ -314,7 +314,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the size is legitimate.
si := int(size)
@@ -345,7 +345,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Get the source of the write.
src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
@@ -364,7 +364,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
n, err := file.Write(t, src, opts)
if err != syserror.ErrWouldBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
}
return n, err
}
@@ -372,7 +372,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
}
return n, err
}
@@ -405,7 +405,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
file.EventUnregister(&w)
if total > 0 {
- file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
}
return total, err
}
@@ -421,7 +421,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
@@ -458,7 +458,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < 0 {
@@ -496,7 +496,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < -1 {
@@ -528,7 +528,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
n, err := file.PWrite(t, src, offset, opts)
if err != syserror.ErrWouldBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
}
return n, err
}
@@ -536,7 +536,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
if n > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return n, err
}
@@ -569,7 +569,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
file.EventUnregister(&w)
if total > 0 {
- file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
}
return total, err
}
@@ -601,7 +601,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
newoff, err := file.Seek(t, offset, whence)
return uintptr(newoff), nil, err
@@ -617,7 +617,7 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is readable.
if !file.IsReadable() {
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 25cdb7a55..5e6eb13ba 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -66,7 +66,7 @@ func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, file.SetStat(t, vfs.SetStatOptions{
Stat: linux.Statx{
@@ -151,7 +151,7 @@ func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
var opts vfs.SetStatOptions
if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
@@ -197,7 +197,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
if !file.IsWritable() {
return 0, nil, syserror.EINVAL
@@ -224,7 +224,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
if !file.IsWritable() {
return 0, nil, syserror.EBADF
@@ -258,7 +258,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
- file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
return 0, nil, nil
}
@@ -438,7 +438,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op
func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
root := t.FSContext().RootDirectoryVFS2()
- defer root.DecRef()
+ defer root.DecRef(t)
start := root
if !path.Absolute {
if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
@@ -446,7 +446,7 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa
}
if dirfd == linux.AT_FDCWD {
start = t.FSContext().WorkingDirectoryVFS2()
- defer start.DecRef()
+ defer start.DecRef(t)
} else {
dirfile := t.GetFileVFS2(dirfd)
if dirfile == nil {
@@ -457,13 +457,13 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa
// VirtualFilesystem.SetStatAt(), since the former may be able
// to use opened file state to expedite the SetStat.
err := dirfile.SetStat(t, *opts)
- dirfile.DecRef()
+ dirfile.DecRef(t)
return err
}
start = dirfile.VirtualDentry()
start.IncRef()
- defer start.DecRef()
- dirfile.DecRef()
+ defer start.DecRef(t)
+ dirfile.DecRef(t)
}
}
return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go
index 623992f6f..b89f34cdb 100644
--- a/pkg/sentry/syscalls/linux/vfs2/signal.go
+++ b/pkg/sentry/syscalls/linux/vfs2/signal.go
@@ -45,7 +45,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Is this a signalfd?
if sfd, ok := file.Impl().(*signalfd.SignalFileDescription); ok {
@@ -68,7 +68,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Create a new descriptor.
fd, err = t.NewFDFromVFS2(0, file, kernel.FDFlags{
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 8096a8f9c..4a68c64f3 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -196,7 +196,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if e != nil {
return 0, nil, e.ToError()
}
- defer s.DecRef()
+ defer s.DecRef(t)
if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil {
return 0, nil, err
@@ -230,8 +230,8 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
// Adding to the FD table will cause an extra reference to be acquired.
- defer s1.DecRef()
- defer s2.DecRef()
+ defer s1.DecRef(t)
+ defer s2.DecRef(t)
nonblocking := uint32(stype & linux.SOCK_NONBLOCK)
if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
@@ -253,7 +253,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if _, err := t.CopyOut(addr, fds); err != nil {
for _, fd := range fds {
if _, file := t.FDTable().Remove(fd); file != nil {
- file.DecRef()
+ file.DecRef(t)
}
}
return 0, nil, err
@@ -273,7 +273,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -304,7 +304,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -363,7 +363,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -390,7 +390,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -419,7 +419,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -450,7 +450,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -532,7 +532,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -570,7 +570,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -598,7 +598,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -631,7 +631,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -684,7 +684,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -778,7 +778,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
}
if !cms.Unix.Empty() {
mflags |= linux.MSG_CTRUNC
- cms.Release()
+ cms.Release(t)
}
if int(msg.Flags) != mflags {
@@ -798,7 +798,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
- defer cms.Release()
+ defer cms.Release(t)
controlData := make([]byte, 0, msg.ControlLen)
controlData = control.PackControlMessages(t, cms, controlData)
@@ -854,7 +854,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -883,7 +883,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
}
n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
- cm.Release()
+ cm.Release(t)
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
@@ -927,7 +927,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -965,7 +965,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
@@ -1069,7 +1069,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
if err != nil {
- controlMessages.Release()
+ controlMessages.Release(t)
}
return uintptr(n), err
}
@@ -1087,7 +1087,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 63ab11f8c..16f59fce9 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -53,12 +53,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
outFile := t.GetFileVFS2(outFD)
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
// Check that both files support the required directionality.
if !inFile.IsReadable() || !outFile.IsWritable() {
@@ -175,7 +175,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// On Linux, inotify behavior is not very consistent with splice(2). We try
// our best to emulate Linux for very basic calls to splice, where for some
// reason, events are generated for output files, but not input files.
- outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
return uintptr(n), nil, nil
}
@@ -203,12 +203,12 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
outFile := t.GetFileVFS2(outFD)
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
// Check that both files support the required directionality.
if !inFile.IsReadable() || !outFile.IsWritable() {
@@ -251,7 +251,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if n == 0 {
return 0, nil, err
}
- outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
return uintptr(n), nil, nil
}
@@ -266,7 +266,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
if !inFile.IsReadable() {
return 0, nil, syserror.EBADF
}
@@ -275,7 +275,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
if !outFile.IsWritable() {
return 0, nil, syserror.EBADF
}
@@ -419,8 +419,8 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
- inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
- outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
return uintptr(n), nil, nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
index bb1d5cac4..0f5d5189c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/stat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -65,7 +65,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
}
root := t.FSContext().RootDirectoryVFS2()
- defer root.DecRef()
+ defer root.DecRef(t)
start := root
if !path.Absolute {
if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
@@ -73,7 +73,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
}
if dirfd == linux.AT_FDCWD {
start = t.FSContext().WorkingDirectoryVFS2()
- defer start.DecRef()
+ defer start.DecRef(t)
} else {
dirfile := t.GetFileVFS2(dirfd)
if dirfile == nil {
@@ -85,7 +85,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
// former may be able to use opened file state to expedite the
// Stat.
statx, err := dirfile.Stat(t, opts)
- dirfile.DecRef()
+ dirfile.DecRef(t)
if err != nil {
return err
}
@@ -96,8 +96,8 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
}
start = dirfile.VirtualDentry()
start.IncRef()
- defer start.DecRef()
- dirfile.DecRef()
+ defer start.DecRef(t)
+ dirfile.DecRef(t)
}
}
@@ -132,7 +132,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
statx, err := file.Stat(t, vfs.StatOptions{
Mask: linux.STATX_BASIC_STATS,
@@ -177,7 +177,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
root := t.FSContext().RootDirectoryVFS2()
- defer root.DecRef()
+ defer root.DecRef(t)
start := root
if !path.Absolute {
if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
@@ -185,7 +185,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if dirfd == linux.AT_FDCWD {
start = t.FSContext().WorkingDirectoryVFS2()
- defer start.DecRef()
+ defer start.DecRef(t)
} else {
dirfile := t.GetFileVFS2(dirfd)
if dirfile == nil {
@@ -197,7 +197,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// former may be able to use opened file state to expedite the
// Stat.
statx, err := dirfile.Stat(t, opts)
- dirfile.DecRef()
+ dirfile.DecRef(t)
if err != nil {
return 0, nil, err
}
@@ -207,8 +207,8 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
start = dirfile.VirtualDentry()
start.IncRef()
- defer start.DecRef()
- dirfile.DecRef()
+ defer start.DecRef(t)
+ dirfile.DecRef(t)
}
}
@@ -282,7 +282,7 @@ func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) err
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
// access(2) and faccessat(2) check permissions using real
// UID/GID, not effective UID/GID.
@@ -328,7 +328,7 @@ func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, siz
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
if err != nil {
@@ -358,7 +358,7 @@ func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
if err != nil {
@@ -377,7 +377,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
index 0d0ebf46a..a6491ac37 100644
--- a/pkg/sentry/syscalls/linux/vfs2/sync.go
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -34,7 +34,7 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, file.SyncFS(t)
}
@@ -47,7 +47,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, file.Sync(t)
}
@@ -77,7 +77,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support
// is a full-file sync, i.e. fsync(2). As a result, there are severe
diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
index 5ac79bc09..7a26890ef 100644
--- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
@@ -50,11 +50,11 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.EINVAL
}
vfsObj := t.Kernel().VFS()
- file, err := timerfd.New(vfsObj, clock, fileFlags)
+ file, err := timerfd.New(t, vfsObj, clock, fileFlags)
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
CloseOnExec: flags&linux.TFD_CLOEXEC != 0,
})
@@ -79,7 +79,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
if !ok {
@@ -113,7 +113,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
if !ok {
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
index af455d5c1..ef99246ed 100644
--- a/pkg/sentry/syscalls/linux/vfs2/xattr.go
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -49,7 +49,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size))
if err != nil {
@@ -72,7 +72,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
names, err := file.Listxattr(t, uint64(size))
if err != nil {
@@ -109,7 +109,7 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli
if err != nil {
return 0, nil, err
}
- defer tpop.Release()
+ defer tpop.Release(t)
name, err := copyInXattrName(t, nameAddr)
if err != nil {
@@ -141,7 +141,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
name, err := copyInXattrName(t, nameAddr)
if err != nil {
@@ -188,7 +188,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
name, err := copyInXattrName(t, nameAddr)
if err != nil {
@@ -222,7 +222,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
name, err := copyInXattrName(t, nameAddr)
if err != nil {
@@ -262,7 +262,7 @@ func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSy
if err != nil {
return err
}
- defer tpop.Release()
+ defer tpop.Release(t)
name, err := copyInXattrName(t, nameAddr)
if err != nil {
@@ -281,7 +281,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
name, err := copyInXattrName(t, nameAddr)
if err != nil {
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 641e3e502..5a0e3e6b5 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -82,7 +82,7 @@ type anonDentry struct {
}
// Release implements FilesystemImpl.Release.
-func (fs *anonFilesystem) Release() {
+func (fs *anonFilesystem) Release(ctx context.Context) {
}
// Sync implements FilesystemImpl.Sync.
@@ -294,7 +294,7 @@ func (d *anonDentry) TryIncRef() bool {
}
// DecRef implements DentryImpl.DecRef.
-func (d *anonDentry) DecRef() {
+func (d *anonDentry) DecRef(ctx context.Context) {
// no-op
}
@@ -303,7 +303,7 @@ func (d *anonDentry) DecRef() {
// Although Linux technically supports inotify on pseudo filesystems (inotify
// is implemented at the vfs layer), it is not particularly useful. It is left
// unimplemented until someone actually needs it.
-func (d *anonDentry) InotifyWithParent(events, cookie uint32, et EventType) {}
+func (d *anonDentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) {}
// Watches implements DentryImpl.Watches.
func (d *anonDentry) Watches() *Watches {
@@ -311,4 +311,4 @@ func (d *anonDentry) Watches() *Watches {
}
// OnZeroWatches implements Dentry.OnZeroWatches.
-func (d *anonDentry) OnZeroWatches() {}
+func (d *anonDentry) OnZeroWatches(context.Context) {}
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index cea3e6955..bc7ea93ea 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -17,6 +17,7 @@ package vfs
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -102,7 +103,7 @@ type DentryImpl interface {
TryIncRef() bool
// DecRef decrements the Dentry's reference count.
- DecRef()
+ DecRef(ctx context.Context)
// InotifyWithParent notifies all watches on the targets represented by this
// dentry and its parent. The parent's watches are notified first, followed
@@ -113,7 +114,7 @@ type DentryImpl interface {
//
// Note that the events may not actually propagate up to the user, depending
// on the event masks.
- InotifyWithParent(events, cookie uint32, et EventType)
+ InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType)
// Watches returns the set of inotify watches for the file corresponding to
// the Dentry. Dentries that are hard links to the same underlying file
@@ -135,7 +136,7 @@ type DentryImpl interface {
// The caller does not need to hold a reference on the dentry. OnZeroWatches
// may acquire inotify locks, so to prevent deadlock, no inotify locks should
// be held by the caller.
- OnZeroWatches()
+ OnZeroWatches(ctx context.Context)
}
// IncRef increments d's reference count.
@@ -150,8 +151,8 @@ func (d *Dentry) TryIncRef() bool {
}
// DecRef decrements d's reference count.
-func (d *Dentry) DecRef() {
- d.impl.DecRef()
+func (d *Dentry) DecRef(ctx context.Context) {
+ d.impl.DecRef(ctx)
}
// IsDead returns true if d has been deleted or invalidated by its owning
@@ -168,8 +169,8 @@ func (d *Dentry) isMounted() bool {
// InotifyWithParent notifies all watches on the targets represented by d and
// its parent of events.
-func (d *Dentry) InotifyWithParent(events, cookie uint32, et EventType) {
- d.impl.InotifyWithParent(events, cookie, et)
+func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) {
+ d.impl.InotifyWithParent(ctx, events, cookie, et)
}
// Watches returns the set of inotify watches associated with d.
@@ -182,8 +183,8 @@ func (d *Dentry) Watches() *Watches {
// OnZeroWatches performs cleanup tasks whenever the number of watches on a
// dentry drops to zero.
-func (d *Dentry) OnZeroWatches() {
- d.impl.OnZeroWatches()
+func (d *Dentry) OnZeroWatches(ctx context.Context) {
+ d.impl.OnZeroWatches(ctx)
}
// The following functions are exported so that filesystem implementations can
@@ -214,11 +215,11 @@ func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) {
// CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion
// succeeds.
-func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) {
+func (vfs *VirtualFilesystem) CommitDeleteDentry(ctx context.Context, d *Dentry) {
d.dead = true
d.mu.Unlock()
if d.isMounted() {
- vfs.forgetDeadMountpoint(d)
+ vfs.forgetDeadMountpoint(ctx, d)
}
}
@@ -226,12 +227,12 @@ func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) {
// did for reasons outside of VFS' control (e.g. d represents the local state
// of a file on a remote filesystem on which the file has already been
// deleted).
-func (vfs *VirtualFilesystem) InvalidateDentry(d *Dentry) {
+func (vfs *VirtualFilesystem) InvalidateDentry(ctx context.Context, d *Dentry) {
d.mu.Lock()
d.dead = true
d.mu.Unlock()
if d.isMounted() {
- vfs.forgetDeadMountpoint(d)
+ vfs.forgetDeadMountpoint(ctx, d)
}
}
@@ -278,13 +279,13 @@ func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
// that was replaced by from.
//
// Preconditions: PrepareRenameDentry was previously called on from and to.
-func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, to *Dentry) {
+func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, from, to *Dentry) {
from.mu.Unlock()
if to != nil {
to.dead = true
to.mu.Unlock()
if to.isMounted() {
- vfs.forgetDeadMountpoint(to)
+ vfs.forgetDeadMountpoint(ctx, to)
}
}
}
@@ -303,7 +304,7 @@ func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) {
//
// forgetDeadMountpoint is analogous to Linux's
// fs/namespace.c:__detach_mounts().
-func (vfs *VirtualFilesystem) forgetDeadMountpoint(d *Dentry) {
+func (vfs *VirtualFilesystem) forgetDeadMountpoint(ctx context.Context, d *Dentry) {
var (
vdsToDecRef []VirtualDentry
mountsToDecRef []*Mount
@@ -316,9 +317,9 @@ func (vfs *VirtualFilesystem) forgetDeadMountpoint(d *Dentry) {
vfs.mounts.seq.EndWrite()
vfs.mountMu.Unlock()
for _, vd := range vdsToDecRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
for _, mnt := range mountsToDecRef {
- mnt.DecRef()
+ mnt.DecRef(ctx)
}
}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 5b009b928..1b5af9f73 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -93,9 +93,9 @@ type epollInterest struct {
// NewEpollInstanceFD returns a FileDescription representing a new epoll
// instance. A reference is taken on the returned FileDescription.
-func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) {
+func (vfs *VirtualFilesystem) NewEpollInstanceFD(ctx context.Context) (*FileDescription, error) {
vd := vfs.NewAnonVirtualDentry("[eventpoll]")
- defer vd.DecRef()
+ defer vd.DecRef(ctx)
ep := &EpollInstance{
interest: make(map[epollInterestKey]*epollInterest),
}
@@ -110,7 +110,7 @@ func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) {
}
// Release implements FileDescriptionImpl.Release.
-func (ep *EpollInstance) Release() {
+func (ep *EpollInstance) Release(ctx context.Context) {
// Unregister all polled fds.
ep.interestMu.Lock()
defer ep.interestMu.Unlock()
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 93861fb4a..576ab3920 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -171,7 +171,7 @@ func (fd *FileDescription) TryIncRef() bool {
}
// DecRef decrements fd's reference count.
-func (fd *FileDescription) DecRef() {
+func (fd *FileDescription) DecRef(ctx context.Context) {
if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 {
// Unregister fd from all epoll instances.
fd.epollMu.Lock()
@@ -196,11 +196,11 @@ func (fd *FileDescription) DecRef() {
}
// Release implementation resources.
- fd.impl.Release()
+ fd.impl.Release(ctx)
if fd.writable {
fd.vd.mount.EndWrite()
}
- fd.vd.DecRef()
+ fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
// TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1.
if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
@@ -335,7 +335,7 @@ func (fd *FileDescription) Impl() FileDescriptionImpl {
type FileDescriptionImpl interface {
// Release is called when the associated FileDescription reaches zero
// references.
- Release()
+ Release(ctx context.Context)
// OnClose is called when a file descriptor representing the
// FileDescription is closed. Note that returning a non-nil error does not
@@ -526,7 +526,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St
Start: fd.vd,
})
stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(rp)
+ vfsObj.putResolvingPath(ctx, rp)
return stat, err
}
return fd.impl.Stat(ctx, opts)
@@ -541,7 +541,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(rp)
+ vfsObj.putResolvingPath(ctx, rp)
return err
}
return fd.impl.SetStat(ctx, opts)
@@ -557,7 +557,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
Start: fd.vd,
})
statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp)
- vfsObj.putResolvingPath(rp)
+ vfsObj.putResolvingPath(ctx, rp)
return statfs, err
}
return fd.impl.StatFS(ctx)
@@ -674,7 +674,7 @@ func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string
Start: fd.vd,
})
names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size)
- vfsObj.putResolvingPath(rp)
+ vfsObj.putResolvingPath(ctx, rp)
return names, err
}
names, err := fd.impl.Listxattr(ctx, size)
@@ -703,7 +703,7 @@ func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions)
Start: fd.vd,
})
val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
- vfsObj.putResolvingPath(rp)
+ vfsObj.putResolvingPath(ctx, rp)
return val, err
}
return fd.impl.Getxattr(ctx, *opts)
@@ -719,7 +719,7 @@ func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions)
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
- vfsObj.putResolvingPath(rp)
+ vfsObj.putResolvingPath(ctx, rp)
return err
}
return fd.impl.Setxattr(ctx, *opts)
@@ -735,7 +735,7 @@ func (fd *FileDescription) Removexattr(ctx context.Context, name string) error {
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name)
- vfsObj.putResolvingPath(rp)
+ vfsObj.putResolvingPath(ctx, rp)
return err
}
return fd.impl.Removexattr(ctx, name)
@@ -752,7 +752,7 @@ func (fd *FileDescription) MappedName(ctx context.Context) string {
vfsroot := RootFromContext(ctx)
s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd)
if vfsroot.Ok() {
- vfsroot.DecRef()
+ vfsroot.DecRef(ctx)
}
return s
}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 3b7e1c273..1cd607c0a 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -80,9 +80,9 @@ type testFD struct {
data DynamicBytesSource
}
-func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription {
+func newTestFD(ctx context.Context, vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription {
vd := vfsObj.NewAnonVirtualDentry("genCountFD")
- defer vd.DecRef()
+ defer vd.DecRef(ctx)
var fd testFD
fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{})
fd.DynamicBytesFileDescriptionImpl.SetDataSource(data)
@@ -90,7 +90,7 @@ func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesS
}
// Release implements FileDescriptionImpl.Release.
-func (fd *testFD) Release() {
+func (fd *testFD) Release(context.Context) {
}
// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags.
@@ -109,11 +109,11 @@ func TestGenCountFD(t *testing.T) {
ctx := contexttest.Context(t)
vfsObj := &VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
- fd := newTestFD(vfsObj, linux.O_RDWR, &genCount{})
- defer fd.DecRef()
+ fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &genCount{})
+ defer fd.DecRef(ctx)
// The first read causes Generate to be called to fill the FD's buffer.
buf := make([]byte, 2)
@@ -167,11 +167,11 @@ func TestWritable(t *testing.T) {
ctx := contexttest.Context(t)
vfsObj := &VirtualFilesystem{}
- if err := vfsObj.Init(); err != nil {
+ if err := vfsObj.Init(ctx); err != nil {
t.Fatalf("VFS init: %v", err)
}
- fd := newTestFD(vfsObj, linux.O_RDWR, &storeData{data: "init"})
- defer fd.DecRef()
+ fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &storeData{data: "init"})
+ defer fd.DecRef(ctx)
buf := make([]byte, 10)
ioseq := usermem.BytesIOSequence(buf)
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 6bb9ca180..df3758fd1 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -100,12 +100,12 @@ func (fs *Filesystem) TryIncRef() bool {
}
// DecRef decrements fs' reference count.
-func (fs *Filesystem) DecRef() {
+func (fs *Filesystem) DecRef(ctx context.Context) {
if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 {
fs.vfs.filesystemsMu.Lock()
delete(fs.vfs.filesystems, fs)
fs.vfs.filesystemsMu.Unlock()
- fs.impl.Release()
+ fs.impl.Release(ctx)
} else if refs < 0 {
panic("Filesystem.decRef() called without holding a reference")
}
@@ -149,7 +149,7 @@ func (fs *Filesystem) DecRef() {
type FilesystemImpl interface {
// Release is called when the associated Filesystem reaches zero
// references.
- Release()
+ Release(ctx context.Context)
// Sync "causes all pending modifications to filesystem metadata and cached
// file data to be written to the underlying [filesystem]", as by syncfs(2).
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 167b731ac..aff220a61 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -100,7 +100,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32)
id := uniqueid.GlobalFromContext(ctx)
vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id))
- defer vd.DecRef()
+ defer vd.DecRef(ctx)
fd := &Inotify{
id: id,
scratch: make([]byte, inotifyEventBaseSize),
@@ -118,7 +118,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32)
// Release implements FileDescriptionImpl.Release. Release removes all
// watches and frees all resources for an inotify instance.
-func (i *Inotify) Release() {
+func (i *Inotify) Release(ctx context.Context) {
var ds []*Dentry
// We need to hold i.mu to avoid a race with concurrent calls to
@@ -144,7 +144,7 @@ func (i *Inotify) Release() {
i.mu.Unlock()
for _, d := range ds {
- d.OnZeroWatches()
+ d.OnZeroWatches(ctx)
}
}
@@ -350,7 +350,7 @@ func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) {
// RmWatch looks up an inotify watch for the given 'wd' and configures the
// target to stop sending events to this inotify instance.
-func (i *Inotify) RmWatch(wd int32) error {
+func (i *Inotify) RmWatch(ctx context.Context, wd int32) error {
i.mu.Lock()
// Find the watch we were asked to removed.
@@ -374,7 +374,7 @@ func (i *Inotify) RmWatch(wd int32) error {
i.mu.Unlock()
if remaining == 0 {
- w.target.OnZeroWatches()
+ w.target.OnZeroWatches(ctx)
}
// Generate the event for the removal.
@@ -462,7 +462,7 @@ func (w *Watches) Remove(id uint64) {
// Notify queues a new event with watches in this set. Watches with
// IN_EXCL_UNLINK are skipped if the event is coming from a child that has been
// unlinked.
-func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlinked bool) {
+func (w *Watches) Notify(ctx context.Context, name string, events, cookie uint32, et EventType, unlinked bool) {
var hasExpired bool
w.mu.RLock()
for _, watch := range w.ws {
@@ -476,13 +476,13 @@ func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlin
w.mu.RUnlock()
if hasExpired {
- w.cleanupExpiredWatches()
+ w.cleanupExpiredWatches(ctx)
}
}
// This function is relatively expensive and should only be called where there
// are expired watches.
-func (w *Watches) cleanupExpiredWatches() {
+func (w *Watches) cleanupExpiredWatches(ctx context.Context) {
// Because of lock ordering, we cannot acquire Inotify.mu for each watch
// owner while holding w.mu. As a result, store expired watches locally
// before removing.
@@ -495,15 +495,15 @@ func (w *Watches) cleanupExpiredWatches() {
}
w.mu.RUnlock()
for _, watch := range toRemove {
- watch.owner.RmWatch(watch.wd)
+ watch.owner.RmWatch(ctx, watch.wd)
}
}
// HandleDeletion is called when the watch target is destroyed. Clear the
// watch set, detach watches from the inotify instances they belong to, and
// generate the appropriate events.
-func (w *Watches) HandleDeletion() {
- w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */)
+func (w *Watches) HandleDeletion(ctx context.Context) {
+ w.Notify(ctx, "", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */)
// As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for
// the owner of each watch being deleted. Instead, atomically store the
@@ -744,12 +744,12 @@ func InotifyEventFromStatMask(mask uint32) uint32 {
// InotifyRemoveChild sends the appriopriate notifications to the watch sets of
// the child being removed and its parent. Note that unlike most pairs of
// parent/child notifications, the child is notified first in this case.
-func InotifyRemoveChild(self, parent *Watches, name string) {
+func InotifyRemoveChild(ctx context.Context, self, parent *Watches, name string) {
if self != nil {
- self.Notify("", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */)
+ self.Notify(ctx, "", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */)
}
if parent != nil {
- parent.Notify(name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */)
+ parent.Notify(ctx, name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */)
}
}
@@ -762,13 +762,13 @@ func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches,
}
cookie := uniqueid.InotifyCookie(ctx)
if oldParent != nil {
- oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */)
+ oldParent.Notify(ctx, oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */)
}
if newParent != nil {
- newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */)
+ newParent.Notify(ctx, newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */)
}
// Somewhat surprisingly, self move events do not have a cookie.
if renamed != nil {
- renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */)
+ renamed.Notify(ctx, "", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */)
}
}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 32f901bd8..d1d29d0cd 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -200,8 +200,8 @@ func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth
if err != nil {
return nil, err
}
- defer root.DecRef()
- defer fs.DecRef()
+ defer root.DecRef(ctx)
+ defer fs.DecRef(ctx)
return vfs.NewDisconnectedMount(fs, root, opts)
}
@@ -221,7 +221,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
if vd.dentry.dead {
vd.dentry.mu.Unlock()
vfs.mountMu.Unlock()
- vd.DecRef()
+ vd.DecRef(ctx)
return syserror.ENOENT
}
// vd might have been mounted over between vfs.GetDentryAt() and
@@ -243,7 +243,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
// This can't fail since we're holding vfs.mountMu.
nextmnt.root.IncRef()
vd.dentry.mu.Unlock()
- vd.DecRef()
+ vd.DecRef(ctx)
vd = VirtualDentry{
mount: nextmnt,
dentry: nextmnt.root,
@@ -268,7 +268,7 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
if err != nil {
return err
}
- defer mnt.DecRef()
+ defer mnt.DecRef(ctx)
if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil {
return err
}
@@ -293,13 +293,13 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
if err != nil {
return err
}
- defer vd.DecRef()
+ defer vd.DecRef(ctx)
if vd.dentry != vd.mount.root {
return syserror.EINVAL
}
vfs.mountMu.Lock()
if mntns := MountNamespaceFromContext(ctx); mntns != nil {
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
if mntns != vd.mount.ns {
vfs.mountMu.Unlock()
return syserror.EINVAL
@@ -335,10 +335,10 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
vfs.mounts.seq.EndWrite()
vfs.mountMu.Unlock()
for _, vd := range vdsToDecRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
for _, mnt := range mountsToDecRef {
- mnt.DecRef()
+ mnt.DecRef(ctx)
}
return nil
}
@@ -479,7 +479,7 @@ func (mnt *Mount) IncRef() {
}
// DecRef decrements mnt's reference count.
-func (mnt *Mount) DecRef() {
+func (mnt *Mount) DecRef(ctx context.Context) {
refs := atomic.AddInt64(&mnt.refs, -1)
if refs&^math.MinInt64 == 0 { // mask out MSB
var vd VirtualDentry
@@ -490,10 +490,10 @@ func (mnt *Mount) DecRef() {
mnt.vfs.mounts.seq.EndWrite()
mnt.vfs.mountMu.Unlock()
}
- mnt.root.DecRef()
- mnt.fs.DecRef()
+ mnt.root.DecRef(ctx)
+ mnt.fs.DecRef(ctx)
if vd.Ok() {
- vd.DecRef()
+ vd.DecRef(ctx)
}
}
}
@@ -506,7 +506,7 @@ func (mntns *MountNamespace) IncRef() {
}
// DecRef decrements mntns' reference count.
-func (mntns *MountNamespace) DecRef() {
+func (mntns *MountNamespace) DecRef(ctx context.Context) {
vfs := mntns.root.fs.VirtualFilesystem()
if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 {
vfs.mountMu.Lock()
@@ -517,10 +517,10 @@ func (mntns *MountNamespace) DecRef() {
vfs.mounts.seq.EndWrite()
vfs.mountMu.Unlock()
for _, vd := range vdsToDecRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
for _, mnt := range mountsToDecRef {
- mnt.DecRef()
+ mnt.DecRef(ctx)
}
} else if refs < 0 {
panic("MountNamespace.DecRef() called without holding a reference")
@@ -534,7 +534,7 @@ func (mntns *MountNamespace) DecRef() {
// getMountAt is analogous to Linux's fs/namei.c:follow_mount().
//
// Preconditions: References are held on mnt and d.
-func (vfs *VirtualFilesystem) getMountAt(mnt *Mount, d *Dentry) *Mount {
+func (vfs *VirtualFilesystem) getMountAt(ctx context.Context, mnt *Mount, d *Dentry) *Mount {
// The first mount is special-cased:
//
// - The caller is assumed to have checked d.isMounted() already. (This
@@ -565,7 +565,7 @@ retryFirst:
// Raced with umount.
continue
}
- mnt.DecRef()
+ mnt.DecRef(ctx)
mnt = next
d = next.root
}
@@ -578,7 +578,7 @@ retryFirst:
//
// Preconditions: References are held on mnt and root. vfsroot is not (mnt,
// mnt.root).
-func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry {
+func (vfs *VirtualFilesystem) getMountpointAt(ctx context.Context, mnt *Mount, vfsroot VirtualDentry) VirtualDentry {
// The first mount is special-cased:
//
// - The caller must have already checked mnt against vfsroot.
@@ -602,12 +602,12 @@ retryFirst:
if !point.TryIncRef() {
// Since Mount holds a reference on Mount.key.point, this can only
// happen due to a racing change to Mount.key.
- parent.DecRef()
+ parent.DecRef(ctx)
goto retryFirst
}
if !vfs.mounts.seq.ReadOk(epoch) {
- point.DecRef()
- parent.DecRef()
+ point.DecRef(ctx)
+ parent.DecRef(ctx)
goto retryFirst
}
mnt = parent
@@ -635,16 +635,16 @@ retryFirst:
if !point.TryIncRef() {
// Since Mount holds a reference on Mount.key.point, this can
// only happen due to a racing change to Mount.key.
- parent.DecRef()
+ parent.DecRef(ctx)
goto retryNotFirst
}
if !vfs.mounts.seq.ReadOk(epoch) {
- point.DecRef()
- parent.DecRef()
+ point.DecRef(ctx)
+ parent.DecRef(ctx)
goto retryNotFirst
}
- d.DecRef()
- mnt.DecRef()
+ d.DecRef(ctx)
+ mnt.DecRef(ctx)
mnt = parent
d = point
}
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
index cd78d66bc..e4da15009 100644
--- a/pkg/sentry/vfs/pathname.go
+++ b/pkg/sentry/vfs/pathname.go
@@ -47,7 +47,7 @@ func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot,
haveRef := false
defer func() {
if haveRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
}()
@@ -64,12 +64,12 @@ loop:
// of FilesystemImpl.PrependPath() may return nil instead.
break loop
}
- nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot)
if !nextVD.Ok() {
break loop
}
if haveRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
vd = nextVD
haveRef = true
@@ -101,7 +101,7 @@ func (vfs *VirtualFilesystem) PathnameReachable(ctx context.Context, vfsroot, vd
haveRef := false
defer func() {
if haveRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
}()
loop:
@@ -112,12 +112,12 @@ loop:
if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
break loop
}
- nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot)
if !nextVD.Ok() {
return "", nil
}
if haveRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
vd = nextVD
haveRef = true
@@ -145,7 +145,7 @@ func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd
haveRef := false
defer func() {
if haveRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
}()
unreachable := false
@@ -157,13 +157,13 @@ loop:
if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
break loop
}
- nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot)
if !nextVD.Ok() {
unreachable = true
break loop
}
if haveRef {
- vd.DecRef()
+ vd.DecRef(ctx)
}
vd = nextVD
haveRef = true
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 9d047ff88..3304372d9 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sync"
@@ -136,31 +137,31 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat
return rp
}
-func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) {
+func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) {
rp.root = VirtualDentry{}
- rp.decRefStartAndMount()
+ rp.decRefStartAndMount(ctx)
rp.mount = nil
rp.start = nil
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
resolvingPathPool.Put(rp)
}
-func (rp *ResolvingPath) decRefStartAndMount() {
+func (rp *ResolvingPath) decRefStartAndMount(ctx context.Context) {
if rp.flags&rpflagsHaveStartRef != 0 {
- rp.start.DecRef()
+ rp.start.DecRef(ctx)
}
if rp.flags&rpflagsHaveMountRef != 0 {
- rp.mount.DecRef()
+ rp.mount.DecRef(ctx)
}
}
-func (rp *ResolvingPath) releaseErrorState() {
+func (rp *ResolvingPath) releaseErrorState(ctx context.Context) {
if rp.nextStart != nil {
- rp.nextStart.DecRef()
+ rp.nextStart.DecRef(ctx)
rp.nextStart = nil
}
if rp.nextMount != nil {
- rp.nextMount.DecRef()
+ rp.nextMount.DecRef(ctx)
rp.nextMount = nil
}
}
@@ -236,13 +237,13 @@ func (rp *ResolvingPath) Advance() {
// Restart resets the stream of path components represented by rp to its state
// on entry to the current FilesystemImpl method.
-func (rp *ResolvingPath) Restart() {
+func (rp *ResolvingPath) Restart(ctx context.Context) {
rp.pit = rp.origParts[rp.numOrigParts-1]
rp.mustBeDir = rp.mustBeDirOrig
rp.symlinks = rp.symlinksOrig
rp.curPart = rp.numOrigParts - 1
copy(rp.parts[:], rp.origParts[:rp.numOrigParts])
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
}
func (rp *ResolvingPath) relpathCommit() {
@@ -260,13 +261,13 @@ func (rp *ResolvingPath) relpathCommit() {
// Mount, CheckRoot returns (unspecified, non-nil error). Otherwise, path
// resolution should resolve d's parent normally, and CheckRoot returns (false,
// nil).
-func (rp *ResolvingPath) CheckRoot(d *Dentry) (bool, error) {
+func (rp *ResolvingPath) CheckRoot(ctx context.Context, d *Dentry) (bool, error) {
if d == rp.root.dentry && rp.mount == rp.root.mount {
// At contextual VFS root (due to e.g. chroot(2)).
return true, nil
} else if d == rp.mount.root {
// At mount root ...
- vd := rp.vfs.getMountpointAt(rp.mount, rp.root)
+ vd := rp.vfs.getMountpointAt(ctx, rp.mount, rp.root)
if vd.Ok() {
// ... of non-root mount.
rp.nextMount = vd.mount
@@ -283,11 +284,11 @@ func (rp *ResolvingPath) CheckRoot(d *Dentry) (bool, error) {
// to d. If d is a mount point, such that path resolution should switch to
// another Mount, CheckMount returns a non-nil error. Otherwise, CheckMount
// returns nil.
-func (rp *ResolvingPath) CheckMount(d *Dentry) error {
+func (rp *ResolvingPath) CheckMount(ctx context.Context, d *Dentry) error {
if !d.isMounted() {
return nil
}
- if mnt := rp.vfs.getMountAt(rp.mount, d); mnt != nil {
+ if mnt := rp.vfs.getMountAt(ctx, rp.mount, d); mnt != nil {
rp.nextMount = mnt
return resolveMountPointError{}
}
@@ -389,11 +390,11 @@ func (rp *ResolvingPath) HandleJump(target VirtualDentry) error {
return resolveMountRootOrJumpError{}
}
-func (rp *ResolvingPath) handleError(err error) bool {
+func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
switch err.(type) {
case resolveMountRootOrJumpError:
// Switch to the new Mount. We hold references on the Mount and Dentry.
- rp.decRefStartAndMount()
+ rp.decRefStartAndMount(ctx)
rp.mount = rp.nextMount
rp.start = rp.nextStart
rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef
@@ -412,7 +413,7 @@ func (rp *ResolvingPath) handleError(err error) bool {
case resolveMountPointError:
// Switch to the new Mount. We hold a reference on the Mount, but
// borrow the reference on the mount root from the Mount.
- rp.decRefStartAndMount()
+ rp.decRefStartAndMount(ctx)
rp.mount = rp.nextMount
rp.start = rp.nextMount.root
rp.flags = rp.flags&^rpflagsHaveStartRef | rpflagsHaveMountRef
@@ -423,12 +424,12 @@ func (rp *ResolvingPath) handleError(err error) bool {
// path.
rp.relpathCommit()
// Restart path resolution on the new Mount.
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
return true
case resolveAbsSymlinkError:
// Switch to the new Mount. References are borrowed from rp.root.
- rp.decRefStartAndMount()
+ rp.decRefStartAndMount(ctx)
rp.mount = rp.root.mount
rp.start = rp.root.dentry
rp.flags &^= rpflagsHaveMountRef | rpflagsHaveStartRef
@@ -440,7 +441,7 @@ func (rp *ResolvingPath) handleError(err error) bool {
// path, including the symlink target we just prepended.
rp.relpathCommit()
// Restart path resolution on the new Mount.
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
return true
default:
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 522e27475..9c2420683 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -122,7 +122,7 @@ type VirtualFilesystem struct {
}
// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
-func (vfs *VirtualFilesystem) Init() error {
+func (vfs *VirtualFilesystem) Init(ctx context.Context) error {
if vfs.mountpoints != nil {
panic("VFS already initialized")
}
@@ -145,7 +145,7 @@ func (vfs *VirtualFilesystem) Init() error {
devMinor: anonfsDevMinor,
}
anonfs.vfsfs.Init(vfs, &anonFilesystemType{}, &anonfs)
- defer anonfs.vfsfs.DecRef()
+ defer anonfs.vfsfs.DecRef(ctx)
anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{})
if err != nil {
// We should not be passing any MountOptions that would cause
@@ -192,11 +192,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -214,11 +214,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede
dentry: d,
}
rp.mount.IncRef()
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return vd, nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return VirtualDentry{}, err
}
}
@@ -236,7 +236,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
}
rp.mount.IncRef()
name := rp.Component()
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return parentVD, name, nil
}
if checkInvariants {
@@ -244,8 +244,8 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return VirtualDentry{}, "", err
}
}
@@ -260,14 +260,14 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
}
if !newpop.Path.Begin.Ok() {
- oldVD.DecRef()
+ oldVD.DecRef(ctx)
if newpop.Path.Absolute {
return syserror.EEXIST
}
return syserror.ENOENT
}
if newpop.FollowFinalSymlink {
- oldVD.DecRef()
+ oldVD.DecRef(ctx)
ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink")
return syserror.EINVAL
}
@@ -276,8 +276,8 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
for {
err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
if err == nil {
- vfs.putResolvingPath(rp)
- oldVD.DecRef()
+ vfs.putResolvingPath(ctx, rp)
+ oldVD.DecRef(ctx)
return nil
}
if checkInvariants {
@@ -285,9 +285,9 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- oldVD.DecRef()
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ oldVD.DecRef(ctx)
return err
}
}
@@ -313,7 +313,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
if checkInvariants {
@@ -321,8 +321,8 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -346,7 +346,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
if checkInvariants {
@@ -354,8 +354,8 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -408,31 +408,31 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
for {
fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
if opts.FileExec {
if fd.Mount().Flags.NoExec {
- fd.DecRef()
+ fd.DecRef(ctx)
return nil, syserror.EACCES
}
// Only a regular file can be executed.
stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
if err != nil {
- fd.DecRef()
+ fd.DecRef(ctx)
return nil, err
}
if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
- fd.DecRef()
+ fd.DecRef(ctx)
return nil, syserror.EACCES
}
}
- fd.Dentry().InotifyWithParent(linux.IN_OPEN, 0, PathEvent)
+ fd.Dentry().InotifyWithParent(ctx, linux.IN_OPEN, 0, PathEvent)
return fd, nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return nil, err
}
}
@@ -444,11 +444,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden
for {
target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return target, nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return "", err
}
}
@@ -472,19 +472,19 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
return err
}
if oldName == "." || oldName == ".." {
- oldParentVD.DecRef()
+ oldParentVD.DecRef(ctx)
return syserror.EBUSY
}
if !newpop.Path.Begin.Ok() {
- oldParentVD.DecRef()
+ oldParentVD.DecRef(ctx)
if newpop.Path.Absolute {
return syserror.EBUSY
}
return syserror.ENOENT
}
if newpop.FollowFinalSymlink {
- oldParentVD.DecRef()
+ oldParentVD.DecRef(ctx)
ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink")
return syserror.EINVAL
}
@@ -497,8 +497,8 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts)
if err == nil {
- vfs.putResolvingPath(rp)
- oldParentVD.DecRef()
+ vfs.putResolvingPath(ctx, rp)
+ oldParentVD.DecRef(ctx)
return nil
}
if checkInvariants {
@@ -506,9 +506,9 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- oldParentVD.DecRef()
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ oldParentVD.DecRef(ctx)
return err
}
}
@@ -531,7 +531,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.RmdirAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
if checkInvariants {
@@ -539,8 +539,8 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -552,11 +552,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent
for {
err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -568,11 +568,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential
for {
stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return stat, nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return linux.Statx{}, err
}
}
@@ -585,11 +585,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti
for {
statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return statfs, nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return linux.Statfs{}, err
}
}
@@ -612,7 +612,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
for {
err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
if checkInvariants {
@@ -620,8 +620,8 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -644,7 +644,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
if checkInvariants {
@@ -652,8 +652,8 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -671,7 +671,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
for {
bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return bep, nil
}
if checkInvariants {
@@ -679,8 +679,8 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return nil, err
}
}
@@ -693,7 +693,7 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede
for {
names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return names, nil
}
if err == syserror.ENOTSUP {
@@ -701,11 +701,11 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
// subsystem to return security extended attributes, which by
// default don't exist.
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil, nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return nil, err
}
}
@@ -718,11 +718,11 @@ func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Creden
for {
val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return val, nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return "", err
}
}
@@ -735,11 +735,11 @@ func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Creden
for {
err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -751,11 +751,11 @@ func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Cre
for {
err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name)
if err == nil {
- vfs.putResolvingPath(rp)
+ vfs.putResolvingPath(ctx, rp)
return nil
}
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
return err
}
}
@@ -777,7 +777,7 @@ func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
retErr = err
}
- fs.DecRef()
+ fs.DecRef(ctx)
}
return retErr
}
@@ -831,9 +831,9 @@ func (vd VirtualDentry) IncRef() {
// DecRef decrements the reference counts on the Mount and Dentry represented
// by vd.
-func (vd VirtualDentry) DecRef() {
- vd.dentry.DecRef()
- vd.mount.DecRef()
+func (vd VirtualDentry) DecRef(ctx context.Context) {
+ vd.dentry.DecRef(ctx)
+ vd.mount.DecRef(ctx)
}
// Mount returns the Mount associated with vd. It does not take a reference on
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index e0db6cf54..6c137f693 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -12,6 +12,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/refs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 04ae58e59..22b0a12bd 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -64,14 +65,14 @@ func (d *Device) beforeSave() {
}
// Release implements fs.FileOperations.Release.
-func (d *Device) Release() {
+func (d *Device) Release(ctx context.Context) {
d.mu.Lock()
defer d.mu.Unlock()
// Decrease refcount if there is an endpoint associated with this file.
if d.endpoint != nil {
d.endpoint.RemoveNotify(d.notifyHandle)
- d.endpoint.DecRef()
+ d.endpoint.DecRef(ctx)
d.endpoint = nil
}
}
@@ -341,8 +342,8 @@ type tunEndpoint struct {
}
// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
-func (e *tunEndpoint) DecRef() {
- e.DecRefWithDestructor(func() {
+func (e *tunEndpoint) DecRef(ctx context.Context) {
+ e.DecRefWithDestructor(ctx, func(context.Context) {
e.stack.RemoveNIC(e.nicID)
})
}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 59639ba19..9dd5b0184 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -640,7 +640,7 @@ func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Confi
func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error {
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
for _, m := range c.mounts {
log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options)
@@ -868,7 +868,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns
if err != nil {
return fmt.Errorf("can't find mount destination %q: %v", m.Destination, err)
}
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
if err := mns.Mount(ctx, dirent, inode); err != nil {
return fmt.Errorf("mount %q error: %v", m.Destination, err)
}
@@ -889,12 +889,12 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun
if err != nil {
return fmt.Errorf("can't find mount destination %q: %v", mount.Destination, err)
}
- defer target.DecRef()
+ defer target.DecRef(ctx)
// Take a ref on the inode that is about to be (re)-mounted.
source.root.IncRef()
if err := mns.Mount(ctx, target, source.root); err != nil {
- source.root.DecRef()
+ source.root.DecRef(ctx)
return fmt.Errorf("bind mount %q error: %v", mount.Destination, err)
}
@@ -997,12 +997,12 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.M
switch err {
case nil:
// Found '/tmp' in filesystem, check if it's empty.
- defer tmp.DecRef()
+ defer tmp.DecRef(ctx)
f, err := tmp.Inode.GetFile(ctx, tmp, fs.FileFlags{Read: true, Directory: true})
if err != nil {
return err
}
- defer f.DecRef()
+ defer f.DecRef(ctx)
serializer := &fs.CollectEntriesSerializer{}
if err := f.Readdir(ctx, serializer); err != nil {
return err
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 9cd9c5909..e0d077f5a 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -346,7 +346,7 @@ func New(args Args) (*Loader, error) {
if err != nil {
return nil, fmt.Errorf("failed to create hostfs filesystem: %v", err)
}
- defer hostFilesystem.DecRef()
+ defer hostFilesystem.DecRef(k.SupervisorContext())
hostMount, err := k.VFS().NewDisconnectedMount(hostFilesystem, nil, &vfs.MountOptions{})
if err != nil {
return nil, fmt.Errorf("failed to create hostfs mount: %v", err)
@@ -755,7 +755,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
return nil, fmt.Errorf("creating process: %v", err)
}
// CreateProcess takes a reference on FDTable if successful.
- info.procArgs.FDTable.DecRef()
+ info.procArgs.FDTable.DecRef(ctx)
// Set the foreground process group on the TTY to the global init process
// group, since that is what we are about to start running.
@@ -890,22 +890,20 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
// Add the HOME environment variable if it is not already set.
if kernel.VFS2Enabled {
- defer args.MountNamespaceVFS2.DecRef()
-
root := args.MountNamespaceVFS2.Root()
- defer root.DecRef()
ctx := vfs.WithRoot(l.k.SupervisorContext(), root)
+ defer args.MountNamespaceVFS2.DecRef(ctx)
+ defer root.DecRef(ctx)
envv, err := user.MaybeAddExecUserHomeVFS2(ctx, args.MountNamespaceVFS2, args.KUID, args.Envv)
if err != nil {
return 0, err
}
args.Envv = envv
} else {
- defer args.MountNamespace.DecRef()
-
root := args.MountNamespace.Root()
- defer root.DecRef()
ctx := fs.WithRoot(l.k.SupervisorContext(), root)
+ defer args.MountNamespace.DecRef(ctx)
+ defer root.DecRef(ctx)
envv, err := user.MaybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
if err != nil {
return 0, err
@@ -1263,7 +1261,7 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
fdTable := k.NewFDTable()
ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, console, stdioFDs)
if err != nil {
- fdTable.DecRef()
+ fdTable.DecRef(ctx)
return nil, nil, nil, err
}
return fdTable, ttyFile, ttyFileVFS2, nil
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 8e6fe57e1..aa3fdf96c 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -450,13 +450,13 @@ func TestCreateMountNamespace(t *testing.T) {
}
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
for _, p := range tc.expectedPaths {
maxTraversals := uint(0)
if d, err := mns.FindInode(ctx, root, root, p, &maxTraversals); err != nil {
t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err)
} else {
- d.DecRef()
+ d.DecRef(ctx)
}
}
})
@@ -491,7 +491,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) {
}
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
for _, p := range tc.expectedPaths {
target := &vfs.PathOperation{
Root: root,
@@ -502,7 +502,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) {
if d, err := l.k.VFS().GetDentryAt(ctx, l.root.procArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil {
t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err)
} else {
- d.DecRef()
+ d.DecRef(ctx)
}
}
})
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index cfe2d36aa..252ca07e3 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -103,7 +103,7 @@ func registerFilesystems(k *kernel.Kernel) error {
if err != nil {
return fmt.Errorf("creating devtmpfs accessor: %w", err)
}
- defer a.Release()
+ defer a.Release(ctx)
if err := a.UserspaceInit(ctx); err != nil {
return fmt.Errorf("initializing userspace: %w", err)
@@ -252,7 +252,7 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) {
func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) error {
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
target := &vfs.PathOperation{
Root: root,
Start: root,
@@ -387,7 +387,7 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds
}
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
pop := vfs.PathOperation{
Root: root,
Start: root,
@@ -481,10 +481,10 @@ func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *Co
if err != nil {
return err
}
- defer newMnt.DecRef()
+ defer newMnt.DecRef(ctx)
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
if err := c.makeSyntheticMount(ctx, mount.Destination, root, creds); err != nil {
return err
}
--
cgit v1.2.3