From ef045b914bc8d9795f9184aed4b13351be70a3cf Mon Sep 17 00:00:00 2001
From: Kevin Krakauer <krakauer@google.com>
Date: Thu, 15 Aug 2019 16:30:25 -0700
Subject: Add tests for "cooked" AF_PACKET sockets.

PiperOrigin-RevId: 263666789
---
 test/syscalls/BUILD                     |   2 +
 test/syscalls/linux/BUILD               |  18 ++
 test/syscalls/linux/packet_socket.cc    | 299 ++++++++++++++++++++++++++++++++
 test/syscalls/linux/raw_socket_icmp.cc  |  42 +----
 test/syscalls/linux/socket_test_util.cc |  69 ++++++++
 test/syscalls/linux/socket_test_util.h  |  14 ++
 6 files changed, 410 insertions(+), 34 deletions(-)
 create mode 100644 test/syscalls/linux/packet_socket.cc

diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index aa1e33fb4..f50d83f38 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -249,6 +249,8 @@ syscall_test(
     test = "//test/syscalls/linux:open_test",
 )
 
+syscall_test(test = "//test/syscalls/linux:packet_socket_test")
+
 syscall_test(test = "//test/syscalls/linux:partial_bad_buffer_test")
 
 syscall_test(test = "//test/syscalls/linux:pause_test")
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index d28ce4ba1..db0a1e661 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1208,6 +1208,24 @@ cc_binary(
     ],
 )
 
