diff options
Diffstat (limited to 'test/syscalls')
-rw-r--r-- | test/syscalls/linux/BUILD | 1 | ||||
-rw-r--r-- | test/syscalls/linux/packet_socket.cc | 171 |
2 files changed, 169 insertions, 3 deletions
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 85fa58970..5efb3e620 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1487,6 +1487,7 @@ cc_binary( srcs = ["packet_socket.cc"], linkstatic = 1, deps = [ + ":ip_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:socket_util", diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index c8d1e1d4a..43828a52e 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -14,10 +14,15 @@ #include <net/if.h> #include <netinet/if_ether.h> +#include <netpacket/packet.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/types.h> #include <limits> #include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/socket_util.h" @@ -27,10 +32,13 @@ namespace testing { namespace { +using ::testing::AnyOf; using ::testing::Combine; +using ::testing::Eq; using ::testing::Values; -class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { +class PacketSocketCreationTest + : public ::testing::TestWithParam<std::tuple<int, int>> { protected: void SetUp() override { if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) { @@ -42,18 +50,175 @@ class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { } }; -TEST_P(PacketSocketTest, Create) { +TEST_P(PacketSocketCreationTest, Create) { const auto [type, protocol] = GetParam(); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, type, htons(protocol))); EXPECT_GE(fd.get(), 0); } -INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest, +INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketCreationTest, Combine(Values(SOCK_DGRAM, SOCK_RAW), Values(0, 1, 255, ETH_P_IP, ETH_P_IPV6, std::numeric_limits<uint16_t>::max()))); +class PacketSocketTest : public ::testing::TestWithParam<int> { + protected: + void SetUp() override { + if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) { + ASSERT_THAT(socket(AF_PACKET, GetParam(), 0), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP() << "Missing packet socket capability"; + } + + socket_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, GetParam(), 0)); + } + + FileDescriptor socket_; +}; + +TEST_P(PacketSocketTest, RebindProtocol) { + const bool kEthHdrIncluded = GetParam() == SOCK_RAW; + + sockaddr_in udp_bind_addr = { + .sin_family = AF_INET, + .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}, + }; + + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + { + // Bind the socket so that we have something to send packets to. + // + // If we didn't do this, the UDP packets we send will be responded to with + // ICMP Destination Port Unreachable errors. + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<const sockaddr*>(&udp_bind_addr), + sizeof(udp_bind_addr)), + SyscallSucceeds()); + socklen_t addrlen = sizeof(udp_bind_addr); + ASSERT_THAT( + getsockname(udp_sock.get(), reinterpret_cast<sockaddr*>(&udp_bind_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_THAT(addrlen, sizeof(udp_bind_addr)); + } + + const int loopback_index = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); + + auto send_udp_message = [&](const uint64_t v) { + ASSERT_THAT( + sendto(udp_sock.get(), reinterpret_cast<const char*>(&v), sizeof(v), + 0 /* flags */, reinterpret_cast<const sockaddr*>(&udp_bind_addr), + sizeof(udp_bind_addr)), + SyscallSucceeds()); + }; + + auto bind_to_network_protocol = [&](uint16_t protocol) { + const sockaddr_ll packet_bind_addr = { + .sll_family = AF_PACKET, + .sll_protocol = htons(protocol), + .sll_ifindex = loopback_index, + }; + + ASSERT_THAT(bind(socket_.get(), + reinterpret_cast<const sockaddr*>(&packet_bind_addr), + sizeof(packet_bind_addr)), + SyscallSucceeds()); + }; + + auto test_recv = [&, this](const uint64_t v) { + constexpr int kInfiniteTimeout = -1; + pollfd pfd = { + .fd = socket_.get(), + .events = POLLIN, + }; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfiniteTimeout), + SyscallSucceedsWithValue(1)); + + struct { + ethhdr eth; + iphdr ip; + udphdr udp; + uint64_t payload; + char unused; + } ABSL_ATTRIBUTE_PACKED read_pkt; + sockaddr_ll src; + socklen_t src_len = sizeof(src); + + char* buf = reinterpret_cast<char*>(&read_pkt); + size_t buflen = sizeof(read_pkt); + size_t expected_read_len = sizeof(read_pkt) - sizeof(read_pkt.unused); + if (!kEthHdrIncluded) { + buf += sizeof(read_pkt.eth); + buflen -= sizeof(read_pkt.eth); + expected_read_len -= sizeof(read_pkt.eth); + } + + ASSERT_THAT(recvfrom(socket_.get(), buf, buflen, 0, + reinterpret_cast<sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(expected_read_len)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but returns sizeof(sockaddr_ll) since + // https://github.com/torvalds/linux/commit/b2cf86e1563e33a14a1c69b3e508d15dc12f804c. + ASSERT_THAT(src_len, ::testing::AnyOf( + ::testing::Eq(sizeof(src)), + ::testing::Eq(sizeof(src) - sizeof(src.sll_addr) + + ETH_ALEN))); + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, loopback_index); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + // This came from the loopback device, so the address is all 0s. + constexpr uint8_t allZeroesMAC[ETH_ALEN] = {}; + EXPECT_EQ(memcmp(src.sll_addr, allZeroesMAC, sizeof(allZeroesMAC)), 0); + if (kEthHdrIncluded) { + EXPECT_EQ(memcmp(read_pkt.eth.h_dest, allZeroesMAC, sizeof(allZeroesMAC)), + 0); + EXPECT_EQ( + memcmp(read_pkt.eth.h_source, allZeroesMAC, sizeof(allZeroesMAC)), 0); + EXPECT_EQ(ntohs(read_pkt.eth.h_proto), ETH_P_IP); + } + // IHL hold the size of the header in 4 byte units. + EXPECT_EQ(read_pkt.ip.ihl, sizeof(iphdr) / 4); + EXPECT_EQ(read_pkt.ip.version, IPVERSION); + const uint16_t ip_pkt_size = + sizeof(read_pkt) - sizeof(read_pkt.eth) - sizeof(read_pkt.unused); + EXPECT_EQ(ntohs(read_pkt.ip.tot_len), ip_pkt_size); + EXPECT_EQ(read_pkt.ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ntohl(read_pkt.ip.daddr), INADDR_LOOPBACK); + EXPECT_EQ(ntohl(read_pkt.ip.saddr), INADDR_LOOPBACK); + EXPECT_EQ(read_pkt.udp.source, udp_bind_addr.sin_port); + EXPECT_EQ(read_pkt.udp.dest, udp_bind_addr.sin_port); + EXPECT_EQ(ntohs(read_pkt.udp.len), ip_pkt_size - sizeof(read_pkt.ip)); + EXPECT_EQ(read_pkt.payload, v); + }; + + // The packet socket is not bound to IPv4 so we should not receive the sent + // message. + uint64_t counter = 0; + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + + // Bind to IPv4 and expect to receive the UDP packet we send after binding. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + ASSERT_NO_FATAL_FAILURE(test_recv(counter)); + + // Bind the packet socket to a random protocol. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(255)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + + // Bind back to IPv4 and expect to the UDP packet we send after binding + // back to IPv4. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + ASSERT_NO_FATAL_FAILURE(test_recv(counter)); +} + +INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest, + Values(SOCK_DGRAM, SOCK_RAW)); + } // namespace } // namespace testing |