summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-10-11 11:44:42 -0700
committergVisor bot <gvisor-bot@google.com>2021-10-11 11:46:54 -0700
commit4ea18a8a7b72f49734a2f89be1ff7a4be87017c7 (patch)
tree582fa713c16a62b0080e6bbd560d5a048742e5e0
parent09a42f9976403e6842a291b49ac2ab3319a5d02e (diff)
Support IP_PKTINFO and IPV6_RECVPKTINFO on raw sockets
Updates #1584, #3556. PiperOrigin-RevId: 402354066
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go32
-rw-r--r--test/syscalls/linux/BUILD1
-rw-r--r--test/syscalls/linux/raw_socket.cc129
3 files changed, 160 insertions, 2 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 181b478d0..ce76774af 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -49,6 +49,7 @@ type rawPacket struct {
receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
}
// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
@@ -208,6 +209,23 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
if opts.NeedRemoteAddr {
res.RemoteAddr = pkt.senderAddr
}
+ switch netProto := e.net.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceivePacketInfo() {
+ res.ControlMessages.HasIPPacketInfo = true
+ res.ControlMessages.PacketInfo = pkt.packetInfo
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ res.ControlMessages.HasIPv6PacketInfo = true
+ res.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: pkt.packetInfo.NIC,
+ Addr: pkt.packetInfo.DestinationAddr,
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", netProto))
+ }
n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
@@ -435,7 +453,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return false
}
- srcAddr := pkt.Network().SourceAddress()
+ net := pkt.Network()
+ dstAddr := net.DestinationAddress()
+ srcAddr := net.SourceAddress()
info := e.net.Info()
switch state := e.net.State(); state {
@@ -457,7 +477,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
// If bound to an address, only accept data for that address.
- if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ if info.BindAddr != "" && info.BindAddr != dstAddr {
return false
}
default:
@@ -472,6 +492,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
NIC: pkt.NICID,
Addr: srcAddr,
},
+ packetInfo: tcpip.IPPacketInfo{
+ // TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast
+ // address. LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ LocalAddr: dstAddr,
+ DestinationAddr: dstAddr,
+ NIC: pkt.NICID,
+ },
}
// Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 45f6304e5..e37d0402f 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1980,6 +1980,7 @@ cc_binary(
defines = select_system(),
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index ef1db47ee..ef176cbee 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -25,6 +25,7 @@
#include <algorithm>
#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -39,6 +40,9 @@ namespace testing {
namespace {
+using ::testing::IsNull;
+using ::testing::NotNull;
+
// Fixture for tests parameterized by protocol.
class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
protected:
@@ -1057,6 +1061,131 @@ TEST(RawSocketTest, BindReceive) {
ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(true /* do_bind */));
}
+TEST(RawSocketTest, ReceiveIPPacketInfo) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability()));
+
+ FileDescriptor raw =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ const sockaddr_in addr_ = {
+ .sin_family = AF_INET,
+ .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)},
+ };
+ ASSERT_THAT(
+ bind(raw.get(), reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Register to receive IP packet info.
+ constexpr int one = 1;
+ ASSERT_THAT(setsockopt(raw.get(), IPPROTO_IP, IP_PKTINFO, &one, sizeof(one)),
+ SyscallSucceeds());
+
+ constexpr char send_buf[] = "malformed UDP";
+ ASSERT_THAT(sendto(raw.get(), send_buf, sizeof(send_buf), 0 /* flags */,
+ reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ struct {
+ iphdr ip;
+ char data[sizeof(send_buf)];
+
+ // Extra space in the receive buffer should be unused.
+ char unused_space;
+ } ABSL_ATTRIBUTE_PACKED recv_buf;
+ iovec recv_iov = {
+ .iov_base = &recv_buf,
+ .iov_len = sizeof(recv_buf),
+ };
+ in_pktinfo received_pktinfo;
+ char recv_cmsg_buf[CMSG_SPACE(sizeof(received_pktinfo))];
+ msghdr recv_msg = {
+ .msg_iov = &recv_iov,
+ .msg_iovlen = 1,
+ .msg_control = recv_cmsg_buf,
+ .msg_controllen = CMSG_LEN(sizeof(received_pktinfo)),
+ };
+ ASSERT_THAT(RetryEINTR(recvmsg)(raw.get(), &recv_msg, 0),
+ SyscallSucceedsWithValue(sizeof(iphdr) + sizeof(send_buf)));
+ EXPECT_EQ(memcmp(send_buf, &recv_buf.data, sizeof(send_buf)), 0);
+ EXPECT_EQ(recv_buf.ip.version, static_cast<unsigned int>(IPVERSION));
+ // IHL holds the number of header bytes in 4 byte units.
+ EXPECT_EQ(recv_buf.ip.ihl, sizeof(iphdr) / 4);
+ EXPECT_EQ(ntohs(recv_buf.ip.tot_len), sizeof(iphdr) + sizeof(send_buf));
+ EXPECT_EQ(recv_buf.ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ntohl(recv_buf.ip.saddr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(recv_buf.ip.daddr), INADDR_LOOPBACK);
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg);
+ ASSERT_THAT(cmsg, NotNull());
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(received_pktinfo)));
+ EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP);
+ EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO);
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(received_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi_ifindex,
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()));
+ EXPECT_EQ(ntohl(received_pktinfo.ipi_spec_dst.s_addr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(received_pktinfo.ipi_addr.s_addr), INADDR_LOOPBACK);
+
+ EXPECT_THAT(CMSG_NXTHDR(&recv_msg, cmsg), IsNull());
+}
+
+TEST(RawSocketTest, ReceiveIPv6PacketInfo) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability()));
+
+ FileDescriptor raw =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_UDP));
+
+ const sockaddr_in6 addr_ = {
+ .sin6_family = AF_INET6,
+ .sin6_addr = in6addr_loopback,
+ };
+ ASSERT_THAT(
+ bind(raw.get(), reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Register to receive IPv6 packet info.
+ constexpr int one = 1;
+ ASSERT_THAT(
+ setsockopt(raw.get(), IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one)),
+ SyscallSucceeds());
+
+ constexpr char send_buf[] = "malformed UDP";
+ ASSERT_THAT(sendto(raw.get(), send_buf, sizeof(send_buf), 0 /* flags */,
+ reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ char recv_buf[sizeof(send_buf) + 1];
+ iovec recv_iov = {
+ .iov_base = recv_buf,
+ .iov_len = sizeof(recv_buf),
+ };
+ in6_pktinfo received_pktinfo;
+ char recv_cmsg_buf[CMSG_SPACE(sizeof(received_pktinfo))];
+ msghdr recv_msg = {
+ .msg_iov = &recv_iov,
+ .msg_iovlen = 1,
+ .msg_control = recv_cmsg_buf,
+ .msg_controllen = CMSG_LEN(sizeof(received_pktinfo)),
+ };
+ ASSERT_THAT(RetryEINTR(recvmsg)(raw.get(), &recv_msg, 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0);
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg);
+ ASSERT_THAT(cmsg, NotNull());
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(received_pktinfo)));
+ EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IPV6);
+ EXPECT_EQ(cmsg->cmsg_type, IPV6_PKTINFO);
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(received_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi6_ifindex,
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()));
+ ASSERT_EQ(memcmp(&received_pktinfo.ipi6_addr, &in6addr_loopback,
+ sizeof(in6addr_loopback)),
+ 0);
+
+ EXPECT_THAT(CMSG_NXTHDR(&recv_msg, cmsg), IsNull());
+}
+
} // namespace
} // namespace testing