From 652d068119052b0b3bc4a0808a4400a22380a30b Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Fri, 28 Dec 2018 11:26:01 -0800 Subject: Implement SO_REUSEPORT for TCP and UDP sockets This option allows multiple sockets to be bound to the same port. Incoming packets are distributed to sockets using a hash based on source and destination addresses. This means that all packets from one sender will be received by the same server socket. PiperOrigin-RevId: 227153413 Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf --- test/syscalls/linux/BUILD | 4 + test/syscalls/linux/socket_inet_loopback.cc | 289 ++++++++++++++++++++++++++++ test/syscalls/syscall_test_runner.go | 1 + 3 files changed, 294 insertions(+) (limited to 'test') diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ae33d14da..f0e61e083 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2163,9 +2163,13 @@ cc_binary( ":socket_test_util", "//test/util:file_descriptor", "//test/util:posix_error", + "//test/util:save_util", "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 17a46e149..0893be5a7 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -17,17 +17,24 @@ #include #include +#include +#include #include #include #include #include +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -227,6 +234,238 @@ INSTANTIATE_TEST_CASE_P( TestParam{V6Loopback(), V6Loopback()}), DescribeTestParam); +using SocketInetReusePortTest = ::testing::TestWithParam; + +TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // Create the listening socket. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) continue; + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic connects_received = ATOMIC_VAR_INIT(0); + std::unique_ptr listen_thread[kThreadCount]; + int accept_counts[kThreadCount] = {}; + // TODO: figure how to not disable S/R for the whole test. + // We need to take into account that this test executes a lot of system + // calls from many threads. + DisableSave ds; + + for (int i = 0; i < kThreadCount; i++) { + listen_thread[i] = absl::make_unique( + [&listener_fds, &accept_counts, i, &connects_received]() { + do { + auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); + if (!fd.ok()) { + if (connects_received >= kConnectAttempts) { + // Another thread have shutdown our read side causing the + // accept to fail. + break; + } + ASSERT_NO_ERRNO(fd); + break; + } + // Receive some data from a socket to be sure that the connect() + // system call has been completed on another side. + int data; + EXPECT_THAT( + RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + accept_counts[i]++; + } while (++connects_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (int j = 0; j < kThreadCount; j++) { + shutdown(listener_fds[j].get(), SHUT_RDWR); + } + }); + } + + ScopedThread connecting_thread([&connector, &conn_addr]() { + for (int i = 0; i < kConnectAttempts; i++) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + RetryEINTR(connect)(fd.get(), reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } + }); + + // Join threads to be sure that all connections have been counted + connecting_thread.Join(); + for (int i = 0; i < kThreadCount; i++) { + listen_thread[i]->Join(); + } + // Check that connections are distributed fairly between listening sockets + for (int i = 0; i < kThreadCount; i++) + EXPECT_THAT(accept_counts[i], + EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); +} + +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // Create the listening socket. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) continue; + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic packets_received = ATOMIC_VAR_INIT(0); + std::unique_ptr receiver_thread[kThreadCount]; + int packets_per_socket[kThreadCount] = {}; + // TODO: figure how to not disable S/R for the whole test. + DisableSave ds; // Too expensive. + + for (int i = 0; i < kThreadCount; i++) { + receiver_thread[i] = absl::make_unique( + [&listener_fds, &packets_per_socket, i, &packets_received]() { + do { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + + auto ret = RetryEINTR(recvfrom)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast(&addr), &addrlen); + + if (packets_received < kConnectAttempts) { + ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); + } + + if (ret != sizeof(data)) { + // Another thread may have shutdown our read side causing the + // recvfrom to fail. + break; + } + + packets_received++; + packets_per_socket[i]++; + + // A response is required to synchronize with the main thread, + // otherwise the main thread can send more than can fit into receive + // queues. + EXPECT_THAT(RetryEINTR(sendto)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); + } while (packets_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (int j = 0; j < kThreadCount; j++) + shutdown(listener_fds[j].get(), SHUT_RDWR); + }); + } + + ScopedThread main_thread([&connector, &conn_addr]() { + for (int i = 0; i < kConnectAttempts; i++) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + int data; + EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + } + }); + + main_thread.Join(); + + // Join threads to be sure that all connections have been counted + for (int i = 0; i < kThreadCount; i++) { + receiver_thread[i]->Join(); + } + // Check that packets are distributed fairly between listening sockets. + for (int i = 0; i < kThreadCount; i++) + EXPECT_THAT(packets_per_socket[i], + EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); +} + +INSTANTIATE_TEST_CASE_P( + All, SocketInetReusePortTest, + ::testing::Values( + // Listeners bound to IPv4 addresses refuse connections using IPv6 + // addresses. + TestParam{V4Any(), V4Loopback()}, + TestParam{V4Loopback(), V4MappedLoopback()}, + + // Listeners bound to IN6ADDR_ANY accept all connections. + TestParam{V6Any(), V4Loopback()}, TestParam{V6Any(), V6Loopback()}, + + // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 + // addresses. + TestParam{V6Loopback(), V6Loopback()}), + DescribeTestParam); + struct ProtocolTestParam { std::string description; int type; @@ -806,6 +1045,56 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { } } +TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { + auto const& param = GetParam(); + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage addr = test_addr.addr; + + for (int i = 0; i < 2; i++) { + const int portreuse1 = i % 2; + auto s1 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + int fd1 = s1.get(); + socklen_t addrlen = test_addr.addr_len; + + EXPECT_THAT( + setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)), + SyscallSucceeds()); + + ASSERT_THAT(bind(fd1, reinterpret_cast(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(getsockname(fd1, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(fd1, 1), SyscallSucceeds()); + } + + // j is less than 4 to check that the port reuse logic works correctly after + // closing bound sockets. + for (int j = 0; j < 4; j++) { + const int portreuse2 = j % 2; + auto s2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + int fd2 = s2.get(); + + EXPECT_THAT( + setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)), + SyscallSucceeds()); + + LOG(INFO) << portreuse1 << " " << portreuse2; + int ret = bind(fd2, reinterpret_cast(&addr), addrlen); + + // Verify that two sockets can be bound to the same port only if + // SO_REUSEPORT is set for both of them. + if (!portreuse1 || !portreuse2) + ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE)); + else + ASSERT_THAT(ret, SyscallSucceeds()); + } + } +} + INSTANTIATE_TEST_CASE_P(AllFamlies, SocketMultiProtocolInetLoopbackTest, ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, ProtocolTestParam{"UDP", SOCK_DGRAM}), diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index 9ee0361ee..ec048f10f 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -118,6 +118,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. spec.Root.Readonly = false + spec.Mounts = nil // Set environment variable that indicates we are // running in gVisor and with the given platform. -- cgit v1.2.3