summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go9
-rw-r--r--test/syscalls/linux/raw_socket.cc118
2 files changed, 122 insertions, 5 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 3bf6c0a8f..4a5858bdd 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -562,8 +562,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- remoteAddr := pkt.Network().SourceAddress()
-
if e.bound {
// If bound to a NIC, only accept data for that NIC.
if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
@@ -572,16 +570,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
// If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != remoteAddr {
+ if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() {
e.rcvMu.Unlock()
e.mu.RUnlock()
return
}
}
+ srcAddr := pkt.Network().SourceAddress()
// If connected, only accept packets from the remote address we
// connected to.
- if e.connected && e.route.RemoteAddress() != remoteAddr {
+ if e.connected && e.route.RemoteAddress() != srcAddr {
e.rcvMu.Unlock()
e.mu.RUnlock()
return
@@ -593,7 +592,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
packet := &rawPacket{
senderAddr: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: remoteAddr,
+ Addr: srcAddr,
},
}
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 66f0e6ca4..f0eb7cc4a 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -939,6 +939,124 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
::testing::Values(AF_INET, AF_INET6)));
+void TestRawSocketMaybeBindReceive(bool do_bind) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability()));
+
+ constexpr char payload[] = "abcdefgh";
+
+ const sockaddr_in addr = {
+ .sin_family = AF_INET,
+ .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)},
+ };
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, SOL_UDP));
+ sockaddr_in udp_sock_bind_addr = addr;
+ socklen_t udp_sock_bind_addr_len = sizeof(udp_sock_bind_addr);
+ ASSERT_THAT(bind(udp_sock.get(),
+ reinterpret_cast<const sockaddr*>(&udp_sock_bind_addr),
+ sizeof(udp_sock_bind_addr)),
+ SyscallSucceeds());
+ ASSERT_THAT(getsockname(udp_sock.get(),
+ reinterpret_cast<sockaddr*>(&udp_sock_bind_addr),
+ &udp_sock_bind_addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(udp_sock_bind_addr_len, sizeof(udp_sock_bind_addr));
+
+ FileDescriptor raw_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, SOL_UDP));
+
+ auto test_recv = [&](const char* scope, uint32_t expected_destination) {
+ SCOPED_TRACE(scope);
+
+ constexpr int kInfinitePollTimeout = -1;
+ pollfd pfd = {
+ .fd = raw_sock.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfinitePollTimeout),
+ SyscallSucceedsWithValue(1));
+
+ struct ipv4_udp_packet {
+ iphdr ip;
+ udphdr udp;
+ char data[sizeof(payload)];
+
+ // Used to make sure only the required space is used.
+ char unused_space;
+ } ABSL_ATTRIBUTE_PACKED;
+ constexpr size_t kExpectedIPPacketSize =
+ offsetof(ipv4_udp_packet, unused_space);
+
+ // Receive the whole IPv4 packet on the raw socket.
+ ipv4_udp_packet read_raw_packet;
+ sockaddr_in peer;
+ socklen_t peerlen = sizeof(peer);
+ ASSERT_EQ(
+ recvfrom(raw_sock.get(), reinterpret_cast<char*>(&read_raw_packet),
+ sizeof(read_raw_packet), 0 /* flags */,
+ reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ static_cast<ssize_t>(kExpectedIPPacketSize))
+ << strerror(errno);
+ ASSERT_EQ(peerlen, sizeof(peer));
+ EXPECT_EQ(read_raw_packet.ip.version, static_cast<unsigned int>(IPVERSION));
+ // IHL holds the number of header bytes in 4 byte units.
+ EXPECT_EQ(read_raw_packet.ip.ihl, sizeof(read_raw_packet.ip) / 4);
+ EXPECT_EQ(ntohs(read_raw_packet.ip.tot_len), kExpectedIPPacketSize);
+ EXPECT_EQ(ntohs(read_raw_packet.ip.frag_off) & IP_OFFMASK, 0);
+ EXPECT_EQ(read_raw_packet.ip.protocol, SOL_UDP);
+ EXPECT_EQ(ntohl(read_raw_packet.ip.saddr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(read_raw_packet.ip.daddr), expected_destination);
+ EXPECT_EQ(read_raw_packet.udp.source, udp_sock_bind_addr.sin_port);
+ EXPECT_EQ(read_raw_packet.udp.dest, udp_sock_bind_addr.sin_port);
+ EXPECT_EQ(ntohs(read_raw_packet.udp.len),
+ kExpectedIPPacketSize - sizeof(read_raw_packet.ip));
+ for (size_t i = 0; i < sizeof(payload); i++) {
+ EXPECT_EQ(read_raw_packet.data[i], payload[i])
+ << "byte mismatch @ idx=" << i;
+ }
+ EXPECT_EQ(peer.sin_family, AF_INET);
+ EXPECT_EQ(peer.sin_port, 0);
+ EXPECT_EQ(ntohl(peer.sin_addr.s_addr), INADDR_LOOPBACK);
+ };
+
+ if (do_bind) {
+ ASSERT_THAT(bind(raw_sock.get(), reinterpret_cast<const sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ }
+
+ constexpr int kSendToFlags = 0;
+ sockaddr_in different_addr = udp_sock_bind_addr;
+ different_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK + 1);
+ ASSERT_THAT(sendto(udp_sock.get(), payload, sizeof(payload), kSendToFlags,
+ reinterpret_cast<const sockaddr*>(&different_addr),
+ sizeof(different_addr)),
+ SyscallSucceedsWithValue(sizeof(payload)));
+ if (!do_bind) {
+ ASSERT_NO_FATAL_FAILURE(
+ test_recv("different_addr", ntohl(different_addr.sin_addr.s_addr)));
+ }
+ ASSERT_THAT(sendto(udp_sock.get(), payload, sizeof(payload), kSendToFlags,
+ reinterpret_cast<const sockaddr*>(&udp_sock_bind_addr),
+ sizeof(udp_sock_bind_addr)),
+ SyscallSucceedsWithValue(sizeof(payload)));
+ ASSERT_NO_FATAL_FAILURE(
+ test_recv("addr", ntohl(udp_sock_bind_addr.sin_addr.s_addr)));
+}
+
+TEST(RawSocketTest, UnboundReceive) {
+ // Test that a raw socket receives packets destined to any address if it is
+ // not bound to an address.
+ ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(false /* do_bind */));
+}
+
+TEST(RawSocketTest, BindReceive) {
+ // Test that a raw socket only receives packets destined to the address it is
+ // bound to.
+ ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(true /* do_bind */));
+}
+
} // namespace
} // namespace testing