+cc_binary(
+    name = "packet_socket_test",
+    testonly = 1,
+    srcs = ["packet_socket.cc"],
+    linkstatic = 1,
+    deps = [
+        ":socket_test_util",
+        ":unix_domain_socket_test_util",
+        "//test/util:capability_util",
+        "//test/util:file_descriptor",
+        "//test/util:test_main",
+        "//test/util:test_util",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/base:endian",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_binary(
     name = "pty_test",
     testonly = 1,
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
new file mode 100644
index 000000000..7a3379b9e
--- /dev/null
+++ b/test/syscalls/linux/packet_socket.cc
@@ -0,0 +1,299 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_packet.h>
+#include <net/ethernet.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/udp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Some of these tests involve sending packets via AF_PACKET sockets and the
+// loopback interface. Because AF_PACKET circumvents so much of the networking
+// stack, Linux sees these packets as "martian", i.e. they claim to be to/from
+// localhost but don't have the usual associated data. Thus Linux drops them by
+// default. You can see where this happens by following the code at:
+//
+// - net/ipv4/ip_input.c:ip_rcv_finish, which calls
+// - net/ipv4/route.c:ip_route_input_noref, which calls
+// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian
+//   packets.
+//
+// To tell Linux not to drop these packets, you need to tell it to accept our
+// funny packets (which are completely valid and correct, but lack associated
+// in-kernel data because we use AF_PACKET):
+//
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet
+//
+// These tests require CAP_NET_RAW to run.
+
+// TODO(gvisor.dev/issue/173): gVisor support.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr char kMessage[] = "soweoneul malhaebwa";
+constexpr in_port_t kPort = 0x409c;  // htons(40000)
+
+//
+// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer
+// headers, and provide link layer destination/source information via a
+// returned struct sockaddr_ll.
+//
+
+// Send kMessage via sock to loopback
+void SendUDPMessage(int sock) {
+  struct sockaddr_in dest = {};
+  dest.sin_port = kPort;
+  dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+  dest.sin_family = AF_INET;
+  EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0,
+                     reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+              SyscallSucceedsWithValue(sizeof(kMessage)));
+}
+
+// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
+TEST(BasicCookedPacketTest, WrongType) {
+  SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+  SKIP_IF(IsRunningOnGvisor());
+
+  FileDescriptor sock =
+      ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP));
+
+  // Let's use a simple IP payload: a UDP datagram.
+  FileDescriptor udp_sock =
+      ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+  SendUDPMessage(udp_sock.get());
+
+  // Wait and make sure the socket never becomes readable.
+  struct pollfd pfd = {};
+  pfd.fd = sock.get();
+  pfd.events = POLLIN;
+  EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets.
+class CookedPacketTest : public ::testing::TestWithParam<int> {
+ protected:
+  // Creates a socket to be used in tests.
+  void SetUp() override;
+
+  // Closes the socket created by SetUp().
+  void TearDown() override;
+
+  // Gets the device index of the loopback device.
+  int GetLoopbackIndex();
+
+  // The socket used for both reading and writing.
+  int socket_;
+};
+
+void CookedPacketTest::SetUp() {
+  SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+  SKIP_IF(IsRunningOnGvisor());
+
+  ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+              SyscallSucceeds());
+}
+
+void CookedPacketTest::TearDown() {
+  SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+  SKIP_IF(IsRunningOnGvisor());
+
+  EXPECT_THAT(close(socket_), SyscallSucceeds());
+}
+
+int CookedPacketTest::GetLoopbackIndex() {
+  struct ifreq ifr;
+  snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+  EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+  EXPECT_NE(ifr.ifr_ifindex, 0);
+  return ifr.ifr_ifindex;
+}
+
+// Receive via a packet socket.
+TEST_P(CookedPacketTest, Receive) {
+  SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+  SKIP_IF(IsRunningOnGvisor());
+
+  // Let's use a simple IP payload: a UDP datagram.
+  FileDescriptor udp_sock =
+      ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+  SendUDPMessage(udp_sock.get());
+
+  // Wait for the socket to become readable.
+  struct pollfd pfd = {};
+  pfd.fd = socket_;
+  pfd.events = POLLIN;
+  EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
+
+  // Read and verify the data.
+  constexpr size_t packet_size =
+      sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
+  char buf[64];
+  struct sockaddr_ll src = {};
+  socklen_t src_len = sizeof(src);
+  ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+                       reinterpret_cast<struct sockaddr*>(&src), &src_len),
+              SyscallSucceedsWithValue(packet_size));
+  ASSERT_EQ(src_len, sizeof(src));
+
+  // Verify the source address.
+  EXPECT_EQ(src.sll_family, AF_PACKET);
+  EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP));
+  EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
+  EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK);
+  EXPECT_EQ(src.sll_halen, ETH_ALEN);
+  // This came from the loopback device, so the address is all 0s.
+  for (int i = 0; i < src.sll_halen; i++) {
+    EXPECT_EQ(src.sll_addr[i], 0);
+  }
+
+  // Verify the IP header. We memcpy to deal with pointer aligment.
+  struct iphdr ip = {};
+  memcpy(&ip, buf, sizeof(ip));
+  EXPECT_EQ(ip.ihl, 5);
+  EXPECT_EQ(ip.version, 4);
+  EXPECT_EQ(ip.tot_len, htons(packet_size));
+  EXPECT_EQ(ip.protocol, IPPROTO_UDP);
+  EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK));
+  EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK));
+
+  // Verify the UDP header. We memcpy to deal with pointer aligment.
+  struct udphdr udp = {};
+  memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
+  EXPECT_EQ(udp.dest, kPort);
+  EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
+
+  // Verify the payload.
+  char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr));
+  EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
+}
+
+// Send via a packet socket.
+TEST_P(CookedPacketTest, Send) {
+  SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+  SKIP_IF(IsRunningOnGvisor());
+
+  // Let's send a UDP packet and receive it using a regular UDP socket.
+  FileDescriptor udp_sock =
+      ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+  struct sockaddr_in bind_addr = {};
+  bind_addr.sin_family = AF_INET;
+  bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+  bind_addr.sin_port = kPort;
+  ASSERT_THAT(
+      bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr),
+           sizeof(bind_addr)),
+      SyscallSucceeds());
+
+  // Set up the destination physical address.
+  struct sockaddr_ll dest = {};
+  dest.sll_family = AF_PACKET;
+  dest.sll_halen = ETH_ALEN;
+  dest.sll_ifindex = GetLoopbackIndex();
+  dest.sll_protocol = htons(ETH_P_IP);
+  // We're sending to the loopback device, so the address is all 0s.
+  memset(dest.sll_addr, 0x00, ETH_ALEN);
+
+  // Set up the IP header.
+  struct iphdr iphdr = {0};
+  iphdr.ihl = 5;
+  iphdr.version = 4;
+  iphdr.tos = 0;
+  iphdr.tot_len =
+      htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage));
+  // Get a pseudo-random ID. If we clash with an in-use ID the test will fail,
+  // but we have no way of getting an ID we know to be good.
+  srand(*reinterpret_cast<unsigned int*>(&iphdr));
+  iphdr.id = rand();
+  // Linux sets this bit ("do not fragment") for small packets.
+  iphdr.frag_off = 1 << 6;
+  iphdr.ttl = 64;
+  iphdr.protocol = IPPROTO_UDP;
+  iphdr.daddr = htonl(INADDR_LOOPBACK);
+  iphdr.saddr = htonl(INADDR_LOOPBACK);
+  iphdr.check = IPChecksum(iphdr);
+
+  // Set up the UDP header.
+  struct udphdr udphdr = {};
+  udphdr.source = kPort;
+  udphdr.dest = kPort;
+  udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage));
+  udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage));
+
+  // Copy both headers and the payload into our packet buffer.
+  char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)];
+  memcpy(send_buf, &iphdr, sizeof(iphdr));
+  memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr));
+  memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage));
+
+  // Send it.
+  ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0,
+                     reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+              SyscallSucceedsWithValue(sizeof(send_buf)));
+
+  // Wait for the packet to become available on both sockets.
+  struct pollfd pfd = {};
+  pfd.fd = udp_sock.get();
+  pfd.events = POLLIN;
+  ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+  pfd.fd = socket_;
+  pfd.events = POLLIN;
+  ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+
+  // Receive on the packet socket.
+  char recv_buf[sizeof(send_buf)];
+  ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0),
+              SyscallSucceedsWithValue(sizeof(recv_buf)));
+  ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0);
+
+  // Receive on the UDP socket.
+  struct sockaddr_in src;
+  socklen_t src_len = sizeof(src);
+  ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
+                       reinterpret_cast<struct sockaddr*>(&src), &src_len),
+              SyscallSucceedsWithValue(sizeof(kMessage)));
+  // Check src and payload.
+  EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0);
+  EXPECT_EQ(src.sin_family, AF_INET);
+  EXPECT_EQ(src.sin_port, kPort);
+  EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
+                         ::testing::Values(ETH_P_IP, ETH_P_ALL));
+
+}  // namespace
+
+}  // namespace testing
+}  // namespace gvisor
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 1c07bacc2..971592d7d 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -35,32 +35,6 @@ namespace testing {
 
 namespace {
 
-// Compute the internet checksum of the ICMP header (assuming no payload).
-static uint16_t Checksum(struct icmphdr* icmp) {
-  uint32_t total = 0;
-  uint16_t* num = reinterpret_cast<uint16_t*>(icmp);
-
-  // This is just the ICMP header, so there's an even number of bytes.
-  static_assert(
-      sizeof(*icmp) % sizeof(*num) == 0,
-      "sizeof(struct icmphdr) is not an integer multiple of sizeof(uint16_t)");
-  for (unsigned int i = 0; i < sizeof(*icmp); i += sizeof(*num)) {
-    total += *num;
-    num++;
-  }
-
-  // Combine the upper and lower 16 bits. This happens twice in case the first
-  // combination causes a carry.
-  unsigned short upper = total >> 16;
-  unsigned short lower = total & 0xffff;
-  total = upper + lower;
-  upper = total >> 16;
-  lower = total & 0xffff;
-  total = upper + lower;
-
-  return ~total;
-}
-
 // The size of an empty ICMP packet and IP header together.
 constexpr size_t kEmptyICMPSize = 28;
 
@@ -164,7 +138,7 @@ TEST_F(RawSocketICMPTest, SendAndReceive) {
   icmp.checksum = 0;
   icmp.un.echo.sequence = 2012;
   icmp.un.echo.id = 2014;
-  icmp.checksum = Checksum(&icmp);
+  icmp.checksum = ICMPChecksum(icmp, NULL, 0);
   ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
 
   ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
@@ -187,7 +161,7 @@ TEST_F(RawSocketICMPTest, MultipleSocketReceive) {
   icmp.checksum = 0;
   icmp.un.echo.sequence = 2016;
   icmp.un.echo.id = 2018;
-  icmp.checksum = Checksum(&icmp);
+  icmp.checksum = ICMPChecksum(icmp, NULL, 0);
   ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
 
   // Both sockets will receive the echo request and reply in indeterminate
@@ -297,7 +271,7 @@ TEST_F(RawSocketICMPTest, ShortEchoRawAndPingSockets) {
   icmp.un.echo.sequence = 0;
   icmp.un.echo.id = 6789;
   icmp.checksum = 0;
-  icmp.checksum = Checksum(&icmp);
+  icmp.checksum = ICMPChecksum(icmp, NULL, 0);
 
   // Omit 2 bytes from ICMP packet.
   constexpr int kShortICMPSize = sizeof(icmp) - 2;
@@ -338,7 +312,7 @@ TEST_F(RawSocketICMPTest, ShortEchoReplyRawAndPingSockets) {
   icmp.un.echo.sequence = 0;
   icmp.un.echo.id = 6789;
   icmp.checksum = 0;
-  icmp.checksum = Checksum(&icmp);
+  icmp.checksum = ICMPChecksum(icmp, NULL, 0);
 
   // Omit 2 bytes from ICMP packet.
   constexpr int kShortICMPSize = sizeof(icmp) - 2;
@@ -381,7 +355,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveViaConnect) {
   icmp.checksum = 0;
   icmp.un.echo.sequence = 2003;
   icmp.un.echo.id = 2004;
-  icmp.checksum = Checksum(&icmp);
+  icmp.checksum = ICMPChecksum(icmp, NULL, 0);
   ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0),
               SyscallSucceedsWithValue(sizeof(icmp)));
 
@@ -405,7 +379,7 @@ TEST_F(RawSocketICMPTest, BindSendAndReceive) {
   icmp.checksum = 0;
   icmp.un.echo.sequence = 2004;
   icmp.un.echo.id = 2007;
-  icmp.checksum = Checksum(&icmp);
+  icmp.checksum = ICMPChecksum(icmp, NULL, 0);
   ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
 
   ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
@@ -431,7 +405,7 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) {
   icmp.checksum = 0;
   icmp.un.echo.sequence = 2010;
   icmp.un.echo.id = 7;
-  icmp.checksum = Checksum(&icmp);
+  icmp.checksum = ICMPChecksum(icmp, NULL, 0);
   ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
 
   ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
@@ -471,7 +445,7 @@ void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) {
         // A couple are different.
         EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY);
         // The checksum computed over the reply should still be valid.
-        EXPECT_EQ(Checksum(recvd_icmp), 0);
+        EXPECT_EQ(ICMPChecksum(*recvd_icmp, NULL, 0), 0);
         break;
     }
   }
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 4f65cf5ae..3c716235b 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -744,5 +744,74 @@ TestAddress V6Loopback() {
   return t;
 }
 
+// Checksum computes the internet checksum of a buffer.
+uint16_t Checksum(uint16_t* buf, ssize_t buf_size) {
+  // Add up the 16-bit values in the buffer.
+  uint32_t total = 0;
+  for (unsigned int i = 0; i < buf_size; i += sizeof(*buf)) {
+    total += *buf;
+    buf++;
+  }
+
+  // If buf has an odd size, add the remaining byte.
+  if (buf_size % 2) {
+    total += *(reinterpret_cast<unsigned char*>(buf) - 1);
+  }
+
+  // This carries any bits past the lower 16 until everything fits in 16 bits.
+  while (total >> 16) {
+    uint16_t lower = total & 0xffff;
+    uint16_t upper = total >> 16;
+    total = lower + upper;
+  }
+
+  return ~total;
+}
+
+uint16_t IPChecksum(struct iphdr ip) {
+  return Checksum(reinterpret_cast<uint16_t*>(&ip), sizeof(ip));
+}
+
+// The pseudo-header defined in RFC 768 for calculating the UDP checksum.
+struct udp_pseudo_hdr {
+  uint32_t srcip;
+  uint32_t destip;
+  char zero;
+  char protocol;
+  uint16_t udplen;
+};
+
+uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
+                     const char* payload, ssize_t payload_len) {
+  struct udp_pseudo_hdr phdr = {};
+  phdr.srcip = iphdr.saddr;
+  phdr.destip = iphdr.daddr;
+  phdr.zero = 0;
+  phdr.protocol = IPPROTO_UDP;
+  phdr.udplen = udphdr.len;
+
+  ssize_t buf_size = sizeof(phdr) + sizeof(udphdr) + payload_len;
+  char* buf = static_cast<char*>(malloc(buf_size));
+  memcpy(buf, &phdr, sizeof(phdr));
+  memcpy(buf + sizeof(phdr), &udphdr, sizeof(udphdr));
+  memcpy(buf + sizeof(phdr) + sizeof(udphdr), payload, payload_len);
+
+  uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size);
+  free(buf);
+  return csum;
+}
+
+uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
+                      ssize_t payload_len) {
+  ssize_t buf_size = sizeof(icmphdr) + payload_len;
+  char* buf = static_cast<char*>(malloc(buf_size));
+  memcpy(buf, &icmphdr, sizeof(icmphdr));
+  memcpy(buf + sizeof(icmphdr), payload, payload_len);
+
+  uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size);
+  free(buf);
+  return csum;
+}
+
 }  // namespace testing
 }  // namespace gvisor
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 4fd59767a..ae0da2679 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -17,9 +17,12 @@
 
 #include <errno.h>
 #include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <netinet/udp.h>
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <sys/un.h>
+
 #include <functional>
 #include <memory>
 #include <string>
@@ -478,6 +481,17 @@ TestAddress V4MappedLoopback();
 TestAddress V6Any();
 TestAddress V6Loopback();
 
+// Compute the internet checksum of an IP header.
+uint16_t IPChecksum(struct iphdr ip);
+
+// Compute the internet checksum of a UDP header.
+uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
+                     const char* payload, ssize_t payload_len);
+
+// Compute the internet checksum of an ICMP header.
+uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
+                      ssize_t payload_len);
+
 }  // namespace testing
 }  // namespace gvisor
 
-- 
cgit v1.2.3