summaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
authorAndrei Vagin <avagin@google.com>2018-12-28 11:26:01 -0800
committerShentubot <shentubot@google.com>2018-12-28 11:27:14 -0800
commit652d068119052b0b3bc4a0808a4400a22380a30b (patch)
treef5a617063151ffb9563ebbcd3189611e854952db /test
parenta3217b71723a93abb7a2aca535408ab84d81ac2f (diff)
Implement SO_REUSEPORT for TCP and UDP sockets
This option allows multiple sockets to be bound to the same port. Incoming packets are distributed to sockets using a hash based on source and destination addresses. This means that all packets from one sender will be received by the same server socket. PiperOrigin-RevId: 227153413 Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf
Diffstat (limited to 'test')
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc289
-rw-r--r--test/syscalls/syscall_test_runner.go1
3 files changed, 294 insertions, 0 deletions
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index ae33d14da..f0e61e083 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2163,9 +2163,13 @@ cc_binary(
":socket_test_util",
"//test/util:file_descriptor",
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:test_main",
"//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 17a46e149..0893be5a7 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -17,17 +17,24 @@
#include <string.h>
#include <sys/socket.h>
+#include <atomic>
+#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -227,6 +234,238 @@ INSTANTIATE_TEST_CASE_P(
TestParam{V6Loopback(), V6Loopback()}),
DescribeTestParam);
+using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>;
+
+TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // Create the listening socket.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) continue;
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::unique_ptr<ScopedThread> listen_thread[kThreadCount];
+ int accept_counts[kThreadCount] = {};
+ // TODO: figure how to not disable S/R for the whole test.
+ // We need to take into account that this test executes a lot of system
+ // calls from many threads.
+ DisableSave ds;
+
+ for (int i = 0; i < kThreadCount; i++) {
+ listen_thread[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, i, &connects_received]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ if (connects_received >= kConnectAttempts) {
+ // Another thread have shutdown our read side causing the
+ // accept to fail.
+ break;
+ }
+ ASSERT_NO_ERRNO(fd);
+ break;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ int data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (int j = 0; j < kThreadCount; j++) {
+ shutdown(listener_fds[j].get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ ScopedThread connecting_thread([&connector, &conn_addr]() {
+ for (int i = 0; i < kConnectAttempts; i++) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ });
+
+ // Join threads to be sure that all connections have been counted
+ connecting_thread.Join();
+ for (int i = 0; i < kThreadCount; i++) {
+ listen_thread[i]->Join();
+ }
+ // Check that connections are distributed fairly between listening sockets
+ for (int i = 0; i < kThreadCount; i++)
+ EXPECT_THAT(accept_counts[i],
+ EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
+}
+
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // Create the listening socket.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) continue;
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::unique_ptr<ScopedThread> receiver_thread[kThreadCount];
+ int packets_per_socket[kThreadCount] = {};
+ // TODO: figure how to not disable S/R for the whole test.
+ DisableSave ds; // Too expensive.
+
+ for (int i = 0; i < kThreadCount; i++) {
+ receiver_thread[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, i, &packets_received]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (int j = 0; j < kThreadCount; j++)
+ shutdown(listener_fds[j].get(), SHUT_RDWR);
+ });
+ }
+
+ ScopedThread main_thread([&connector, &conn_addr]() {
+ for (int i = 0; i < kConnectAttempts; i++) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+ });
+
+ main_thread.Join();
+
+ // Join threads to be sure that all connections have been counted
+ for (int i = 0; i < kThreadCount; i++) {
+ receiver_thread[i]->Join();
+ }
+ // Check that packets are distributed fairly between listening sockets.
+ for (int i = 0; i < kThreadCount; i++)
+ EXPECT_THAT(packets_per_socket[i],
+ EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ All, SocketInetReusePortTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Loopback()}, TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
struct ProtocolTestParam {
std::string description;
int type;
@@ -806,6 +1045,56 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
}
}
+TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
+ auto const& param = GetParam();
+ TestAddress const& test_addr = V4Loopback();
+ sockaddr_storage addr = test_addr.addr;
+
+ for (int i = 0; i < 2; i++) {
+ const int portreuse1 = i % 2;
+ auto s1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ int fd1 = s1.get();
+ socklen_t addrlen = test_addr.addr_len;
+
+ EXPECT_THAT(
+ setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(fd1, reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(getsockname(fd1, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(fd1, 1), SyscallSucceeds());
+ }
+
+ // j is less than 4 to check that the port reuse logic works correctly after
+ // closing bound sockets.
+ for (int j = 0; j < 4; j++) {
+ const int portreuse2 = j % 2;
+ auto s2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ int fd2 = s2.get();
+
+ EXPECT_THAT(
+ setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)),
+ SyscallSucceeds());
+
+ LOG(INFO) << portreuse1 << " " << portreuse2;
+ int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen);
+
+ // Verify that two sockets can be bound to the same port only if
+ // SO_REUSEPORT is set for both of them.
+ if (!portreuse1 || !portreuse2)
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE));
+ else
+ ASSERT_THAT(ret, SyscallSucceeds());
+ }
+ }
+}
+
INSTANTIATE_TEST_CASE_P(AllFamlies, SocketMultiProtocolInetLoopbackTest,
::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM},
ProtocolTestParam{"UDP", SOCK_DGRAM}),
diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go
index 9ee0361ee..ec048f10f 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/syscalls/syscall_test_runner.go
@@ -118,6 +118,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Mark the root as writeable, as some tests attempt to
// write to the rootfs, and expect EACCES, not EROFS.
spec.Root.Readonly = false
+ spec.Mounts = nil
// Set environment variable that indicates we are
// running in gVisor and with the given platform.