summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go24
-rw-r--r--test/syscalls/linux/udp_socket.cc160
2 files changed, 139 insertions, 45 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 91f89a781..52a5df691 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -241,12 +241,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto, err := e.checkV4Mapped(&addr, false)
- if err != nil {
- return stack.Route{}, 0, 0, err
- }
-
+func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.id.LocalAddress
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
if nicid == 0 {
@@ -260,9 +255,9 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac
// Find a route to the desired destination.
r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
- return stack.Route{}, 0, 0, err
+ return stack.Route{}, 0, err
}
- return r, nicid, netProto, nil
+ return r, nicid, nil
}
// Write writes data to the endpoint's peer. This method does not block
@@ -336,7 +331,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- r, _, _, err := e.connectRoute(nicid, *to)
+ netProto, err := e.checkV4Mapped(to, false)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ r, _, err := e.connectRoute(nicid, *to, netProto)
if err != nil {
return 0, nil, err
}
@@ -740,6 +740,10 @@ func (e *endpoint) disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
if addr.Addr == "" {
return e.disconnect()
}
@@ -770,7 +774,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- r, nicid, netProto, err := e.connectRoute(nicid, addr)
+ r, nicid, err := e.connectRoute(nicid, addr, netProto)
if err != nil {
return err
}
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 1bb0307c4..6ffb65168 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -39,7 +39,7 @@ constexpr int TestPort = 40000;
// Fixture for tests parameterized by the address family to use (AF_INET and
// AF_INET6) when creating sockets.
-class UdpSocketTest : public ::testing::TestWithParam<int> {
+class UdpSocketTest : public ::testing::TestWithParam<AddressFamily> {
protected:
// Creates two sockets that will be used by test cases.
void SetUp() override;
@@ -97,31 +97,32 @@ uint16_t* Port(struct sockaddr_storage* addr) {
}
void UdpSocketTest::SetUp() {
- ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP),
- SyscallSucceeds());
+ int type;
+ if (GetParam() == AddressFamily::kIpv4) {
+ type = AF_INET;
+ auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_);
+ addrlen_ = sizeof(*sin);
+ sin->sin_addr.s_addr = htonl(INADDR_ANY);
+ } else {
+ type = AF_INET6;
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_);
+ addrlen_ = sizeof(*sin6);
+ if (GetParam() == AddressFamily::kIpv6) {
+ sin6->sin6_addr = IN6ADDR_ANY_INIT;
+ } else {
+ TestAddress const& v4_mapped_any = V4MappedAny();
+ sin6->sin6_addr =
+ reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
+ ->sin6_addr;
+ }
+ }
+ ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
- ASSERT_THAT(t_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP),
- SyscallSucceeds());
+ ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_));
anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_);
- anyaddr_->sa_family = GetParam();
-
- // Initialize address-family-specific values.
- switch (GetParam()) {
- case AF_INET: {
- auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_);
- addrlen_ = sizeof(*sin);
- sin->sin_addr.s_addr = htonl(INADDR_ANY);
- break;
- }
- case AF_INET6: {
- auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_);
- addrlen_ = sizeof(*sin6);
- sin6->sin6_addr = in6addr_any;
- break;
- }
- }
+ anyaddr_->sa_family = type;
if (gvisor::testing::IsRunningOnGvisor()) {
for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) {
@@ -154,9 +155,9 @@ void UdpSocketTest::SetUp() {
memset(&addr_storage_[i], 0, sizeof(addr_storage_[i]));
addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]);
- addr_[i]->sa_family = GetParam();
+ addr_[i]->sa_family = type;
- switch (GetParam()) {
+ switch (type) {
case AF_INET: {
auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]);
sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
@@ -174,17 +175,20 @@ void UdpSocketTest::SetUp() {
}
TEST_P(UdpSocketTest, Creation) {
+ int type = AF_INET6;
+ if (GetParam() == AddressFamily::kIpv4) {
+ type = AF_INET;
+ }
+
int s_;
- ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP),
- SyscallSucceeds());
+ ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
EXPECT_THAT(close(s_), SyscallSucceeds());
- ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, 0), SyscallSucceeds());
+ ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds());
EXPECT_THAT(close(s_), SyscallSucceeds());
- ASSERT_THAT(s_ = socket(GetParam(), SOCK_STREAM, IPPROTO_UDP),
- SyscallFails());
+ ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails());
}
TEST_P(UdpSocketTest, Getsockname) {
@@ -374,6 +378,92 @@ TEST_P(UdpSocketTest, Connect) {
EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
}
+TEST_P(UdpSocketTest, ConnectAny) {
+ struct sockaddr_storage addr = {};
+
+ // Precondition check.
+ {
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ if (GetParam() == AddressFamily::kIpv4) {
+ auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
+ } else {
+ auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ struct in6_addr any = IN6ADDR_ANY_INIT;
+ EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0);
+ }
+
+ {
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ }
+
+ struct sockaddr_storage baddr = {};
+ if (GetParam() == AddressFamily::kIpv4) {
+ auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
+ addrlen = sizeof(*addr_in);
+ addr_in->sin_family = AF_INET;
+ addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
+ } else {
+ auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
+ addrlen = sizeof(*addr_in);
+ addr_in->sin6_family = AF_INET6;
+ if (GetParam() == AddressFamily::kIpv6) {
+ addr_in->sin6_addr = IN6ADDR_ANY_INIT;
+ } else {
+ TestAddress const& v4_mapped_any = V4MappedAny();
+ addr_in->sin6_addr =
+ reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
+ ->sin6_addr;
+ }
+ }
+
+ ASSERT_THAT(connect(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen),
+ SyscallSucceeds());
+ }
+
+ // TODO(b/138658473): gVisor doesn't return the correct local address after
+ // connecting to the any address.
+ SKIP_IF(IsRunningOnGvisor());
+
+ // Postcondition check.
+ {
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ if (GetParam() == AddressFamily::kIpv4) {
+ auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+ } else {
+ auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ struct in6_addr loopback;
+ if (GetParam() == AddressFamily::kIpv6) {
+ loopback = IN6ADDR_LOOPBACK_INIT;
+ } else {
+ TestAddress const& v4_mapped_loopback = V4MappedLoopback();
+ loopback = reinterpret_cast<const struct sockaddr_in6*>(
+ &v4_mapped_loopback.addr)
+ ->sin6_addr;
+ }
+
+ EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
+ }
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ }
+}
+
TEST_P(UdpSocketTest, DisconnectAfterBind) {
ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
// Connect the socket.
@@ -402,19 +492,17 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
struct sockaddr_storage baddr = {};
socklen_t addrlen;
auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
- if (addr_[0]->sa_family == AF_INET) {
+ if (GetParam() == AddressFamily::kIpv4) {
auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
addr_in->sin_family = AF_INET;
addr_in->sin_port = port;
- inet_pton(AF_INET, "0.0.0.0",
- reinterpret_cast<void*>(&addr_in->sin_addr.s_addr));
+ addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
} else {
auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
addr_in->sin6_family = AF_INET6;
addr_in->sin6_port = port;
- inet_pton(AF_INET6,
- "::", reinterpret_cast<void*>(&addr_in->sin6_addr.s6_addr));
addr_in->sin6_scope_id = 0;
+ addr_in->sin6_addr = IN6ADDR_ANY_INIT;
}
ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_),
SyscallSucceeds());
@@ -1165,7 +1253,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
}
INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest,
- ::testing::Values(AF_INET, AF_INET6));
+ ::testing::Values(AddressFamily::kIpv4,
+ AddressFamily::kIpv6,
+ AddressFamily::kDualStack));
} // namespace