From 570ca571805d6939c4c24b6a88660eefaf558ae7 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 1 Jul 2021 14:39:20 -0700 Subject: Fix bug with TCP bind w/ SO_REUSEADDR. In gVisor today its possible that when trying to bind a TCP socket w/ SO_REUSEADDR specified and requesting the kernel pick a port by setting port to zero can result in a previously bound port being returned. This behaviour is incorrect as the user is clearly requesting a free port. The behaviour is fine when the user explicity specifies a port. This change now checks if the user specified a port when making a port reservation for a TCP port and only returns unbound ports even if SO_REUSEADDR was specified. Fixes #6209 PiperOrigin-RevId: 382607638 --- pkg/tcpip/ports/BUILD | 1 + pkg/tcpip/ports/ports.go | 40 ++++++++---- test/syscalls/linux/socket_generic_stress.cc | 48 ++------------ .../linux/socket_inet_loopback_nogotsan.cc | 74 +++++++++++++++++++--- test/syscalls/linux/socket_test_util.cc | 39 ++++++++++++ test/syscalls/linux/socket_test_util.h | 4 ++ 6 files changed, 141 insertions(+), 65 deletions(-) diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index b7f6d52ae..fe98a52af 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -12,6 +12,7 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", + "//pkg/tcpip/header", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 854d6a6ba..fb8ef1ee2 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -122,7 +123,7 @@ type deviceToDest map[tcpip.NICID]destToCounter // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (dd deviceToDest) isAvailable(res Reservation) bool { +func (dd deviceToDest) isAvailable(res Reservation, portSpecified bool) bool { flagBits := res.Flags.Bits() if res.BindToDevice == 0 { intersection := FlagMask @@ -138,6 +139,9 @@ func (dd deviceToDest) isAvailable(res Reservation) bool { return false } } + if !portSpecified && res.Transport == header.TCPProtocolNumber { + return false + } return true } @@ -146,16 +150,26 @@ func (dd deviceToDest) isAvailable(res Reservation) bool { if dests, ok := dd[0]; ok { var count int intersection, count = dests.intersectionFlags(res) - if count > 0 && intersection&flagBits == 0 { - return false + if count > 0 { + if intersection&flagBits == 0 { + return false + } + if !portSpecified && res.Transport == header.TCPProtocolNumber { + return false + } } } if dests, ok := dd[res.BindToDevice]; ok { flags, count := dests.intersectionFlags(res) intersection &= flags - if count > 0 && intersection&flagBits == 0 { - return false + if count > 0 { + if intersection&flagBits == 0 { + return false + } + if !portSpecified && res.Transport == header.TCPProtocolNumber { + return false + } } } @@ -168,12 +182,12 @@ type addrToDevice map[tcpip.Address]deviceToDest // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (ad addrToDevice) isAvailable(res Reservation) bool { +func (ad addrToDevice) isAvailable(res Reservation, portSpecified bool) bool { if res.Addr == anyIPAddress { // If binding to the "any" address then check that there are no // conflicts with all addresses. for _, devices := range ad { - if !devices.isAvailable(res) { + if !devices.isAvailable(res, portSpecified) { return false } } @@ -182,14 +196,14 @@ func (ad addrToDevice) isAvailable(res Reservation) bool { // Check that there is no conflict with the "any" address. if devices, ok := ad[anyIPAddress]; ok { - if !devices.isAvailable(res) { + if !devices.isAvailable(res, portSpecified) { return false } } // Check that this is no conflict with the provided address. if devices, ok := ad[res.Addr]; ok { - if !devices.isAvailable(res) { + if !devices.isAvailable(res, portSpecified) { return false } } @@ -310,7 +324,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por // If a port is specified, just try to reserve it for all network // protocols. if res.Port != 0 { - if !pm.reserveSpecificPortLocked(res) { + if !pm.reserveSpecificPortLocked(res, true /* portSpecified */) { return 0, &tcpip.ErrPortInUse{} } if testPort != nil { @@ -330,7 +344,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por // A port wasn't specified, so try to find one. return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) { res.Port = p - if !pm.reserveSpecificPortLocked(res) { + if !pm.reserveSpecificPortLocked(res, false /* portSpecified */) { return false, nil } if testPort != nil { @@ -350,12 +364,12 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por // reserveSpecificPortLocked tries to reserve the given port on all given // protocols. -func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool { +func (pm *PortManager) reserveSpecificPortLocked(res Reservation, portSpecified bool) bool { // Make sure the port is available. for _, network := range res.Networks { desc := portDescriptor{network, res.Transport, res.Port} if addrs, ok := pm.allocatedPorts[desc]; ok { - if !addrs.isAvailable(res) { + if !addrs.isAvailable(res, portSpecified) { return false } } diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc index c35aa2183..778c32a8e 100644 --- a/test/syscalls/linux/socket_generic_stress.cc +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -37,49 +37,11 @@ namespace gvisor { namespace testing { -constexpr char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range"; - -PosixErrorOr NumPorts() { - int min = 0; - int max = 1 << 16; - - // Read the ephemeral range from /proc. - ASSIGN_OR_RETURN_ERRNO(std::string rangefile, GetContents(kRangeFile)); - const std::string err_msg = - absl::StrFormat("%s has invalid content: %s", kRangeFile, rangefile); - if (rangefile.back() != '\n') { - return PosixError(EINVAL, err_msg); - } - rangefile.pop_back(); - std::vector range = - absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); - if (range.size() < 2 || !absl::SimpleAtoi(range.front(), &min) || - !absl::SimpleAtoi(range.back(), &max)) { - return PosixError(EINVAL, err_msg); - } - - // If we can open as writable, limit the range. - if (!access(kRangeFile, W_OK)) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, - Open(kRangeFile, O_WRONLY | O_TRUNC, 0)); - max = min + 50; - const std::string small_range = absl::StrFormat("%d %d", min, max); - int n = write(fd.get(), small_range.c_str(), small_range.size()); - if (n < 0) { - return PosixError( - errno, - absl::StrFormat("write(%d [%s], \"%s\", %d)", fd.get(), kRangeFile, - small_range.c_str(), small_range.size())); - } - } - return max - min; -} - // Test fixture for tests that apply to pairs of connected sockets. using ConnectStressTest = SocketPairTest; TEST_P(ConnectStressTest, Reset) { - const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + const int nports = ASSERT_NO_ERRNO_AND_VALUE(MaybeLimitEphemeralPorts()); for (int i = 0; i < nports * 2; i++) { const std::unique_ptr sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -103,7 +65,7 @@ TEST_P(ConnectStressTest, Reset) { // Tests that opening too many connections -- without closing them -- does lead // to port exhaustion. TEST_P(ConnectStressTest, TooManyOpen) { - const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + const int nports = ASSERT_NO_ERRNO_AND_VALUE(MaybeLimitEphemeralPorts()); int err_num = 0; std::vector> sockets = std::vector>(nports); @@ -164,7 +126,7 @@ class PersistentListenerConnectStressTest : public SocketPairTest { }; TEST_P(PersistentListenerConnectStressTest, ShutdownCloseFirst) { - const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + const int nports = ASSERT_NO_ERRNO_AND_VALUE(MaybeLimitEphemeralPorts()); for (int i = 0; i < nports * 2; i++) { std::unique_ptr sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep()); @@ -185,7 +147,7 @@ TEST_P(PersistentListenerConnectStressTest, ShutdownCloseFirst) { } TEST_P(PersistentListenerConnectStressTest, ShutdownCloseSecond) { - const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + const int nports = ASSERT_NO_ERRNO_AND_VALUE(MaybeLimitEphemeralPorts()); for (int i = 0; i < nports * 2; i++) { const std::unique_ptr sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -206,7 +168,7 @@ TEST_P(PersistentListenerConnectStressTest, ShutdownCloseSecond) { } TEST_P(PersistentListenerConnectStressTest, Close) { - const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + const int nports = ASSERT_NO_ERRNO_AND_VALUE(MaybeLimitEphemeralPorts()); for (int i = 0; i < nports * 2; i++) { std::unique_ptr sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep()); diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc index b131213d4..cc2773af1 100644 --- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -104,16 +104,25 @@ INSTANTIATE_TEST_SUITE_P(All, SocketInetLoopbackTest, using SocketMultiProtocolInetLoopbackTest = ::testing::TestWithParam; -TEST_P(SocketMultiProtocolInetLoopbackTest, BindAvoidsListeningPortsReuseAddr) { +TEST_P(SocketMultiProtocolInetLoopbackTest, + TCPBindAvoidsOtherBoundPortsReuseAddr) { ProtocolTestParam const& param = GetParam(); - // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP - // this is only permitted if there is no other listening socket. + // UDP sockets are allowed to bind/listen on an already bound port w/ + // SO_REUSEADDR even when requesting a port from the kernel. In case of TCP + // rebinding is only permitted when SO_REUSEADDR is set and an explicit port + // is specified. When a zero port is specified to the bind() call then an + // already bound port will not be picked. SKIP_IF(param.type != SOCK_STREAM); DisableSave ds; // Too many syscalls. // A map of port to file descriptor binding the port. - std::map listen_sockets; + std::map bound_sockets; + + // Reduce number of ephemeral ports if permitted to reduce running time of + // the test. + [[maybe_unused]] const int nports = + ASSERT_NO_ERRNO_AND_VALUE(MaybeLimitEphemeralPorts()); // Exhaust all ephemeral ports. while (true) { @@ -139,12 +148,59 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, BindAvoidsListeningPortsReuseAddr) { SyscallSucceeds()); uint16_t port = reinterpret_cast(&bound_addr)->sin_port; - // Newly bound port should not already be in use by a listening socket. - ASSERT_EQ(listen_sockets.find(port), listen_sockets.end()); - auto fd = bound_fd.get(); - listen_sockets.insert(std::make_pair(port, std::move(bound_fd))); - ASSERT_THAT(listen(fd, SOMAXCONN), SyscallSucceeds()); + auto [iter, inserted] = bound_sockets.emplace(port, std::move(bound_fd)); + ASSERT_TRUE(inserted); + } +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + UDPBindMayBindOtherBoundPortsReuseAddr) { + ProtocolTestParam const& param = GetParam(); + // UDP sockets are allowed to bind/listen on an already bound port w/ + // SO_REUSEADDR even when requesting a port from the kernel. + SKIP_IF(param.type != SOCK_DGRAM); + + DisableSave ds; // Too many syscalls. + + // A map of port to file descriptor binding the port. + std::map bound_sockets; + + // Reduce number of ephemeral ports if permitted to reduce running time of + // the test. + [[maybe_unused]] const int nports = + ASSERT_NO_ERRNO_AND_VALUE(MaybeLimitEphemeralPorts()); + + // Exhaust all ephemeral ports. + bool duplicate_binding = false; + while (true) { + // Bind the v4 loopback on a v4 socket. + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), + SyscallSucceeds()); + uint16_t port = reinterpret_cast(&bound_addr)->sin_port; + + auto [iter, inserted] = bound_sockets.emplace(port, std::move(bound_fd)); + if (!inserted) { + duplicate_binding = true; + break; + } } + ASSERT_TRUE(duplicate_binding); } INSTANTIATE_TEST_SUITE_P(AllFamilies, SocketMultiProtocolInetLoopbackTest, diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 5e36472b4..1afb1ab50 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -24,6 +24,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/time/clock.h" #include "absl/types/optional.h" #include "test/util/file_descriptor.h" @@ -1067,5 +1068,43 @@ void SetupTimeWaitClose(const TestAddress* listener, absl::SleepFor(absl::Seconds(1)); } +constexpr char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range"; + +PosixErrorOr MaybeLimitEphemeralPorts() { + int min = 0; + int max = 1 << 16; + + // Read the ephemeral range from /proc. + ASSIGN_OR_RETURN_ERRNO(std::string rangefile, GetContents(kRangeFile)); + const std::string err_msg = + absl::StrFormat("%s has invalid content: %s", kRangeFile, rangefile); + if (rangefile.back() != '\n') { + return PosixError(EINVAL, err_msg); + } + rangefile.pop_back(); + std::vector range = + absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); + if (range.size() < 2 || !absl::SimpleAtoi(range.front(), &min) || + !absl::SimpleAtoi(range.back(), &max)) { + return PosixError(EINVAL, err_msg); + } + + // If we can open as writable, limit the range. + if (!access(kRangeFile, W_OK)) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, + Open(kRangeFile, O_WRONLY | O_TRUNC, 0)); + max = min + 50; + const std::string small_range = absl::StrFormat("%d %d", min, max); + int n = write(fd.get(), small_range.c_str(), small_range.size()); + if (n < 0) { + return PosixError( + errno, + absl::StrFormat("write(%d [%s], \"%s\", %d)", fd.get(), kRangeFile, + small_range.c_str(), small_range.size())); + } + } + return max - min; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index df4c26f26..0e2be63cc 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -576,6 +576,10 @@ void SetupTimeWaitClose(const TestAddress* listener, bool accept_close, sockaddr_storage* listen_addr, sockaddr_storage* conn_bound_addr); +// MaybeLimitEphemeralPorts attempts to reduce the number of ephemeral ports and +// returns the number of ephemeral ports. +PosixErrorOr MaybeLimitEphemeralPorts(); + namespace internal { PosixErrorOr TryPortAvailable(int port, AddressFamily family, SocketType type, bool reuse_addr); -- cgit v1.2.3