diff options
Diffstat (limited to 'test/syscalls/linux')
27 files changed, 2362 insertions, 130 deletions
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 40fc73812..88f3bfcb3 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,3 +1,5 @@ +load("//test/syscalls:build_defs.bzl", "select_for_linux") + package( default_visibility = ["//:sandbox"], licenses = ["notice"], @@ -108,20 +110,27 @@ cc_library( cc_library( name = "socket_test_util", testonly = 1, - srcs = ["socket_test_util.cc"], + srcs = [ + "socket_test_util.cc", + ] + select_for_linux( + [ + "socket_test_util_impl.cc", + ], + ), hdrs = ["socket_test_util.h"], deps = [ + "@com_google_googletest//:gtest", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", "//test/util:file_descriptor", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", - ], + ] + select_for_linux([ + ]), ) cc_library( @@ -913,6 +922,24 @@ cc_library( ) cc_binary( + name = "iptables_test", + testonly = 1, + srcs = [ + "iptables.cc", + ], + linkstatic = 1, + deps = [ + ":iptables_types", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "itimer_test", testonly = 1, srcs = ["itimer.cc"], @@ -1209,13 +1236,51 @@ cc_binary( ) cc_binary( + name = "packet_socket_raw_test", + testonly = 1, + srcs = ["packet_socket_raw.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "packet_socket_test", + testonly = 1, + srcs = ["packet_socket.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "pty_test", testonly = 1, srcs = ["pty.cc"], linkstatic = 1, deps = [ + "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:posix_error", + "//test/util:pty_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", @@ -1228,15 +1293,36 @@ cc_binary( ) cc_binary( + name = "pty_root_test", + testonly = 1, + srcs = ["pty_root.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:posix_error", + "//test/util:pty_util", + "//test/util:test_main", + "//test/util:thread_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "partial_bad_buffer_test", testonly = 1, srcs = ["partial_bad_buffer.cc"], linkstatic = 1, deps = [ + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", "//test/util:fs_util", + "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/affinity.cc b/test/syscalls/linux/affinity.cc index f2d8375b6..128364c34 100644 --- a/test/syscalls/linux/affinity.cc +++ b/test/syscalls/linux/affinity.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <sched.h> +#include <sys/syscall.h> #include <sys/types.h> #include <unistd.h> diff --git a/test/syscalls/linux/base_poll_test.h b/test/syscalls/linux/base_poll_test.h index 088831f9f..0d4a6701e 100644 --- a/test/syscalls/linux/base_poll_test.h +++ b/test/syscalls/linux/base_poll_test.h @@ -56,7 +56,7 @@ class TimerThread { private: mutable absl::Mutex mu_; - bool cancel_ GUARDED_BY(mu_) = false; + bool cancel_ ABSL_GUARDED_BY(mu_) = false; // Must be last to ensure that the destructor for the thread is run before // any other member of the object is destroyed. diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index aacbb5e70..d3e3f998c 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -125,6 +125,10 @@ int futex_lock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int zero = 0; + if (uaddr->compare_exchange_strong(zero, gettid())) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -133,6 +137,10 @@ int futex_trylock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int zero = 0; + if (uaddr->compare_exchange_strong(zero, gettid())) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -141,6 +149,10 @@ int futex_unlock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int tid = gettid(); + if (uaddr->compare_exchange_strong(tid, 0)) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -689,11 +701,11 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { std::atomic<int> a = ATOMIC_VAR_INIT(0); const bool is_priv = IsPrivate(); - std::unique_ptr<ScopedThread> threads[100]; + std::unique_ptr<ScopedThread> threads[10]; for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) { threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] { for (size_t j = 0; j < 10;) { - if (futex_trylock_pi(is_priv, &a) >= 0) { + if (futex_trylock_pi(is_priv, &a) == 0) { ++j; EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid()); SleepSafe(absl::Milliseconds(5)); diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc new file mode 100644 index 000000000..b8e4ece64 --- /dev/null +++ b/test/syscalls/linux/iptables.cc @@ -0,0 +1,204 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/iptables.h" + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/netfilter/x_tables.h> +#include <net/if.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <stdio.h> +#include <sys/poll.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <algorithm> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kNatTablename[] = "nat"; +constexpr char kErrorTarget[] = "ERROR"; +constexpr size_t kEmptyStandardEntrySize = + sizeof(struct ipt_entry) + sizeof(struct ipt_standard_target); +constexpr size_t kEmptyErrorEntrySize = + sizeof(struct ipt_entry) + sizeof(struct ipt_error_target); + +TEST(IPTablesBasic, CreateSocket) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallSucceeds()); + + ASSERT_THAT(close(sock), SyscallSucceeds()); +} + +TEST(IPTablesBasic, FailSockoptNonRaw) { + // Even if the user has CAP_NET_RAW, they shouldn't be able to use the + // iptables sockopts with a non-raw socket. + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(ENOPROTOOPT)); + + ASSERT_THAT(close(sock), SyscallSucceeds()); +} + +// Fixture for iptables tests. +class IPTablesTest : public ::testing::Test { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // The socket via which to manipulate iptables. + int s_; +}; + +void IPTablesTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); +} + +void IPTablesTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + EXPECT_THAT(close(s_), SyscallSucceeds()); +} + +// This tests the initial state of a machine with empty iptables. We don't have +// a guarantee that the iptables are empty when running in native, but we can +// test that gVisor has the same initial state that a newly-booted Linux machine +// would have. +TEST_F(IPTablesTest, InitialState) { + SKIP_IF(!IsRunningOnGvisor()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // + // Get info via sockopt. + // + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // The nat table supports PREROUTING, and OUTPUT. + unsigned int valid_hooks = (1 << NF_IP_PRE_ROUTING) | (1 << NF_IP_LOCAL_OUT) | + (1 << NF_IP_POST_ROUTING) | (1 << NF_IP_LOCAL_IN); + + EXPECT_EQ(info.valid_hooks, valid_hooks); + + // Each chain consists of an empty entry with a standard target.. + EXPECT_EQ(info.hook_entry[NF_IP_PRE_ROUTING], 0); + EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.hook_entry[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // The underflow points are the same as the entry points. + EXPECT_EQ(info.underflow[NF_IP_PRE_ROUTING], 0); + EXPECT_EQ(info.underflow[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.underflow[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.underflow[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // One entry for each chain, plus an error entry at the end. + EXPECT_EQ(info.num_entries, 5); + + EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize); + EXPECT_EQ(strcmp(info.name, kNatTablename), 0); + + // + // Use info to get entries. + // + socklen_t entries_size = sizeof(struct ipt_get_entries) + info.size; + struct ipt_get_entries* entries = + static_cast<struct ipt_get_entries*>(malloc(entries_size)); + snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + entries->size = info.size; + ASSERT_THAT( + getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size), + SyscallSucceeds()); + + // Verify the name and size. + ASSERT_EQ(info.size, entries->size); + ASSERT_EQ(strcmp(entries->name, kNatTablename), 0); + + // Verify that the entrytable is 4 entries with accept targets and no matches + // followed by a single error target. + size_t entry_offset = 0; + while (entry_offset < entries->size) { + struct ipt_entry* entry = reinterpret_cast<struct ipt_entry*>( + reinterpret_cast<char*>(entries->entrytable) + entry_offset); + + // ip should be zeroes. + struct ipt_ip zeroed = {}; + EXPECT_EQ(memcmp(static_cast<void*>(&zeroed), + static_cast<void*>(&entry->ip), sizeof(zeroed)), + 0); + + // target_offset should be zero. + EXPECT_EQ(entry->target_offset, sizeof(ipt_entry)); + + if (entry_offset < kEmptyStandardEntrySize * 4) { + // The first 4 entries are standard targets + struct ipt_standard_target* target = + reinterpret_cast<struct ipt_standard_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + // This is what's returned for an accept verdict. I don't know why. + EXPECT_EQ(target->verdict, -NF_ACCEPT - 1); + } else { + // The last entry is an error target + struct ipt_error_target* target = + reinterpret_cast<struct ipt_error_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0); + } + + entry_offset += entry->next_offset; + } + + free(entries); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc new file mode 100644 index 000000000..7a3379b9e --- /dev/null +++ b/test/syscalls/linux/packet_socket.cc @@ -0,0 +1,299 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/udp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/base/internal/endian.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Some of these tests involve sending packets via AF_PACKET sockets and the +// loopback interface. Because AF_PACKET circumvents so much of the networking +// stack, Linux sees these packets as "martian", i.e. they claim to be to/from +// localhost but don't have the usual associated data. Thus Linux drops them by +// default. You can see where this happens by following the code at: +// +// - net/ipv4/ip_input.c:ip_rcv_finish, which calls +// - net/ipv4/route.c:ip_route_input_noref, which calls +// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian +// packets. +// +// To tell Linux not to drop these packets, you need to tell it to accept our +// funny packets (which are completely valid and correct, but lack associated +// in-kernel data because we use AF_PACKET): +// +// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local +// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet +// +// These tests require CAP_NET_RAW to run. + +// TODO(gvisor.dev/issue/173): gVisor support. + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kMessage[] = "soweoneul malhaebwa"; +constexpr in_port_t kPort = 0x409c; // htons(40000) + +// +// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer +// headers, and provide link layer destination/source information via a +// returned struct sockaddr_ll. +// + +// Send kMessage via sock to loopback +void SendUDPMessage(int sock) { + struct sockaddr_in dest = {}; + dest.sin_port = kPort; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); +} + +// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up. +TEST(BasicCookedPacketTest, WrongType) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP)); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait and make sure the socket never becomes readable. + struct pollfd pfd = {}; + pfd.fd = sock.get(); + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + +// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets. +class CookedPacketTest : public ::testing::TestWithParam<int> { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Gets the device index of the loopback device. + int GetLoopbackIndex(); + + // The socket used for both reading and writing. + int socket_; +}; + +void CookedPacketTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), + SyscallSucceeds()); +} + +void CookedPacketTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + EXPECT_THAT(close(socket_), SyscallSucceeds()); +} + +int CookedPacketTest::GetLoopbackIndex() { + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + return ifr.ifr_ifindex; +} + +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait for the socket to become readable. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); + + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + ASSERT_EQ(src_len, sizeof(src)); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); + EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + // This came from the loopback device, so the address is all 0s. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], 0); + } + + // Verify the IP header. We memcpy to deal with pointer aligment. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. We memcpy to deal with pointer aligment. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + +// Send via a packet socket. +TEST_P(CookedPacketTest, Send) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's send a UDP packet and receive it using a regular UDP socket. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + bind_addr.sin_port = kPort; + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Set up the destination physical address. + struct sockaddr_ll dest = {}; + dest.sll_family = AF_PACKET; + dest.sll_halen = ETH_ALEN; + dest.sll_ifindex = GetLoopbackIndex(); + dest.sll_protocol = htons(ETH_P_IP); + // We're sending to the loopback device, so the address is all 0s. + memset(dest.sll_addr, 0x00, ETH_ALEN); + + // Set up the IP header. + struct iphdr iphdr = {0}; + iphdr.ihl = 5; + iphdr.version = 4; + iphdr.tos = 0; + iphdr.tot_len = + htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); + // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, + // but we have no way of getting an ID we know to be good. + srand(*reinterpret_cast<unsigned int*>(&iphdr)); + iphdr.id = rand(); + // Linux sets this bit ("do not fragment") for small packets. + iphdr.frag_off = 1 << 6; + iphdr.ttl = 64; + iphdr.protocol = IPPROTO_UDP; + iphdr.daddr = htonl(INADDR_LOOPBACK); + iphdr.saddr = htonl(INADDR_LOOPBACK); + iphdr.check = IPChecksum(iphdr); + + // Set up the UDP header. + struct udphdr udphdr = {}; + udphdr.source = kPort; + udphdr.dest = kPort; + udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); + udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); + + // Copy both headers and the payload into our packet buffer. + char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; + memcpy(send_buf, &iphdr, sizeof(iphdr)); + memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr)); + memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage)); + + // Send it. + ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Wait for the packet to become available on both sockets. + struct pollfd pfd = {}; + pfd.fd = udp_sock.get(); + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + pfd.fd = socket_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + + // Receive on the packet socket. + char recv_buf[sizeof(send_buf)]; + ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); + + // Receive on the UDP socket. + struct sockaddr_in src; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(sizeof(kMessage))); + // Check src and payload. + EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(src.sin_port, kPort); + EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); +} + +INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, + ::testing::Values(ETH_P_IP, ETH_P_ALL)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc new file mode 100644 index 000000000..9e96460ee --- /dev/null +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/udp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/base/internal/endian.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Some of these tests involve sending packets via AF_PACKET sockets and the +// loopback interface. Because AF_PACKET circumvents so much of the networking +// stack, Linux sees these packets as "martian", i.e. they claim to be to/from +// localhost but don't have the usual associated data. Thus Linux drops them by +// default. You can see where this happens by following the code at: +// +// - net/ipv4/ip_input.c:ip_rcv_finish, which calls +// - net/ipv4/route.c:ip_route_input_noref, which calls +// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian +// packets. +// +// To tell Linux not to drop these packets, you need to tell it to accept our +// funny packets (which are completely valid and correct, but lack associated +// in-kernel data because we use AF_PACKET): +// +// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local +// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet +// +// These tests require CAP_NET_RAW to run. + +// TODO(gvisor.dev/issue/173): gVisor support. + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kMessage[] = "soweoneul malhaebwa"; +constexpr in_port_t kPort = 0x409c; // htons(40000) + +// Send kMessage via sock to loopback +void SendUDPMessage(int sock) { + struct sockaddr_in dest = {}; + dest.sin_port = kPort; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); +} + +// +// Raw tests. Packets sent with raw AF_PACKET sockets always include link layer +// headers. +// + +// Tests for "raw" (SOCK_RAW) packet(7) sockets. +class RawPacketTest : public ::testing::TestWithParam<int> { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Gets the device index of the loopback device. + int GetLoopbackIndex(); + + // The socket used for both reading and writing. + int socket_; +}; + +void RawPacketTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + if (!IsRunningOnGvisor()) { + FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + char enabled; + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } + + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), + SyscallSucceeds()); +} + +void RawPacketTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + EXPECT_THAT(close(socket_), SyscallSucceeds()); +} + +int RawPacketTest::GetLoopbackIndex() { + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + return ifr.ifr_ifindex; +} + +// Receive via a packet socket. +TEST_P(RawPacketTest, Receive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait for the socket to become readable. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); + + // Read and verify the data. + constexpr size_t packet_size = sizeof(struct ethhdr) + sizeof(struct iphdr) + + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + // sizeof(src) is the size of a struct sockaddr_ll. sockaddr_ll ends with an 8 + // byte physical address field, but ethernet (MAC) addresses only use 6 bytes. + // Thus src_len should get modified to be 2 less than the size of sockaddr_ll. + ASSERT_EQ(src_len, sizeof(src) - 2); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); + EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + // This came from the loopback device, so the address is all 0s. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], 0); + } + + // Verify the ethernet header. We memcpy to deal with pointer alignment. + struct ethhdr eth = {}; + memcpy(ð, buf, sizeof(eth)); + // The destination and source address should be 0, for loopback. + for (int i = 0; i < ETH_ALEN; i++) { + EXPECT_EQ(eth.h_dest[i], 0); + EXPECT_EQ(eth.h_source[i], 0); + } + EXPECT_EQ(eth.h_proto, htons(ETH_P_IP)); + + // Verify the IP header. We memcpy to deal with pointer aligment. + struct iphdr ip = {}; + memcpy(&ip, buf + sizeof(ethhdr), sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size - sizeof(eth))); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. We memcpy to deal with pointer aligment. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(eth) + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(eth) + sizeof(iphdr) + + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + +// Send via a packet socket. +TEST_P(RawPacketTest, Send) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's send a UDP packet and receive it using a regular UDP socket. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + bind_addr.sin_port = kPort; + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Set up the destination physical address. + struct sockaddr_ll dest = {}; + dest.sll_family = AF_PACKET; + dest.sll_halen = ETH_ALEN; + dest.sll_ifindex = GetLoopbackIndex(); + dest.sll_protocol = htons(ETH_P_IP); + // We're sending to the loopback device, so the address is all 0s. + memset(dest.sll_addr, 0x00, ETH_ALEN); + + // Set up the ethernet header. The kernel takes care of the footer. + // We're sending to and from hardware address 0 (loopback). + struct ethhdr eth = {}; + eth.h_proto = htons(ETH_P_IP); + + // Set up the IP header. + struct iphdr iphdr = {}; + iphdr.ihl = 5; + iphdr.version = 4; + iphdr.tos = 0; + iphdr.tot_len = + htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); + // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, + // but we have no way of getting an ID we know to be good. + srand(*reinterpret_cast<unsigned int*>(&iphdr)); + iphdr.id = rand(); + // Linux sets this bit ("do not fragment") for small packets. + iphdr.frag_off = 1 << 6; + iphdr.ttl = 64; + iphdr.protocol = IPPROTO_UDP; + iphdr.daddr = htonl(INADDR_LOOPBACK); + iphdr.saddr = htonl(INADDR_LOOPBACK); + iphdr.check = IPChecksum(iphdr); + + // Set up the UDP header. + struct udphdr udphdr = {}; + udphdr.source = kPort; + udphdr.dest = kPort; + udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); + udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); + + // Copy both headers and the payload into our packet buffer. + char + send_buf[sizeof(eth) + sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; + memcpy(send_buf, ð, sizeof(eth)); + memcpy(send_buf + sizeof(ethhdr), &iphdr, sizeof(iphdr)); + memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr), &udphdr, sizeof(udphdr)); + memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr) + sizeof(udphdr), kMessage, + sizeof(kMessage)); + + // Send it. + ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Wait for the packet to become available on both sockets. + struct pollfd pfd = {}; + pfd.fd = udp_sock.get(); + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + pfd.fd = socket_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + + // Receive on the packet socket. + char recv_buf[sizeof(send_buf)]; + ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); + + // Receive on the UDP socket. + struct sockaddr_in src; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(sizeof(kMessage))); + // Check src and payload. + EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(src.sin_port, kPort); + EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); +} + +INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, + ::testing::Values(ETH_P_IP /*, ETH_P_ALL*/)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc index 83b1ad4e4..33822ee57 100644 --- a/test/syscalls/linux/partial_bad_buffer.cc +++ b/test/syscalls/linux/partial_bad_buffer.cc @@ -14,13 +14,20 @@ #include <errno.h> #include <fcntl.h> +#include <netinet/in.h> +#include <netinet/tcp.h> #include <sys/mman.h> +#include <sys/socket.h> #include <sys/syscall.h> #include <sys/uio.h> #include <unistd.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -299,6 +306,109 @@ TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) { EXPECT_STREQ(buf, kMessage); } +PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// SendMsgTCP verifies that calling sendmsg with a bad address returns an +// EFAULT. It also verifies that passing a buffer which is made up of 2 +// pages one valid and one guard page succeeds as long as the write is +// for exactly the size of 1 page. +TEST_F(PartialBadBufferTest, SendMsgTCP) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // TODO(gvisor.dev/issue/674): Update this once Netstack matches linux + // behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF. + // + // Set SO_SNDBUF for socket to exactly kPageSize+1. + // + // gVisor does not double the value passed in SO_SNDBUF like linux does so we + // just increase it by 1 byte here for gVisor so that we can test writing 1 + // byte past the valid page and check that it triggers an EFAULT + // correctly. Otherwise in gVisor the sendmsg call will just return with no + // error with kPageSize bytes written successfully. + const uint32_t buf_size = kPageSize + 1; + ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceedsWithValue(0)); + + struct msghdr hdr = {}; + struct iovec iov = {}; + iov.iov_base = bad_buffer_; + iov.iov_len = kPageSize; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); + + // Now assert that writing kPageSize from addr_ succeeds. + iov.iov_base = addr_; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallSucceedsWithValue(kPageSize)); + // Read all the data out so that we drain the socket SND_BUF on the sender. + std::vector<char> buffer(kPageSize); + ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize), + SyscallSucceedsWithValue(kPageSize)); + + // Sleep for a shortwhile to ensure that we have time to process the + // ACKs. This is not strictly required unless running under gotsan which is a + // lot slower and can result in the next write to write only 1 byte instead of + // our intended kPageSize + 1. + absl::SleepFor(absl::Milliseconds(50)); + + // Now assert that writing > kPageSize results in EFAULT. + iov.iov_len = kPageSize + 1; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index b440ba0df..2b753b7d1 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1602,9 +1602,9 @@ class BlockingChild { } mutable absl::Mutex mu_; - bool stop_ GUARDED_BY(mu_) = false; + bool stop_ ABSL_GUARDED_BY(mu_) = false; pid_t tid_; - bool tid_ready_ GUARDED_BY(mu_) = false; + bool tid_ready_ ABSL_GUARDED_BY(mu_) = false; // Must be last to ensure that the destructor for the thread is run before // any other member of the object is destroyed. diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index 578b20680..498f62d9c 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -187,9 +187,9 @@ TEST(ProcNetTCP, EntryUID) { std::vector<TCPEntry> entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); TCPEntry e; - EXPECT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr())); + ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr())); EXPECT_EQ(e.uid, geteuid()); - EXPECT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr())); + ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr())); EXPECT_EQ(e.uid, geteuid()); } @@ -249,7 +249,8 @@ TEST(ProcNetTCP, State) { std::unique_ptr<FileDescriptor> client = ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPUnboundSocket(0).Create()); - ASSERT_THAT(connect(client->get(), &addr, addrlen), SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(client->get(), &addr, addrlen), + SyscallSucceeds()); entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); ASSERT_TRUE(FindByLocalAddr(entries, &listen_entry, &addr)); EXPECT_EQ(listen_entry.state, TCP_LISTEN); diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index d1ab4703f..bd6907876 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -13,13 +13,17 @@ // limitations under the License. #include <fcntl.h> +#include <linux/capability.h> #include <linux/major.h> #include <poll.h> +#include <sched.h> +#include <signal.h> #include <sys/ioctl.h> #include <sys/mman.h> #include <sys/stat.h> #include <sys/sysmacros.h> #include <sys/types.h> +#include <sys/wait.h> #include <termios.h> #include <unistd.h> @@ -31,8 +35,10 @@ #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" +#include "test/util/pty_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -370,25 +376,6 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, return PosixError(ETIMEDOUT, "Poll timed out"); } -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - // Get pty index. - int n; - int ret = ioctl(master.get(), TIOCGPTN, &n); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOCGPTN) failed"); - } - - // Unlock pts. - int unlock = 0; - ret = ioctl(master.get(), TIOCSPTLCK, &unlock); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOSPTLCK) failed"); - } - - return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); -} - TEST(BasicPtyTest, StatUnopenedMaster) { struct stat s; ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds()); @@ -1233,6 +1220,340 @@ TEST_F(PtyTest, SetMasterWindowSize) { EXPECT_EQ(retrieved_ws.ws_col, kCols); } +class JobControlTest : public ::testing::Test { + protected: + void SetUp() override { + master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). + if (!IsRunningOnGvisor()) { + ASSERT_THAT(setsid(), SyscallSucceeds()); + } + } + + // Master and slave ends of the PTY. Non-blocking. + FileDescriptor master_; + FileDescriptor slave_; +}; + +TEST_F(JobControlTest, SetTTYMaster) { + ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTYNonLeader) { + // Fork a process that won't be the session leader. + pid_t child = fork(); + if (!child) { + // We shouldn't be able to set the terminal. + TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, SetTTYBadArg) { + // Despite the man page saying arg should be 0 here, Linux doesn't actually + // check. + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTYDifferentSession) { + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should fail. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal. + TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, ReleaseTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect they TTY. + struct sigaction sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + struct sigaction old_sa; + EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); + EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); +} + +TEST_F(JobControlTest, ReleaseUnsetTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, ReleaseWrongTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, ReleaseTTYNonLeader) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, ReleaseTTYDifferentSession) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t child = fork(); + if (!child) { + // Join a new session, then try to disconnect. + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +// Used by the child process spawned in ReleaseTTYSignals to track received +// signals. +static int received; + +void sig_handler(int signum) { received |= signum; } + +// When the session leader releases its controlling terminal, the foreground +// process group gets SIGHUP, then SIGCONT. This test: +// - Spawns 2 threads +// - Has thread 1 return 0 if it gets both SIGHUP and SIGCONT +// - Has thread 2 leave the foreground process group, and return non-zero if it +// receives any signals. +// - Has the parent thread release its controlling terminal +// - Checks that thread 1 got both signals +// - Checks that thread 2 didn't get any signals. +TEST_F(JobControlTest, ReleaseTTYSignals) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + received = 0; + struct sigaction sa = { + .sa_handler = sig_handler, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + sigaddset(&sa.sa_mask, SIGHUP); + sigaddset(&sa.sa_mask, SIGCONT); + sigprocmask(SIG_BLOCK, &sa.sa_mask, NULL); + + pid_t same_pgrp_child = fork(); + if (!same_pgrp_child) { + // The child will wait for SIGHUP and SIGCONT, then return 0. It begins with + // SIGHUP and SIGCONT blocked. We install signal handlers for those signals, + // then use sigsuspend to wait for those specific signals. + TEST_PCHECK(!sigaction(SIGHUP, &sa, NULL)); + TEST_PCHECK(!sigaction(SIGCONT, &sa, NULL)); + sigset_t mask; + sigfillset(&mask); + sigdelset(&mask, SIGHUP); + sigdelset(&mask, SIGCONT); + while (received != (SIGHUP | SIGCONT)) { + sigsuspend(&mask); + } + _exit(0); + } + + // We don't want to block these anymore. + sigprocmask(SIG_UNBLOCK, &sa.sa_mask, NULL); + + // This child will return non-zero if either SIGHUP or SIGCONT are received. + pid_t diff_pgrp_child = fork(); + if (!diff_pgrp_child) { + TEST_PCHECK(!setpgid(0, 0)); + TEST_PCHECK(pause()); + _exit(1); + } + + EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds()); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect they TTY. + struct sigaction sighup_sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sighup_sa.sa_mask); + struct sigaction old_sa; + EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds()); + + // Release the controlling terminal, sending SIGHUP and SIGCONT to all other + // processes in this process group. + EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + + EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); + + // The child in the same process group will get signaled. + int wstatus; + EXPECT_THAT(waitpid(same_pgrp_child, &wstatus, 0), + SyscallSucceedsWithValue(same_pgrp_child)); + EXPECT_EQ(wstatus, 0); + + // The other child will not get signaled. + EXPECT_THAT(waitpid(diff_pgrp_child, &wstatus, WNOHANG), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(kill(diff_pgrp_child, SIGKILL), SyscallSucceeds()); +} + +TEST_F(JobControlTest, GetForegroundProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + pid_t foreground_pgid; + pid_t pid; + ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + SyscallSucceeds()); + ASSERT_THAT(pid = getpid(), SyscallSucceeds()); + + ASSERT_EQ(foreground_pgid, pid); +} + +TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { + // At this point there's no controlling terminal, so TIOCGPGRP should fail. + pid_t foreground_pgid; + ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + SyscallFailsWithErrno(ENOTTY)); +} + +// This test: +// - sets itself as the foreground process group +// - creates a child process in a new process group +// - sets that child as the foreground process group +// - kills its child and sets itself as the foreground process group. +TEST_F(JobControlTest, SetForegroundProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. + struct sigaction sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + sigaction(SIGTTOU, &sa, NULL); + + // Set ourself as the foreground process group. + ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); + + // Create a new process that just waits to be signaled. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!pause()); + // We should never reach this. + _exit(1); + } + + // Make the child its own process group, then make it the controlling process + // group of the terminal. + ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); + ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); + + // Sanity check - we're still the controlling session. + ASSERT_EQ(getsid(0), getsid(child)); + + // Signal the child, wait for it to exit, then retake the terminal. + ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_TRUE(WIFSIGNALED(wstatus)); + ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); + + // Set ourself as the foreground process. + pid_t pgid; + ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); + ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { + pid_t pid = getpid(); + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t pid = -1; + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Create a new process, put it in a new process group, make that group the + // foreground process group, then have the process wait. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!setpgid(0, 0)); + _exit(0); + } + + // Wait for the child to exit. + int wstatus; + EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + // The child's process group doesn't exist anymore - this should fail. + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), + SyscallFailsWithErrno(ESRCH)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Create a new process and put it in a new session. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // Tell the parent we're in a new session. + TEST_PCHECK(!raise(SIGSTOP)); + TEST_PCHECK(!pause()); + _exit(1); + } + + // Wait for the child to tell us it's in a new session. + int wstatus; + EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED), + SyscallSucceedsWithValue(child)); + EXPECT_TRUE(WSTOPSIG(wstatus)); + + // Child is in a new session, so we can't make it the foregroup process group. + EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), + SyscallFailsWithErrno(EPERM)); + + EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds()); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc new file mode 100644 index 000000000..d2a321a6e --- /dev/null +++ b/test/syscalls/linux/pty_root.cc @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/ioctl.h> +#include <termios.h> + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/pty_util.h" + +namespace gvisor { +namespace testing { + +// These tests should be run as root. +namespace { + +TEST(JobControlRootTest, StealTTY) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). + if (!IsRunningOnGvisor()) { + ASSERT_THAT(setsid(), SyscallSucceeds()); + } + + FileDescriptor master = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + + // Make slave the controlling terminal. + ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg + // of 1. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal with the wrong arg value. + TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); + // We should be able to steal it here. + TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +} // namespace +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index db519f4e0..f6a0fc96c 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -244,8 +244,10 @@ TEST(Pwritev2Test, TestInvalidOffset) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/static_cast<off_t>(-8), /*flags=*/0), @@ -286,8 +288,10 @@ TEST(Pwritev2Test, TestUnseekableFileInValid) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); int pipe_fds[2]; + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); @@ -307,8 +311,10 @@ TEST(Pwritev2Test, TestReadOnlyFile) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/0), @@ -324,8 +330,10 @@ TEST(Pwritev2Test, TestInvalidFlag) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR | O_DIRECT)); + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/0xF0), diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index ad19120d5..971592d7d 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -35,32 +35,6 @@ namespace testing { namespace { -// Compute the internet checksum of the ICMP header (assuming no payload). -static uint16_t Checksum(struct icmphdr* icmp) { - uint32_t total = 0; - uint16_t* num = reinterpret_cast<uint16_t*>(icmp); - - // This is just the ICMP header, so there's an even number of bytes. - static_assert( - sizeof(*icmp) % sizeof(*num) == 0, - "sizeof(struct icmphdr) is not an integer multiple of sizeof(uint16_t)"); - for (unsigned int i = 0; i < sizeof(*icmp); i += sizeof(*num)) { - total += *num; - num++; - } - - // Combine the upper and lower 16 bits. This happens twice in case the first - // combination causes a carry. - unsigned short upper = total >> 16; - unsigned short lower = total & 0xffff; - total = upper + lower; - upper = total >> 16; - lower = total & 0xffff; - total = upper + lower; - - return ~total; -} - // The size of an empty ICMP packet and IP header together. constexpr size_t kEmptyICMPSize = 28; @@ -164,7 +138,7 @@ TEST_F(RawSocketICMPTest, SendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2012; icmp.un.echo.id = 2014; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -187,7 +161,7 @@ TEST_F(RawSocketICMPTest, MultipleSocketReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2016; icmp.un.echo.id = 2018; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); // Both sockets will receive the echo request and reply in indeterminate @@ -297,7 +271,7 @@ TEST_F(RawSocketICMPTest, ShortEchoRawAndPingSockets) { icmp.un.echo.sequence = 0; icmp.un.echo.id = 6789; icmp.checksum = 0; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); // Omit 2 bytes from ICMP packet. constexpr int kShortICMPSize = sizeof(icmp) - 2; @@ -338,7 +312,7 @@ TEST_F(RawSocketICMPTest, ShortEchoReplyRawAndPingSockets) { icmp.un.echo.sequence = 0; icmp.un.echo.id = 6789; icmp.checksum = 0; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); // Omit 2 bytes from ICMP packet. constexpr int kShortICMPSize = sizeof(icmp) - 2; @@ -381,7 +355,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveViaConnect) { icmp.checksum = 0; icmp.un.echo.sequence = 2003; icmp.un.echo.id = 2004; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0), SyscallSucceedsWithValue(sizeof(icmp))); @@ -405,7 +379,7 @@ TEST_F(RawSocketICMPTest, BindSendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2004; icmp.un.echo.id = 2007; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -431,7 +405,7 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2010; icmp.un.echo.id = 7; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -470,9 +444,8 @@ void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id); // A couple are different. EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY); - // The checksum is computed in such a way that it is guaranteed to have - // changed. - EXPECT_NE(recvd_icmp->checksum, icmp.checksum); + // The checksum computed over the reply should still be valid. + EXPECT_EQ(ICMPChecksum(*recvd_icmp, NULL, 0), 0); break; } } diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index e5d72e28a..9167ab066 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -299,10 +299,30 @@ TEST(SendFileTest, DoNotSendfileIfOutfileIsAppendOnly) { // Open the output file as append only. const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_APPEND)); + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY | O_APPEND)); // Send data and verify that sendfile returns the correct errno. EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(SendFileTest, AppendCheckOrdering) { + constexpr char kData[] = "And by opposing end them: to die, to sleep"; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + const FileDescriptor read = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + const FileDescriptor write = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + const FileDescriptor append = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_APPEND)); + + // Check that read/write file mode is verified before append. + EXPECT_THAT(sendfile(append.get(), read.get(), nullptr, kDataSize), + SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(sendfile(write.get(), write.get(), nullptr, kDataSize), SyscallFailsWithErrno(EBADF)); } diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index 0404190a0..caae215b8 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -30,12 +30,25 @@ TEST(SocketTest, UnixSocketPairProtocol) { close(socks[1]); } -TEST(SocketTest, Protocol) { +TEST(SocketTest, ProtocolUnix) { struct { int domain, type, protocol; } tests[] = { - {AF_UNIX, SOCK_STREAM, PF_UNIX}, {AF_UNIX, SOCK_SEQPACKET, PF_UNIX}, - {AF_UNIX, SOCK_DGRAM, PF_UNIX}, {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, + {AF_UNIX, SOCK_STREAM, PF_UNIX}, + {AF_UNIX, SOCK_SEQPACKET, PF_UNIX}, + {AF_UNIX, SOCK_DGRAM, PF_UNIX}, + }; + for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { + ASSERT_NO_ERRNO_AND_VALUE( + Socket(tests[i].domain, tests[i].type, tests[i].protocol)); + } +} + +TEST(SocketTest, ProtocolInet) { + struct { + int domain, type, protocol; + } tests[] = { + {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, {AF_INET, SOCK_STREAM, IPPROTO_TCP}, }; for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index df31d25b5..322ee07ad 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -145,6 +145,67 @@ TEST_P(SocketInetLoopbackTest, TCP) { ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); } +TEST_P(SocketInetLoopbackTest, TCPListenClose) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), 1001), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + DisableSave ds; // Too many system calls. + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + constexpr int kFDs = 2048; + constexpr int kThreadCount = 4; + constexpr int kFDsPerThread = kFDs / kThreadCount; + FileDescriptor clients[kFDs]; + std::unique_ptr<ScopedThread> threads[kThreadCount]; + for (int i = 0; i < kFDs; i++) { + clients[i] = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + } + for (int i = 0; i < kThreadCount; i++) { + threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr, + &clients, i]() { + for (int j = 0; j < kFDsPerThread; j++) { + int k = i * kFDsPerThread + j; + int ret = + connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + } + }); + } + for (int i = 0; i < kThreadCount; i++) { + threads[i]->Join(); + } + for (int i = 0; i < 32; i++) { + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + } + // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked + // before function end. + // ds.reset() +} + TEST_P(SocketInetLoopbackTest, TCPbacklog) { auto const& param = GetParam(); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index d9aa7ff3f..67d29af0a 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -30,6 +30,7 @@ namespace gvisor { namespace testing { constexpr char kMulticastAddress[] = "224.0.2.1"; +constexpr char kBroadcastAddress[] = "255.255.255.255"; TestAddress V4Multicast() { TestAddress t("V4Multicast"); @@ -40,6 +41,15 @@ TestAddress V4Multicast() { return t; } +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kBroadcastAddress); + return t; +} + // Check that packets are not received without a group membership. Default send // interface configured by bind. TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { @@ -1426,5 +1436,249 @@ TEST_P(IPv4UDPUnboundSocketPairTest, } } +// Check that a receiving socket can bind to the multicast address before +// joining the group and receive data once the group has been joined. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the multicast address and won't +// receive multicast data if it hasn't joined the group. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we don't receive the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Check that a socket can bind to a multicast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the multicast address. + auto sender_addr = V4Multicast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the broadcast address and receive +// broadcast packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the broadcast address. + auto receiver_addr = V4Broadcast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a broadcast packet on the first socket out the loopback interface. + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + // Note: Binding to the loopback interface makes the broadcast go out of it. + auto sender_bind_addr = V4Loopback(); + ASSERT_THAT(bind(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); + auto sendto_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a socket can bind to the broadcast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the broadcast address. + auto sender_addr = V4Broadcast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 6a5fa8965..765f8e0e4 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -89,7 +89,8 @@ TEST(NetdeviceTest, Netmask) { // (i.e. netmask) for the loopback device. int prefixlen = -1; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr *hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr *hdr) { EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) @@ -111,7 +112,8 @@ TEST(NetdeviceTest, Netmask) { ifaddrmsg->ifa_family == AF_INET) { prefixlen = ifaddrmsg->ifa_prefixlen; } - })); + }, + false)); ASSERT_GE(prefixlen, 0); diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index b5c38f27e..32fe0d6d1 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <arpa/inet.h> #include <ifaddrs.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> @@ -237,7 +238,8 @@ TEST(NetlinkRouteTest, GetLinkDump) { // Loopback is common among all tests, check that it's found. bool loopbackFound = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { CheckGetLinkResponse(hdr, kSeq, port); if (hdr->nlmsg_type != RTM_NEWLINK) { return; @@ -251,10 +253,44 @@ TEST(NetlinkRouteTest, GetLinkDump) { loopbackFound = true; EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); } - })); + }, + false)); EXPECT_TRUE(loopbackFound); } +TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + // If type & 0x3 is equal to 0x2, this means a get request + // which doesn't require CAP_SYS_ADMIN. + req.hdr.nlmsg_type = ((__RTM_MAX + 1024) & (~0x3)) | 0x2; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR)); + EXPECT_EQ(hdr->nlmsg_seq, kSeq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + + const struct nlmsgerr* msg = + reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr)); + EXPECT_EQ(msg->error, -EOPNOTSUPP); + }, + true)); +} + TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); @@ -363,9 +399,11 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { req.ifm.ifi_family = AF_UNSPEC; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { CheckGetLinkResponse(hdr, kSeq, port); - })); + }, + false)); } TEST(NetlinkRouteTest, GetAddrDump) { @@ -387,7 +425,8 @@ TEST(NetlinkRouteTest, GetAddrDump) { req.rgm.rtgen_family = AF_UNSPEC; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) @@ -404,7 +443,8 @@ TEST(NetlinkRouteTest, GetAddrDump) { EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg)); // TODO(mpratt): Check ifaddrmsg contents and following attrs. - })); + }, + false)); } TEST(NetlinkRouteTest, LookupAll) { @@ -425,6 +465,80 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } +// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. +TEST(NetlinkRouteTest, GetRouteDump) { + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + + struct request { + struct nlmsghdr hdr; + struct rtmsg rtm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETROUTE; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rtm.rtm_family = AF_UNSPEC; + + bool routeFound = false; + bool dstFound = true; + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + // Validate the reponse to RTM_GETROUTE + NLM_F_DUMP. + EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWROUTE), Eq(NLMSG_DONE))); + + EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) + << std::hex << hdr->nlmsg_flags; + + EXPECT_EQ(hdr->nlmsg_seq, kSeq); + EXPECT_EQ(hdr->nlmsg_pid, port); + + // The test should not proceed if it's not a RTM_NEWROUTE message. + if (hdr->nlmsg_type != RTM_NEWROUTE) { + return; + } + + // RTM_NEWROUTE contains at least the header and rtmsg. + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg))); + const struct rtmsg* msg = + reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr)); + // NOTE: rtmsg fields are char fields. + std::cout << "Found route table=" << static_cast<int>(msg->rtm_table) + << ", protocol=" << static_cast<int>(msg->rtm_protocol) + << ", scope=" << static_cast<int>(msg->rtm_scope) + << ", type=" << static_cast<int>(msg->rtm_type); + + int len = RTM_PAYLOAD(hdr); + bool rtDstFound = false; + for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len); + attr = RTA_NEXT(attr, len)) { + if (attr->rta_type == RTA_DST) { + char address[INET_ADDRSTRLEN] = {}; + inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address)); + std::cout << ", dst=" << address; + rtDstFound = true; + } + } + + std::cout << std::endl; + + if (msg->rtm_table == RT_TABLE_MAIN) { + routeFound = true; + dstFound = rtDstFound && dstFound; + } + }, + false)); + // At least one route found in main route table. + EXPECT_TRUE(routeFound); + // Found RTA_DST for each route in main table. + EXPECT_TRUE(dstFound); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index 728d25434..36b6560c2 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -54,7 +54,8 @@ PosixErrorOr<uint32_t> NetlinkPortID(int fd) { PosixError NetlinkRequestResponse( const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { + const std::function<void(const struct nlmsghdr* hdr)>& fn, + bool expect_nlmsgerr) { struct iovec iov = {}; iov.iov_base = request; iov.iov_len = len; @@ -93,7 +94,11 @@ PosixError NetlinkRequestResponse( } } while (type != NLMSG_DONE && type != NLMSG_ERROR); - EXPECT_EQ(type, NLMSG_DONE); + if (expect_nlmsgerr) { + EXPECT_EQ(type, NLMSG_ERROR); + } else { + EXPECT_EQ(type, NLMSG_DONE); + } return NoError(); } diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index bea449107..db8639a2f 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -34,7 +34,8 @@ PosixErrorOr<uint32_t> NetlinkPortID(int fd); // Send the passed request and call fn will all response netlink messages. PosixError NetlinkRequestResponse( const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn); + const std::function<void(const struct nlmsghdr* hdr)>& fn, + bool expect_nlmsgerr); } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 4f65cf5ae..eff7d577e 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -588,8 +588,9 @@ ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets, return RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0); } -PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, - bool reuse_addr) { +namespace internal { +PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, + SocketType type, bool reuse_addr) { if (port < 0) { return PosixError(EINVAL, "Invalid port"); } @@ -664,10 +665,7 @@ PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, return available_port; } - -PosixError FreeAvailablePort(int port) { - return NoError(); -} +} // namespace internal PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) { struct iovec iov; @@ -744,5 +742,74 @@ TestAddress V6Loopback() { return t; } +// Checksum computes the internet checksum of a buffer. +uint16_t Checksum(uint16_t* buf, ssize_t buf_size) { + // Add up the 16-bit values in the buffer. + uint32_t total = 0; + for (unsigned int i = 0; i < buf_size; i += sizeof(*buf)) { + total += *buf; + buf++; + } + + // If buf has an odd size, add the remaining byte. + if (buf_size % 2) { + total += *(reinterpret_cast<unsigned char*>(buf) - 1); + } + + // This carries any bits past the lower 16 until everything fits in 16 bits. + while (total >> 16) { + uint16_t lower = total & 0xffff; + uint16_t upper = total >> 16; + total = lower + upper; + } + + return ~total; +} + +uint16_t IPChecksum(struct iphdr ip) { + return Checksum(reinterpret_cast<uint16_t*>(&ip), sizeof(ip)); +} + +// The pseudo-header defined in RFC 768 for calculating the UDP checksum. +struct udp_pseudo_hdr { + uint32_t srcip; + uint32_t destip; + char zero; + char protocol; + uint16_t udplen; +}; + +uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, + const char* payload, ssize_t payload_len) { + struct udp_pseudo_hdr phdr = {}; + phdr.srcip = iphdr.saddr; + phdr.destip = iphdr.daddr; + phdr.zero = 0; + phdr.protocol = IPPROTO_UDP; + phdr.udplen = udphdr.len; + + ssize_t buf_size = sizeof(phdr) + sizeof(udphdr) + payload_len; + char* buf = static_cast<char*>(malloc(buf_size)); + memcpy(buf, &phdr, sizeof(phdr)); + memcpy(buf + sizeof(phdr), &udphdr, sizeof(udphdr)); + memcpy(buf + sizeof(phdr) + sizeof(udphdr), payload, payload_len); + + uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); + free(buf); + return csum; +} + +uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, + ssize_t payload_len) { + ssize_t buf_size = sizeof(icmphdr) + payload_len; + char* buf = static_cast<char*>(malloc(buf_size)); + memcpy(buf, &icmphdr, sizeof(icmphdr)); + memcpy(buf + sizeof(icmphdr), payload, payload_len); + + uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); + free(buf); + return csum; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 4fd59767a..6efa8055f 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -17,9 +17,12 @@ #include <errno.h> #include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <netinet/udp.h> #include <sys/socket.h> #include <sys/types.h> #include <sys/un.h> + #include <functional> #include <memory> #include <string> @@ -478,6 +481,22 @@ TestAddress V4MappedLoopback(); TestAddress V6Any(); TestAddress V6Loopback(); +// Compute the internet checksum of an IP header. +uint16_t IPChecksum(struct iphdr ip); + +// Compute the internet checksum of a UDP header. +uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, + const char* payload, ssize_t payload_len); + +// Compute the internet checksum of an ICMP header. +uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, + ssize_t payload_len); + +namespace internal { +PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, + SocketType type, bool reuse_addr); +} // namespace internal + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util_impl.cc b/test/syscalls/linux/socket_test_util_impl.cc new file mode 100644 index 000000000..ef661a0e3 --- /dev/null +++ b/test/syscalls/linux/socket_test_util_impl.cc @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, + bool reuse_addr) { + return internal::TryPortAvailable(port, family, type, reuse_addr); +} + +PosixError FreeAvailablePort(int port) { return NoError(); } + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index 1875f4533..e25f264f6 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -135,6 +136,80 @@ TEST(SpliceTest, PipeOffsets) { SyscallFailsWithErrno(ESPIPE)); } +// Event FDs may be used with splice without an offset. +TEST(SpliceTest, FromEventFD) { + // Open the input eventfd with an initial value so that it is readable. + constexpr uint64_t kEventFDValue = 1; + int efd; + ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds()); + const FileDescriptor inf(efd); + + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Splice 8-byte eventfd value to pipe. + constexpr int kEventFDSize = 8; + EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), + SyscallSucceedsWithValue(kEventFDSize)); + + // Contents should be equal. + std::vector<char> rbuf(kEventFDSize); + ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kEventFDSize)); + EXPECT_EQ(memcmp(rbuf.data(), &kEventFDValue, rbuf.size()), 0); +} + +// Event FDs may not be used with splice with an offset. +TEST(SpliceTest, FromEventFDOffset) { + int efd; + ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); + const FileDescriptor inf(efd); + + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Attempt to splice 8-byte eventfd value to pipe with offset. + // + // This is not allowed because eventfd doesn't support pread. + constexpr int kEventFDSize = 8; + loff_t in_off = 0; + EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), + SyscallFailsWithErrno(EINVAL)); +} + +// Event FDs may not be used with splice with an offset. +TEST(SpliceTest, ToEventFDOffset) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill with a value. + constexpr int kEventFDSize = 8; + std::vector<char> buf(kEventFDSize); + buf[0] = 1; + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kEventFDSize)); + + int efd; + ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); + const FileDescriptor outf(efd); + + // Attempt to splice 8-byte eventfd value to pipe with offset. + // + // This is not allowed because eventfd doesn't support pwrite. + loff_t out_off = 0; + EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0), + SyscallFailsWithErrno(EINVAL)); +} + TEST(SpliceTest, ToPipe) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 1bb0307c4..111dbacdf 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -39,7 +39,7 @@ constexpr int TestPort = 40000; // Fixture for tests parameterized by the address family to use (AF_INET and // AF_INET6) when creating sockets. -class UdpSocketTest : public ::testing::TestWithParam<int> { +class UdpSocketTest : public ::testing::TestWithParam<AddressFamily> { protected: // Creates two sockets that will be used by test cases. void SetUp() override; @@ -97,31 +97,32 @@ uint16_t* Port(struct sockaddr_storage* addr) { } void UdpSocketTest::SetUp() { - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + int type; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); + addrlen_ = sizeof(*sin); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + type = AF_INET6; + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); + addrlen_ = sizeof(*sin6); + if (GetParam() == AddressFamily::kIpv6) { + sin6->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + sin6->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - ASSERT_THAT(t_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_); - anyaddr_->sa_family = GetParam(); - - // Initialize address-family-specific values. - switch (GetParam()) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin6); - sin6->sin6_addr = in6addr_any; - break; - } - } + anyaddr_->sa_family = type; if (gvisor::testing::IsRunningOnGvisor()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { @@ -154,9 +155,9 @@ void UdpSocketTest::SetUp() { memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]); - addr_[i]->sa_family = GetParam(); + addr_[i]->sa_family = type; - switch (GetParam()) { + switch (type) { case AF_INET: { auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]); sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); @@ -174,17 +175,20 @@ void UdpSocketTest::SetUp() { } TEST_P(UdpSocketTest, Creation) { + int type = AF_INET6; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + } + int s_; - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); EXPECT_THAT(close(s_), SyscallSucceeds()); - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, 0), SyscallSucceeds()); + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); EXPECT_THAT(close(s_), SyscallSucceeds()); - ASSERT_THAT(s_ = socket(GetParam(), SOCK_STREAM, IPPROTO_UDP), - SyscallFails()); + ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); } TEST_P(UdpSocketTest, Getsockname) { @@ -374,6 +378,178 @@ TEST_P(UdpSocketTest, Connect) { EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); } +void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { + struct sockaddr_storage addr = {}; + + // Precondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr any = IN6ADDR_ANY_INIT; + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); + } + + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } + + struct sockaddr_storage baddr = {}; + if (family == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), + SyscallSucceeds()); + } + + // Postcondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback; + if (family == AddressFamily::kIpv6) { + loopback = IN6ADDR_LOOPBACK_INIT; + } else { + TestAddress const& v4_mapped_loopback = V4MappedLoopback(); + loopback = reinterpret_cast<const struct sockaddr_in6*>( + &v4_mapped_loopback.addr) + ->sin6_addr; + } + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } + + addrlen = sizeof(addr); + if (port == 0) { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } else { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + } + } +} + +TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + ConnectAny(GetParam(), s_, port); +} + +void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { + struct sockaddr_storage addr = {}; + + socklen_t addrlen = sizeof(addr); + struct sockaddr_storage baddr = {}; + if (family == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), + SyscallSucceeds()); + // Now the socket is bound to the loopback address. + + // Disconnect + addrlen = sizeof(addr); + addr.ss_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + DisconnectAfterConnectAny(GetParam(), s_, 0); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + DisconnectAfterConnectAny(GetParam(), s_, port); +} + TEST_P(UdpSocketTest, DisconnectAfterBind) { ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); // Connect the socket. @@ -402,19 +578,17 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { struct sockaddr_storage baddr = {}; socklen_t addrlen; auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (addr_[0]->sa_family == AF_INET) { + if (GetParam() == AddressFamily::kIpv4) { auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); addr_in->sin_family = AF_INET; addr_in->sin_port = port; - inet_pton(AF_INET, "0.0.0.0", - reinterpret_cast<void*>(&addr_in->sin_addr.s_addr)); + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); } else { auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); addr_in->sin6_family = AF_INET6; addr_in->sin6_port = port; - inet_pton(AF_INET6, - "::", reinterpret_cast<void*>(&addr_in->sin6_addr.s6_addr)); addr_in->sin6_scope_id = 0; + addr_in->sin6_addr = IN6ADDR_ANY_INIT; } ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), SyscallSucceeds()); @@ -1165,7 +1339,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { } INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, - ::testing::Values(AF_INET, AF_INET6)); + ::testing::Values(AddressFamily::kIpv4, + AddressFamily::kIpv6, + AddressFamily::kDualStack)); } // namespace |