From 226e7d32cb855e69b3bf7a28791a17235074e49a Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 13 Sep 2021 19:40:40 -0700 Subject: Accept packets destined to bound address ...if bound to an address. We previously checked the source of a packet instead of the destination of a packet when bound to an address. PiperOrigin-RevId: 396497647 --- pkg/tcpip/transport/raw/endpoint.go | 9 ++- test/syscalls/linux/raw_socket.cc | 118 ++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 5 deletions(-) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 3bf6c0a8f..4a5858bdd 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -562,8 +562,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - remoteAddr := pkt.Network().SourceAddress() - if e.bound { // If bound to a NIC, only accept data for that NIC. if e.BindNICID != 0 && e.BindNICID != pkt.NICID { @@ -572,16 +570,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != remoteAddr { + if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() { e.rcvMu.Unlock() e.mu.RUnlock() return } } + srcAddr := pkt.Network().SourceAddress() // If connected, only accept packets from the remote address we // connected to. - if e.connected && e.route.RemoteAddress() != remoteAddr { + if e.connected && e.route.RemoteAddress() != srcAddr { e.rcvMu.Unlock() e.mu.RUnlock() return @@ -593,7 +592,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: remoteAddr, + Addr: srcAddr, }, } diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 66f0e6ca4..f0eb7cc4a 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -939,6 +939,124 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP), ::testing::Values(AF_INET, AF_INET6))); +void TestRawSocketMaybeBindReceive(bool do_bind) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability())); + + constexpr char payload[] = "abcdefgh"; + + const sockaddr_in addr = { + .sin_family = AF_INET, + .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}, + }; + + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, SOL_UDP)); + sockaddr_in udp_sock_bind_addr = addr; + socklen_t udp_sock_bind_addr_len = sizeof(udp_sock_bind_addr); + ASSERT_THAT(bind(udp_sock.get(), + reinterpret_cast(&udp_sock_bind_addr), + sizeof(udp_sock_bind_addr)), + SyscallSucceeds()); + ASSERT_THAT(getsockname(udp_sock.get(), + reinterpret_cast(&udp_sock_bind_addr), + &udp_sock_bind_addr_len), + SyscallSucceeds()); + ASSERT_EQ(udp_sock_bind_addr_len, sizeof(udp_sock_bind_addr)); + + FileDescriptor raw_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, SOL_UDP)); + + auto test_recv = [&](const char* scope, uint32_t expected_destination) { + SCOPED_TRACE(scope); + + constexpr int kInfinitePollTimeout = -1; + pollfd pfd = { + .fd = raw_sock.get(), + .events = POLLIN, + }; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfinitePollTimeout), + SyscallSucceedsWithValue(1)); + + struct ipv4_udp_packet { + iphdr ip; + udphdr udp; + char data[sizeof(payload)]; + + // Used to make sure only the required space is used. + char unused_space; + } ABSL_ATTRIBUTE_PACKED; + constexpr size_t kExpectedIPPacketSize = + offsetof(ipv4_udp_packet, unused_space); + + // Receive the whole IPv4 packet on the raw socket. + ipv4_udp_packet read_raw_packet; + sockaddr_in peer; + socklen_t peerlen = sizeof(peer); + ASSERT_EQ( + recvfrom(raw_sock.get(), reinterpret_cast(&read_raw_packet), + sizeof(read_raw_packet), 0 /* flags */, + reinterpret_cast(&peer), &peerlen), + static_cast(kExpectedIPPacketSize)) + << strerror(errno); + ASSERT_EQ(peerlen, sizeof(peer)); + EXPECT_EQ(read_raw_packet.ip.version, static_cast(IPVERSION)); + // IHL holds the number of header bytes in 4 byte units. + EXPECT_EQ(read_raw_packet.ip.ihl, sizeof(read_raw_packet.ip) / 4); + EXPECT_EQ(ntohs(read_raw_packet.ip.tot_len), kExpectedIPPacketSize); + EXPECT_EQ(ntohs(read_raw_packet.ip.frag_off) & IP_OFFMASK, 0); + EXPECT_EQ(read_raw_packet.ip.protocol, SOL_UDP); + EXPECT_EQ(ntohl(read_raw_packet.ip.saddr), INADDR_LOOPBACK); + EXPECT_EQ(ntohl(read_raw_packet.ip.daddr), expected_destination); + EXPECT_EQ(read_raw_packet.udp.source, udp_sock_bind_addr.sin_port); + EXPECT_EQ(read_raw_packet.udp.dest, udp_sock_bind_addr.sin_port); + EXPECT_EQ(ntohs(read_raw_packet.udp.len), + kExpectedIPPacketSize - sizeof(read_raw_packet.ip)); + for (size_t i = 0; i < sizeof(payload); i++) { + EXPECT_EQ(read_raw_packet.data[i], payload[i]) + << "byte mismatch @ idx=" << i; + } + EXPECT_EQ(peer.sin_family, AF_INET); + EXPECT_EQ(peer.sin_port, 0); + EXPECT_EQ(ntohl(peer.sin_addr.s_addr), INADDR_LOOPBACK); + }; + + if (do_bind) { + ASSERT_THAT(bind(raw_sock.get(), reinterpret_cast(&addr), + sizeof(addr)), + SyscallSucceeds()); + } + + constexpr int kSendToFlags = 0; + sockaddr_in different_addr = udp_sock_bind_addr; + different_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK + 1); + ASSERT_THAT(sendto(udp_sock.get(), payload, sizeof(payload), kSendToFlags, + reinterpret_cast(&different_addr), + sizeof(different_addr)), + SyscallSucceedsWithValue(sizeof(payload))); + if (!do_bind) { + ASSERT_NO_FATAL_FAILURE( + test_recv("different_addr", ntohl(different_addr.sin_addr.s_addr))); + } + ASSERT_THAT(sendto(udp_sock.get(), payload, sizeof(payload), kSendToFlags, + reinterpret_cast(&udp_sock_bind_addr), + sizeof(udp_sock_bind_addr)), + SyscallSucceedsWithValue(sizeof(payload))); + ASSERT_NO_FATAL_FAILURE( + test_recv("addr", ntohl(udp_sock_bind_addr.sin_addr.s_addr))); +} + +TEST(RawSocketTest, UnboundReceive) { + // Test that a raw socket receives packets destined to any address if it is + // not bound to an address. + ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(false /* do_bind */)); +} + +TEST(RawSocketTest, BindReceive) { + // Test that a raw socket only receives packets destined to the address it is + // bound to. + ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(true /* do_bind */)); +} + } // namespace } // namespace testing -- cgit v1.2.3