diff options
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 32 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 1 | ||||
-rw-r--r-- | test/syscalls/linux/raw_socket.cc | 129 |
3 files changed, 160 insertions, 2 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 181b478d0..ce76774af 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -49,6 +49,7 @@ type rawPacket struct { receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + packetInfo tcpip.IPPacketInfo } // endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to @@ -208,6 +209,23 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult if opts.NeedRemoteAddr { res.RemoteAddr = pkt.senderAddr } + switch netProto := e.net.NetProto(); netProto { + case header.IPv4ProtocolNumber: + if e.ops.GetReceivePacketInfo() { + res.ControlMessages.HasIPPacketInfo = true + res.ControlMessages.PacketInfo = pkt.packetInfo + } + case header.IPv6ProtocolNumber: + if e.ops.GetIPv6ReceivePacketInfo() { + res.ControlMessages.HasIPv6PacketInfo = true + res.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{ + NIC: pkt.packetInfo.NIC, + Addr: pkt.packetInfo.DestinationAddr, + } + } + default: + panic(fmt.Sprintf("unrecognized network protocol = %d", netProto)) + } n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { @@ -435,7 +453,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return false } - srcAddr := pkt.Network().SourceAddress() + net := pkt.Network() + dstAddr := net.DestinationAddress() + srcAddr := net.SourceAddress() info := e.net.Info() switch state := e.net.State(); state { @@ -457,7 +477,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } // If bound to an address, only accept data for that address. - if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() { + if info.BindAddr != "" && info.BindAddr != dstAddr { return false } default: @@ -472,6 +492,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { NIC: pkt.NICID, Addr: srcAddr, }, + packetInfo: tcpip.IPPacketInfo{ + // TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast + // address. LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + LocalAddr: dstAddr, + DestinationAddr: dstAddr, + NIC: pkt.NICID, + }, } // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 45f6304e5..e37d0402f 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1980,6 +1980,7 @@ cc_binary( defines = select_system(), linkstatic = 1, deps = [ + ":ip_socket_test_util", ":unix_domain_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index ef1db47ee..ef176cbee 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -25,6 +25,7 @@ #include <algorithm> #include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" @@ -39,6 +40,9 @@ namespace testing { namespace { +using ::testing::IsNull; +using ::testing::NotNull; + // Fixture for tests parameterized by protocol. class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { protected: @@ -1057,6 +1061,131 @@ TEST(RawSocketTest, BindReceive) { ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(true /* do_bind */)); } +TEST(RawSocketTest, ReceiveIPPacketInfo) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability())); + + FileDescriptor raw = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + const sockaddr_in addr_ = { + .sin_family = AF_INET, + .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}, + }; + ASSERT_THAT( + bind(raw.get(), reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + + // Register to receive IP packet info. + constexpr int one = 1; + ASSERT_THAT(setsockopt(raw.get(), IPPROTO_IP, IP_PKTINFO, &one, sizeof(one)), + SyscallSucceeds()); + + constexpr char send_buf[] = "malformed UDP"; + ASSERT_THAT(sendto(raw.get(), send_buf, sizeof(send_buf), 0 /* flags */, + reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + struct { + iphdr ip; + char data[sizeof(send_buf)]; + + // Extra space in the receive buffer should be unused. + char unused_space; + } ABSL_ATTRIBUTE_PACKED recv_buf; + iovec recv_iov = { + .iov_base = &recv_buf, + .iov_len = sizeof(recv_buf), + }; + in_pktinfo received_pktinfo; + char recv_cmsg_buf[CMSG_SPACE(sizeof(received_pktinfo))]; + msghdr recv_msg = { + .msg_iov = &recv_iov, + .msg_iovlen = 1, + .msg_control = recv_cmsg_buf, + .msg_controllen = CMSG_LEN(sizeof(received_pktinfo)), + }; + ASSERT_THAT(RetryEINTR(recvmsg)(raw.get(), &recv_msg, 0), + SyscallSucceedsWithValue(sizeof(iphdr) + sizeof(send_buf))); + EXPECT_EQ(memcmp(send_buf, &recv_buf.data, sizeof(send_buf)), 0); + EXPECT_EQ(recv_buf.ip.version, static_cast<unsigned int>(IPVERSION)); + // IHL holds the number of header bytes in 4 byte units. + EXPECT_EQ(recv_buf.ip.ihl, sizeof(iphdr) / 4); + EXPECT_EQ(ntohs(recv_buf.ip.tot_len), sizeof(iphdr) + sizeof(send_buf)); + EXPECT_EQ(recv_buf.ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ntohl(recv_buf.ip.saddr), INADDR_LOOPBACK); + EXPECT_EQ(ntohl(recv_buf.ip.daddr), INADDR_LOOPBACK); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg); + ASSERT_THAT(cmsg, NotNull()); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(received_pktinfo))); + EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP); + EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO); + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(received_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, + ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex())); + EXPECT_EQ(ntohl(received_pktinfo.ipi_spec_dst.s_addr), INADDR_LOOPBACK); + EXPECT_EQ(ntohl(received_pktinfo.ipi_addr.s_addr), INADDR_LOOPBACK); + + EXPECT_THAT(CMSG_NXTHDR(&recv_msg, cmsg), IsNull()); +} + +TEST(RawSocketTest, ReceiveIPv6PacketInfo) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability())); + + FileDescriptor raw = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_UDP)); + + const sockaddr_in6 addr_ = { + .sin6_family = AF_INET6, + .sin6_addr = in6addr_loopback, + }; + ASSERT_THAT( + bind(raw.get(), reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + + // Register to receive IPv6 packet info. + constexpr int one = 1; + ASSERT_THAT( + setsockopt(raw.get(), IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one)), + SyscallSucceeds()); + + constexpr char send_buf[] = "malformed UDP"; + ASSERT_THAT(sendto(raw.get(), send_buf, sizeof(send_buf), 0 /* flags */, + reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + char recv_buf[sizeof(send_buf) + 1]; + iovec recv_iov = { + .iov_base = recv_buf, + .iov_len = sizeof(recv_buf), + }; + in6_pktinfo received_pktinfo; + char recv_cmsg_buf[CMSG_SPACE(sizeof(received_pktinfo))]; + msghdr recv_msg = { + .msg_iov = &recv_iov, + .msg_iovlen = 1, + .msg_control = recv_cmsg_buf, + .msg_controllen = CMSG_LEN(sizeof(received_pktinfo)), + }; + ASSERT_THAT(RetryEINTR(recvmsg)(raw.get(), &recv_msg, 0), + SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg); + ASSERT_THAT(cmsg, NotNull()); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(received_pktinfo))); + EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IPV6); + EXPECT_EQ(cmsg->cmsg_type, IPV6_PKTINFO); + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(received_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi6_ifindex, + ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex())); + ASSERT_EQ(memcmp(&received_pktinfo.ipi6_addr, &in6addr_loopback, + sizeof(in6addr_loopback)), + 0); + + EXPECT_THAT(CMSG_NXTHDR(&recv_msg, cmsg), IsNull()); +} + } // namespace } // namespace testing |