diff options
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 7 | ||||
-rw-r--r-- | test/syscalls/linux/udp_socket.cc | 81 |
2 files changed, 82 insertions, 6 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index afd8f4d39..807df2bb5 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -938,11 +938,6 @@ func (e *endpoint) Disconnect() tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { - if addr.Port == 0 { - // We don't support connecting to port zero. - return &tcpip.ErrInvalidEndpointState{} - } - e.mu.Lock() defer e.mu.Unlock() @@ -1188,7 +1183,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.EndpointState() != StateConnected { + if e.EndpointState() != StateConnected || e.dstPort == 0 { return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 650f12350..50f589708 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -2061,11 +2061,92 @@ TEST_P(UdpSocketTest, SendToZeroPort) { SyscallSucceedsWithValue(sizeof(buf))); } +TEST_P(UdpSocketTest, ConnectToZeroPortUnbound) { + struct sockaddr_storage addr = InetLoopbackAddr(); + SetPort(&addr, 0); + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, ConnectToZeroPortBound) { + struct sockaddr_storage addr = InetLoopbackAddr(); + ASSERT_NO_ERRNO( + BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr))); + + SetPort(&addr, 0); + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); + socklen_t len = sizeof(sockaddr_storage); + ASSERT_THAT( + getsockname(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), &len), + SyscallSucceeds()); + ASSERT_EQ(len, addrlen_); +} + +TEST_P(UdpSocketTest, ConnectToZeroPortConnected) { + struct sockaddr_storage addr = InetLoopbackAddr(); + ASSERT_NO_ERRNO( + BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr))); + + // Connect to an address with non-zero port should succeed. + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); + sockaddr_storage peername; + socklen_t peerlen = sizeof(peername); + ASSERT_THAT( + getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername), + &peerlen), + SyscallSucceeds()); + ASSERT_EQ(peerlen, addrlen_); + ASSERT_EQ(memcmp(&peername, &addr, addrlen_), 0); + + // However connect() to an address with port 0 will make the following + // getpeername() fail. + SetPort(&addr, 0); + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); + ASSERT_THAT( + getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername), + &peerlen), + SyscallFailsWithErrno(ENOTCONN)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, AddressFamily::kDualStack)); +TEST(UdpInet6SocketTest, ConnectInet4Sockaddr) { + // glibc getaddrinfo expects the invariant expressed by this test to be held. + const sockaddr_in connect_sockaddr = { + .sin_family = AF_INET, .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}}; + auto sock_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_THAT( + connect(sock_.get(), + reinterpret_cast<const struct sockaddr*>(&connect_sockaddr), + sizeof(sockaddr_in)), + SyscallSucceeds()); + socklen_t len; + sockaddr_storage sockname; + ASSERT_THAT(getsockname(sock_.get(), + reinterpret_cast<struct sockaddr*>(&sockname), &len), + SyscallSucceeds()); + ASSERT_EQ(sockname.ss_family, AF_INET6); + ASSERT_EQ(len, sizeof(sockaddr_in6)); + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&sockname); + char addr_buf[INET6_ADDRSTRLEN]; + const char* addr; + ASSERT_NE(addr = inet_ntop(sockname.ss_family, &sockname, addr_buf, + sizeof(addr_buf)), + nullptr); + ASSERT_TRUE(IN6_IS_ADDR_V4MAPPED(sin6->sin6_addr.s6_addr)) << addr; +} + } // namespace } // namespace testing |