diff options
Diffstat (limited to 'test/syscalls/linux')
-rw-r--r-- | test/syscalls/linux/BUILD | 58 | ||||
-rw-r--r-- | test/syscalls/linux/futex.cc | 16 | ||||
-rw-r--r-- | test/syscalls/linux/iptables.cc | 204 | ||||
-rw-r--r-- | test/syscalls/linux/packet_socket.cc | 299 | ||||
-rw-r--r-- | test/syscalls/linux/packet_socket_raw.cc | 314 | ||||
-rw-r--r-- | test/syscalls/linux/partial_bad_buffer.cc | 110 | ||||
-rw-r--r-- | test/syscalls/linux/proc_net_tcp.cc | 3 | ||||
-rw-r--r-- | test/syscalls/linux/pty_root.cc | 2 | ||||
-rw-r--r-- | test/syscalls/linux/raw_socket_icmp.cc | 42 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_udp_unbound.cc | 254 | ||||
-rw-r--r-- | test/syscalls/linux/socket_test_util.cc | 69 | ||||
-rw-r--r-- | test/syscalls/linux/socket_test_util.h | 14 | ||||
-rw-r--r-- | test/syscalls/linux/udp_socket.cc | 124 |
13 files changed, 1452 insertions, 57 deletions
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 16666e772..ca4344139 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -913,6 +913,24 @@ cc_library( ) cc_binary( + name = "iptables_test", + testonly = 1, + srcs = [ + "iptables.cc", + ], + linkstatic = 1, + deps = [ + ":iptables_types", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "itimer_test", testonly = 1, srcs = ["itimer.cc"], @@ -1209,6 +1227,42 @@ cc_binary( ) cc_binary( + name = "packet_socket_raw_test", + testonly = 1, + srcs = ["packet_socket_raw.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "packet_socket_test", + testonly = 1, + srcs = ["packet_socket.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "pty_test", testonly = 1, srcs = ["pty.cc"], @@ -1252,10 +1306,14 @@ cc_binary( srcs = ["partial_bad_buffer.cc"], linkstatic = 1, deps = [ + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", "//test/util:fs_util", + "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index aacbb5e70..d3e3f998c 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -125,6 +125,10 @@ int futex_lock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int zero = 0; + if (uaddr->compare_exchange_strong(zero, gettid())) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -133,6 +137,10 @@ int futex_trylock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int zero = 0; + if (uaddr->compare_exchange_strong(zero, gettid())) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -141,6 +149,10 @@ int futex_unlock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int tid = gettid(); + if (uaddr->compare_exchange_strong(tid, 0)) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -689,11 +701,11 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { std::atomic<int> a = ATOMIC_VAR_INIT(0); const bool is_priv = IsPrivate(); - std::unique_ptr<ScopedThread> threads[100]; + std::unique_ptr<ScopedThread> threads[10]; for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) { threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] { for (size_t j = 0; j < 10;) { - if (futex_trylock_pi(is_priv, &a) >= 0) { + if (futex_trylock_pi(is_priv, &a) == 0) { ++j; EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid()); SleepSafe(absl::Milliseconds(5)); diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc new file mode 100644 index 000000000..b8e4ece64 --- /dev/null +++ b/test/syscalls/linux/iptables.cc @@ -0,0 +1,204 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/iptables.h" + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/netfilter/x_tables.h> +#include <net/if.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <stdio.h> +#include <sys/poll.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <algorithm> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kNatTablename[] = "nat"; +constexpr char kErrorTarget[] = "ERROR"; +constexpr size_t kEmptyStandardEntrySize = + sizeof(struct ipt_entry) + sizeof(struct ipt_standard_target); +constexpr size_t kEmptyErrorEntrySize = + sizeof(struct ipt_entry) + sizeof(struct ipt_error_target); + +TEST(IPTablesBasic, CreateSocket) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallSucceeds()); + + ASSERT_THAT(close(sock), SyscallSucceeds()); +} + +TEST(IPTablesBasic, FailSockoptNonRaw) { + // Even if the user has CAP_NET_RAW, they shouldn't be able to use the + // iptables sockopts with a non-raw socket. + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(ENOPROTOOPT)); + + ASSERT_THAT(close(sock), SyscallSucceeds()); +} + +// Fixture for iptables tests. +class IPTablesTest : public ::testing::Test { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // The socket via which to manipulate iptables. + int s_; +}; + +void IPTablesTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); +} + +void IPTablesTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + EXPECT_THAT(close(s_), SyscallSucceeds()); +} + +// This tests the initial state of a machine with empty iptables. We don't have +// a guarantee that the iptables are empty when running in native, but we can +// test that gVisor has the same initial state that a newly-booted Linux machine +// would have. +TEST_F(IPTablesTest, InitialState) { + SKIP_IF(!IsRunningOnGvisor()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // + // Get info via sockopt. + // + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // The nat table supports PREROUTING, and OUTPUT. + unsigned int valid_hooks = (1 << NF_IP_PRE_ROUTING) | (1 << NF_IP_LOCAL_OUT) | + (1 << NF_IP_POST_ROUTING) | (1 << NF_IP_LOCAL_IN); + + EXPECT_EQ(info.valid_hooks, valid_hooks); + + // Each chain consists of an empty entry with a standard target.. + EXPECT_EQ(info.hook_entry[NF_IP_PRE_ROUTING], 0); + EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.hook_entry[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // The underflow points are the same as the entry points. + EXPECT_EQ(info.underflow[NF_IP_PRE_ROUTING], 0); + EXPECT_EQ(info.underflow[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.underflow[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.underflow[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // One entry for each chain, plus an error entry at the end. + EXPECT_EQ(info.num_entries, 5); + + EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize); + EXPECT_EQ(strcmp(info.name, kNatTablename), 0); + + // + // Use info to get entries. + // + socklen_t entries_size = sizeof(struct ipt_get_entries) + info.size; + struct ipt_get_entries* entries = + static_cast<struct ipt_get_entries*>(malloc(entries_size)); + snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + entries->size = info.size; + ASSERT_THAT( + getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size), + SyscallSucceeds()); + + // Verify the name and size. + ASSERT_EQ(info.size, entries->size); + ASSERT_EQ(strcmp(entries->name, kNatTablename), 0); + + // Verify that the entrytable is 4 entries with accept targets and no matches + // followed by a single error target. + size_t entry_offset = 0; + while (entry_offset < entries->size) { + struct ipt_entry* entry = reinterpret_cast<struct ipt_entry*>( + reinterpret_cast<char*>(entries->entrytable) + entry_offset); + + // ip should be zeroes. + struct ipt_ip zeroed = {}; + EXPECT_EQ(memcmp(static_cast<void*>(&zeroed), + static_cast<void*>(&entry->ip), sizeof(zeroed)), + 0); + + // target_offset should be zero. + EXPECT_EQ(entry->target_offset, sizeof(ipt_entry)); + + if (entry_offset < kEmptyStandardEntrySize * 4) { + // The first 4 entries are standard targets + struct ipt_standard_target* target = + reinterpret_cast<struct ipt_standard_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + // This is what's returned for an accept verdict. I don't know why. + EXPECT_EQ(target->verdict, -NF_ACCEPT - 1); + } else { + // The last entry is an error target + struct ipt_error_target* target = + reinterpret_cast<struct ipt_error_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0); + } + + entry_offset += entry->next_offset; + } + + free(entries); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc new file mode 100644 index 000000000..7a3379b9e --- /dev/null +++ b/test/syscalls/linux/packet_socket.cc @@ -0,0 +1,299 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/udp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/base/internal/endian.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Some of these tests involve sending packets via AF_PACKET sockets and the +// loopback interface. Because AF_PACKET circumvents so much of the networking +// stack, Linux sees these packets as "martian", i.e. they claim to be to/from +// localhost but don't have the usual associated data. Thus Linux drops them by +// default. You can see where this happens by following the code at: +// +// - net/ipv4/ip_input.c:ip_rcv_finish, which calls +// - net/ipv4/route.c:ip_route_input_noref, which calls +// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian +// packets. +// +// To tell Linux not to drop these packets, you need to tell it to accept our +// funny packets (which are completely valid and correct, but lack associated +// in-kernel data because we use AF_PACKET): +// +// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local +// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet +// +// These tests require CAP_NET_RAW to run. + +// TODO(gvisor.dev/issue/173): gVisor support. + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kMessage[] = "soweoneul malhaebwa"; +constexpr in_port_t kPort = 0x409c; // htons(40000) + +// +// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer +// headers, and provide link layer destination/source information via a +// returned struct sockaddr_ll. +// + +// Send kMessage via sock to loopback +void SendUDPMessage(int sock) { + struct sockaddr_in dest = {}; + dest.sin_port = kPort; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); +} + +// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up. +TEST(BasicCookedPacketTest, WrongType) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP)); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait and make sure the socket never becomes readable. + struct pollfd pfd = {}; + pfd.fd = sock.get(); + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + +// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets. +class CookedPacketTest : public ::testing::TestWithParam<int> { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Gets the device index of the loopback device. + int GetLoopbackIndex(); + + // The socket used for both reading and writing. + int socket_; +}; + +void CookedPacketTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), + SyscallSucceeds()); +} + +void CookedPacketTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + EXPECT_THAT(close(socket_), SyscallSucceeds()); +} + +int CookedPacketTest::GetLoopbackIndex() { + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + return ifr.ifr_ifindex; +} + +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait for the socket to become readable. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); + + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + ASSERT_EQ(src_len, sizeof(src)); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); + EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + // This came from the loopback device, so the address is all 0s. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], 0); + } + + // Verify the IP header. We memcpy to deal with pointer aligment. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. We memcpy to deal with pointer aligment. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + +// Send via a packet socket. +TEST_P(CookedPacketTest, Send) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's send a UDP packet and receive it using a regular UDP socket. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + bind_addr.sin_port = kPort; + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Set up the destination physical address. + struct sockaddr_ll dest = {}; + dest.sll_family = AF_PACKET; + dest.sll_halen = ETH_ALEN; + dest.sll_ifindex = GetLoopbackIndex(); + dest.sll_protocol = htons(ETH_P_IP); + // We're sending to the loopback device, so the address is all 0s. + memset(dest.sll_addr, 0x00, ETH_ALEN); + + // Set up the IP header. + struct iphdr iphdr = {0}; + iphdr.ihl = 5; + iphdr.version = 4; + iphdr.tos = 0; + iphdr.tot_len = + htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); + // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, + // but we have no way of getting an ID we know to be good. + srand(*reinterpret_cast<unsigned int*>(&iphdr)); + iphdr.id = rand(); + // Linux sets this bit ("do not fragment") for small packets. + iphdr.frag_off = 1 << 6; + iphdr.ttl = 64; + iphdr.protocol = IPPROTO_UDP; + iphdr.daddr = htonl(INADDR_LOOPBACK); + iphdr.saddr = htonl(INADDR_LOOPBACK); + iphdr.check = IPChecksum(iphdr); + + // Set up the UDP header. + struct udphdr udphdr = {}; + udphdr.source = kPort; + udphdr.dest = kPort; + udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); + udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); + + // Copy both headers and the payload into our packet buffer. + char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; + memcpy(send_buf, &iphdr, sizeof(iphdr)); + memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr)); + memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage)); + + // Send it. + ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Wait for the packet to become available on both sockets. + struct pollfd pfd = {}; + pfd.fd = udp_sock.get(); + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + pfd.fd = socket_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + + // Receive on the packet socket. + char recv_buf[sizeof(send_buf)]; + ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); + + // Receive on the UDP socket. + struct sockaddr_in src; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(sizeof(kMessage))); + // Check src and payload. + EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(src.sin_port, kPort); + EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); +} + +INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, + ::testing::Values(ETH_P_IP, ETH_P_ALL)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc new file mode 100644 index 000000000..9e96460ee --- /dev/null +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/udp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/base/internal/endian.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Some of these tests involve sending packets via AF_PACKET sockets and the +// loopback interface. Because AF_PACKET circumvents so much of the networking +// stack, Linux sees these packets as "martian", i.e. they claim to be to/from +// localhost but don't have the usual associated data. Thus Linux drops them by +// default. You can see where this happens by following the code at: +// +// - net/ipv4/ip_input.c:ip_rcv_finish, which calls +// - net/ipv4/route.c:ip_route_input_noref, which calls +// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian +// packets. +// +// To tell Linux not to drop these packets, you need to tell it to accept our +// funny packets (which are completely valid and correct, but lack associated +// in-kernel data because we use AF_PACKET): +// +// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local +// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet +// +// These tests require CAP_NET_RAW to run. + +// TODO(gvisor.dev/issue/173): gVisor support. + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kMessage[] = "soweoneul malhaebwa"; +constexpr in_port_t kPort = 0x409c; // htons(40000) + +// Send kMessage via sock to loopback +void SendUDPMessage(int sock) { + struct sockaddr_in dest = {}; + dest.sin_port = kPort; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); +} + +// +// Raw tests. Packets sent with raw AF_PACKET sockets always include link layer +// headers. +// + +// Tests for "raw" (SOCK_RAW) packet(7) sockets. +class RawPacketTest : public ::testing::TestWithParam<int> { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Gets the device index of the loopback device. + int GetLoopbackIndex(); + + // The socket used for both reading and writing. + int socket_; +}; + +void RawPacketTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + if (!IsRunningOnGvisor()) { + FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + char enabled; + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } + + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), + SyscallSucceeds()); +} + +void RawPacketTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + EXPECT_THAT(close(socket_), SyscallSucceeds()); +} + +int RawPacketTest::GetLoopbackIndex() { + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + return ifr.ifr_ifindex; +} + +// Receive via a packet socket. +TEST_P(RawPacketTest, Receive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait for the socket to become readable. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); + + // Read and verify the data. + constexpr size_t packet_size = sizeof(struct ethhdr) + sizeof(struct iphdr) + + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + // sizeof(src) is the size of a struct sockaddr_ll. sockaddr_ll ends with an 8 + // byte physical address field, but ethernet (MAC) addresses only use 6 bytes. + // Thus src_len should get modified to be 2 less than the size of sockaddr_ll. + ASSERT_EQ(src_len, sizeof(src) - 2); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); + EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + // This came from the loopback device, so the address is all 0s. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], 0); + } + + // Verify the ethernet header. We memcpy to deal with pointer alignment. + struct ethhdr eth = {}; + memcpy(ð, buf, sizeof(eth)); + // The destination and source address should be 0, for loopback. + for (int i = 0; i < ETH_ALEN; i++) { + EXPECT_EQ(eth.h_dest[i], 0); + EXPECT_EQ(eth.h_source[i], 0); + } + EXPECT_EQ(eth.h_proto, htons(ETH_P_IP)); + + // Verify the IP header. We memcpy to deal with pointer aligment. + struct iphdr ip = {}; + memcpy(&ip, buf + sizeof(ethhdr), sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size - sizeof(eth))); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. We memcpy to deal with pointer aligment. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(eth) + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(eth) + sizeof(iphdr) + + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + +// Send via a packet socket. +TEST_P(RawPacketTest, Send) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's send a UDP packet and receive it using a regular UDP socket. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + bind_addr.sin_port = kPort; + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Set up the destination physical address. + struct sockaddr_ll dest = {}; + dest.sll_family = AF_PACKET; + dest.sll_halen = ETH_ALEN; + dest.sll_ifindex = GetLoopbackIndex(); + dest.sll_protocol = htons(ETH_P_IP); + // We're sending to the loopback device, so the address is all 0s. + memset(dest.sll_addr, 0x00, ETH_ALEN); + + // Set up the ethernet header. The kernel takes care of the footer. + // We're sending to and from hardware address 0 (loopback). + struct ethhdr eth = {}; + eth.h_proto = htons(ETH_P_IP); + + // Set up the IP header. + struct iphdr iphdr = {}; + iphdr.ihl = 5; + iphdr.version = 4; + iphdr.tos = 0; + iphdr.tot_len = + htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); + // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, + // but we have no way of getting an ID we know to be good. + srand(*reinterpret_cast<unsigned int*>(&iphdr)); + iphdr.id = rand(); + // Linux sets this bit ("do not fragment") for small packets. + iphdr.frag_off = 1 << 6; + iphdr.ttl = 64; + iphdr.protocol = IPPROTO_UDP; + iphdr.daddr = htonl(INADDR_LOOPBACK); + iphdr.saddr = htonl(INADDR_LOOPBACK); + iphdr.check = IPChecksum(iphdr); + + // Set up the UDP header. + struct udphdr udphdr = {}; + udphdr.source = kPort; + udphdr.dest = kPort; + udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); + udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); + + // Copy both headers and the payload into our packet buffer. + char + send_buf[sizeof(eth) + sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; + memcpy(send_buf, ð, sizeof(eth)); + memcpy(send_buf + sizeof(ethhdr), &iphdr, sizeof(iphdr)); + memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr), &udphdr, sizeof(udphdr)); + memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr) + sizeof(udphdr), kMessage, + sizeof(kMessage)); + + // Send it. + ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Wait for the packet to become available on both sockets. + struct pollfd pfd = {}; + pfd.fd = udp_sock.get(); + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + pfd.fd = socket_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + + // Receive on the packet socket. + char recv_buf[sizeof(send_buf)]; + ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); + + // Receive on the UDP socket. + struct sockaddr_in src; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(sizeof(kMessage))); + // Check src and payload. + EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(src.sin_port, kPort); + EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); +} + +INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, + ::testing::Values(ETH_P_IP /*, ETH_P_ALL*/)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc index 83b1ad4e4..33822ee57 100644 --- a/test/syscalls/linux/partial_bad_buffer.cc +++ b/test/syscalls/linux/partial_bad_buffer.cc @@ -14,13 +14,20 @@ #include <errno.h> #include <fcntl.h> +#include <netinet/in.h> +#include <netinet/tcp.h> #include <sys/mman.h> +#include <sys/socket.h> #include <sys/syscall.h> #include <sys/uio.h> #include <unistd.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -299,6 +306,109 @@ TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) { EXPECT_STREQ(buf, kMessage); } +PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// SendMsgTCP verifies that calling sendmsg with a bad address returns an +// EFAULT. It also verifies that passing a buffer which is made up of 2 +// pages one valid and one guard page succeeds as long as the write is +// for exactly the size of 1 page. +TEST_F(PartialBadBufferTest, SendMsgTCP) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // TODO(gvisor.dev/issue/674): Update this once Netstack matches linux + // behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF. + // + // Set SO_SNDBUF for socket to exactly kPageSize+1. + // + // gVisor does not double the value passed in SO_SNDBUF like linux does so we + // just increase it by 1 byte here for gVisor so that we can test writing 1 + // byte past the valid page and check that it triggers an EFAULT + // correctly. Otherwise in gVisor the sendmsg call will just return with no + // error with kPageSize bytes written successfully. + const uint32_t buf_size = kPageSize + 1; + ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceedsWithValue(0)); + + struct msghdr hdr = {}; + struct iovec iov = {}; + iov.iov_base = bad_buffer_; + iov.iov_len = kPageSize; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); + + // Now assert that writing kPageSize from addr_ succeeds. + iov.iov_base = addr_; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallSucceedsWithValue(kPageSize)); + // Read all the data out so that we drain the socket SND_BUF on the sender. + std::vector<char> buffer(kPageSize); + ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize), + SyscallSucceedsWithValue(kPageSize)); + + // Sleep for a shortwhile to ensure that we have time to process the + // ACKs. This is not strictly required unless running under gotsan which is a + // lot slower and can result in the next write to write only 1 byte instead of + // our intended kPageSize + 1. + absl::SleepFor(absl::Milliseconds(50)); + + // Now assert that writing > kPageSize results in EFAULT. + iov.iov_len = kPageSize + 1; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index 2ca7b6ad7..498f62d9c 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -249,7 +249,8 @@ TEST(ProcNetTCP, State) { std::unique_ptr<FileDescriptor> client = ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPUnboundSocket(0).Create()); - ASSERT_THAT(connect(client->get(), &addr, addrlen), SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(client->get(), &addr, addrlen), + SyscallSucceeds()); entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); ASSERT_TRUE(FindByLocalAddr(entries, &listen_entry, &addr)); EXPECT_EQ(listen_entry.state, TCP_LISTEN); diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc index 14a4af980..d2a321a6e 100644 --- a/test/syscalls/linux/pty_root.cc +++ b/test/syscalls/linux/pty_root.cc @@ -50,7 +50,7 @@ TEST(JobControlRootTest, StealTTY) { // of 1. pid_t child = fork(); if (!child) { - ASSERT_THAT(setsid(), SyscallSucceeds()); + TEST_PCHECK(setsid() >= 0); // We shouldn't be able to steal the terminal with the wrong arg value. TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); // We should be able to steal it here. diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 1c07bacc2..971592d7d 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -35,32 +35,6 @@ namespace testing { namespace { -// Compute the internet checksum of the ICMP header (assuming no payload). -static uint16_t Checksum(struct icmphdr* icmp) { - uint32_t total = 0; - uint16_t* num = reinterpret_cast<uint16_t*>(icmp); - - // This is just the ICMP header, so there's an even number of bytes. - static_assert( - sizeof(*icmp) % sizeof(*num) == 0, - "sizeof(struct icmphdr) is not an integer multiple of sizeof(uint16_t)"); - for (unsigned int i = 0; i < sizeof(*icmp); i += sizeof(*num)) { - total += *num; - num++; - } - - // Combine the upper and lower 16 bits. This happens twice in case the first - // combination causes a carry. - unsigned short upper = total >> 16; - unsigned short lower = total & 0xffff; - total = upper + lower; - upper = total >> 16; - lower = total & 0xffff; - total = upper + lower; - - return ~total; -} - // The size of an empty ICMP packet and IP header together. constexpr size_t kEmptyICMPSize = 28; @@ -164,7 +138,7 @@ TEST_F(RawSocketICMPTest, SendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2012; icmp.un.echo.id = 2014; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -187,7 +161,7 @@ TEST_F(RawSocketICMPTest, MultipleSocketReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2016; icmp.un.echo.id = 2018; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); // Both sockets will receive the echo request and reply in indeterminate @@ -297,7 +271,7 @@ TEST_F(RawSocketICMPTest, ShortEchoRawAndPingSockets) { icmp.un.echo.sequence = 0; icmp.un.echo.id = 6789; icmp.checksum = 0; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); // Omit 2 bytes from ICMP packet. constexpr int kShortICMPSize = sizeof(icmp) - 2; @@ -338,7 +312,7 @@ TEST_F(RawSocketICMPTest, ShortEchoReplyRawAndPingSockets) { icmp.un.echo.sequence = 0; icmp.un.echo.id = 6789; icmp.checksum = 0; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); // Omit 2 bytes from ICMP packet. constexpr int kShortICMPSize = sizeof(icmp) - 2; @@ -381,7 +355,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveViaConnect) { icmp.checksum = 0; icmp.un.echo.sequence = 2003; icmp.un.echo.id = 2004; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0), SyscallSucceedsWithValue(sizeof(icmp))); @@ -405,7 +379,7 @@ TEST_F(RawSocketICMPTest, BindSendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2004; icmp.un.echo.id = 2007; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -431,7 +405,7 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2010; icmp.un.echo.id = 7; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -471,7 +445,7 @@ void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { // A couple are different. EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY); // The checksum computed over the reply should still be valid. - EXPECT_EQ(Checksum(recvd_icmp), 0); + EXPECT_EQ(ICMPChecksum(*recvd_icmp, NULL, 0), 0); break; } } diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index d9aa7ff3f..67d29af0a 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -30,6 +30,7 @@ namespace gvisor { namespace testing { constexpr char kMulticastAddress[] = "224.0.2.1"; +constexpr char kBroadcastAddress[] = "255.255.255.255"; TestAddress V4Multicast() { TestAddress t("V4Multicast"); @@ -40,6 +41,15 @@ TestAddress V4Multicast() { return t; } +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kBroadcastAddress); + return t; +} + // Check that packets are not received without a group membership. Default send // interface configured by bind. TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { @@ -1426,5 +1436,249 @@ TEST_P(IPv4UDPUnboundSocketPairTest, } } +// Check that a receiving socket can bind to the multicast address before +// joining the group and receive data once the group has been joined. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the multicast address and won't +// receive multicast data if it hasn't joined the group. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we don't receive the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Check that a socket can bind to a multicast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the multicast address. + auto sender_addr = V4Multicast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the broadcast address and receive +// broadcast packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the broadcast address. + auto receiver_addr = V4Broadcast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a broadcast packet on the first socket out the loopback interface. + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + // Note: Binding to the loopback interface makes the broadcast go out of it. + auto sender_bind_addr = V4Loopback(); + ASSERT_THAT(bind(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); + auto sendto_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a socket can bind to the broadcast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the broadcast address. + auto sender_addr = V4Broadcast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 4f65cf5ae..3c716235b 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -744,5 +744,74 @@ TestAddress V6Loopback() { return t; } +// Checksum computes the internet checksum of a buffer. +uint16_t Checksum(uint16_t* buf, ssize_t buf_size) { + // Add up the 16-bit values in the buffer. + uint32_t total = 0; + for (unsigned int i = 0; i < buf_size; i += sizeof(*buf)) { + total += *buf; + buf++; + } + + // If buf has an odd size, add the remaining byte. + if (buf_size % 2) { + total += *(reinterpret_cast<unsigned char*>(buf) - 1); + } + + // This carries any bits past the lower 16 until everything fits in 16 bits. + while (total >> 16) { + uint16_t lower = total & 0xffff; + uint16_t upper = total >> 16; + total = lower + upper; + } + + return ~total; +} + +uint16_t IPChecksum(struct iphdr ip) { + return Checksum(reinterpret_cast<uint16_t*>(&ip), sizeof(ip)); +} + +// The pseudo-header defined in RFC 768 for calculating the UDP checksum. +struct udp_pseudo_hdr { + uint32_t srcip; + uint32_t destip; + char zero; + char protocol; + uint16_t udplen; +}; + +uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, + const char* payload, ssize_t payload_len) { + struct udp_pseudo_hdr phdr = {}; + phdr.srcip = iphdr.saddr; + phdr.destip = iphdr.daddr; + phdr.zero = 0; + phdr.protocol = IPPROTO_UDP; + phdr.udplen = udphdr.len; + + ssize_t buf_size = sizeof(phdr) + sizeof(udphdr) + payload_len; + char* buf = static_cast<char*>(malloc(buf_size)); + memcpy(buf, &phdr, sizeof(phdr)); + memcpy(buf + sizeof(phdr), &udphdr, sizeof(udphdr)); + memcpy(buf + sizeof(phdr) + sizeof(udphdr), payload, payload_len); + + uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); + free(buf); + return csum; +} + +uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, + ssize_t payload_len) { + ssize_t buf_size = sizeof(icmphdr) + payload_len; + char* buf = static_cast<char*>(malloc(buf_size)); + memcpy(buf, &icmphdr, sizeof(icmphdr)); + memcpy(buf + sizeof(icmphdr), payload, payload_len); + + uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); + free(buf); + return csum; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 4fd59767a..ae0da2679 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -17,9 +17,12 @@ #include <errno.h> #include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <netinet/udp.h> #include <sys/socket.h> #include <sys/types.h> #include <sys/un.h> + #include <functional> #include <memory> #include <string> @@ -478,6 +481,17 @@ TestAddress V4MappedLoopback(); TestAddress V6Any(); TestAddress V6Loopback(); +// Compute the internet checksum of an IP header. +uint16_t IPChecksum(struct iphdr ip); + +// Compute the internet checksum of a UDP header. +uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, + const char* payload, ssize_t payload_len); + +// Compute the internet checksum of an ICMP header. +uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, + ssize_t payload_len); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 6ffb65168..111dbacdf 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -378,16 +378,17 @@ TEST_P(UdpSocketTest, Connect) { EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); } -TEST_P(UdpSocketTest, ConnectAny) { +void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { struct sockaddr_storage addr = {}; // Precondition check. { socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); - if (GetParam() == AddressFamily::kIpv4) { + if (family == AddressFamily::kIpv4) { auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); EXPECT_EQ(addrlen, sizeof(*addr_out)); EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); @@ -400,21 +401,24 @@ TEST_P(UdpSocketTest, ConnectAny) { { socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); } struct sockaddr_storage baddr = {}; - if (GetParam() == AddressFamily::kIpv4) { + if (family == AddressFamily::kIpv4) { auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); addrlen = sizeof(*addr_in); addr_in->sin_family = AF_INET; addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; } else { auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); addrlen = sizeof(*addr_in); addr_in->sin6_family = AF_INET6; - if (GetParam() == AddressFamily::kIpv6) { + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { addr_in->sin6_addr = IN6ADDR_ANY_INIT; } else { TestAddress const& v4_mapped_any = V4MappedAny(); @@ -424,21 +428,23 @@ TEST_P(UdpSocketTest, ConnectAny) { } } - ASSERT_THAT(connect(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen), + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), SyscallSucceeds()); } - // TODO(b/138658473): gVisor doesn't return the correct local address after - // connecting to the any address. - SKIP_IF(IsRunningOnGvisor()); - // Postcondition check. { socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); - if (GetParam() == AddressFamily::kIpv4) { + if (family == AddressFamily::kIpv4) { auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); EXPECT_EQ(addrlen, sizeof(*addr_out)); EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); @@ -446,7 +452,7 @@ TEST_P(UdpSocketTest, ConnectAny) { auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); EXPECT_EQ(addrlen, sizeof(*addr_out)); struct in6_addr loopback; - if (GetParam() == AddressFamily::kIpv6) { + if (family == AddressFamily::kIpv6) { loopback = IN6ADDR_LOOPBACK_INIT; } else { TestAddress const& v4_mapped_loopback = V4MappedLoopback(); @@ -459,11 +465,91 @@ TEST_P(UdpSocketTest, ConnectAny) { } addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + if (port == 0) { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } else { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + } + } +} + +TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + ConnectAny(GetParam(), s_, port); +} + +void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { + struct sockaddr_storage addr = {}; + + socklen_t addrlen = sizeof(addr); + struct sockaddr_storage baddr = {}; + if (family == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), + SyscallSucceeds()); + // Now the socket is bound to the loopback address. + + // Disconnect + addrlen = sizeof(addr); + addr.ss_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); } } +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + DisconnectAfterConnectAny(GetParam(), s_, 0); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + DisconnectAfterConnectAny(GetParam(), s_, port); +} + TEST_P(UdpSocketTest, DisconnectAfterBind) { ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); // Connect the socket. |