summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go61
-rw-r--r--test/syscalls/BUILD2
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/raw_socket.cc (renamed from test/syscalls/linux/raw_socket_ipv4.cc)190
4 files changed, 148 insertions, 109 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index dd514d397..766c7648e 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -94,7 +94,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -215,6 +215,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -319,12 +324,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrNoRoute
}
- // We don't support IPv6 yet, so this has to be an IPv4 address.
- if len(opts.To.Addr) != header.IPv4AddressSize {
- e.mu.RUnlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
@@ -354,17 +353,13 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch e.NetProto {
- case header.IPv4ProtocolNumber:
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- }); err != nil {
- return 0, nil, err
- }
- break
+ if !e.associated {
+ if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
+ return 0, nil, err
}
-
+ } else {
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
@@ -373,9 +368,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}); err != nil {
return 0, nil, err
}
-
- default:
- return 0, nil, tcpip.ErrUnknownProtocol
}
return int64(len(payloadBytes)), nil, nil
@@ -400,11 +392,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // We don't support IPv6 yet.
- if len(addr.Addr) != header.IPv4AddressSize {
- return tcpip.ErrInvalidEndpointState
- }
-
nic := addr.NIC
if e.bound {
if e.BindNICID == 0 {
@@ -470,14 +457,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // Callers must provide an IPv4 address or no network address (for
- // binding to a NIC, but not an address).
- if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
- return tcpip.ErrInvalidEndpointState
- }
-
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -680,9 +661,19 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
},
}
- headers := append(buffer.View(nil), pkt.NetworkHeader...)
- headers = append(headers, pkt.TransportHeader...)
- combinedVV := headers.ToVectorisedView()
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader))
+ headers = append(headers, pkt.NetworkHeader...)
+ headers = append(headers, pkt.TransportHeader...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView()
+ }
combinedVV.Append(pkt.Data)
packet.data = combinedVV
packet.timestampNS = e.stack.NowNanoseconds()
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 23123006f..c4fff0ac8 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -547,7 +547,7 @@ syscall_test(
)
syscall_test(
- test = "//test/syscalls/linux:raw_socket_ipv4_test",
+ test = "//test/syscalls/linux:raw_socket_test",
vfs2 = "True",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 7282d675e..270b9e4c4 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1800,9 +1800,9 @@ cc_binary(
)
cc_binary(
- name = "raw_socket_ipv4_test",
+ name = "raw_socket_test",
testonly = 1,
- srcs = ["raw_socket_ipv4.cc"],
+ srcs = ["raw_socket.cc"],
linkstatic = 1,
deps = [
":socket_test_util",
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket.cc
index 0116c3e94..05c4ed03f 100644
--- a/test/syscalls/linux/raw_socket_ipv4.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -15,12 +15,12 @@
#include <linux/capability.h>
#include <netinet/in.h>
#include <netinet/ip.h>
+#include <netinet/ip6.h>
#include <netinet/ip_icmp.h>
#include <poll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
-
#include <algorithm>
#include "gtest/gtest.h"
@@ -39,7 +39,7 @@ namespace testing {
namespace {
// Fixture for tests parameterized by protocol.
-class RawSocketTest : public ::testing::TestWithParam<int> {
+class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
protected:
// Creates a socket to be used in tests.
void SetUp() override;
@@ -50,36 +50,58 @@ class RawSocketTest : public ::testing::TestWithParam<int> {
// Sends buf via s_.
void SendBuf(const char* buf, int buf_len);
- // Sends buf to the provided address via the provided socket.
- void SendBufTo(int sock, const struct sockaddr_in& addr, const char* buf,
- int buf_len);
-
// Reads from s_ into recv_buf.
void ReceiveBuf(char* recv_buf, size_t recv_buf_len);
- int Protocol() { return GetParam(); }
+ void ReceiveBufFrom(int sock, char* recv_buf, size_t recv_buf_len);
+
+ int Protocol() { return std::get<0>(GetParam()); }
+
+ int Family() { return std::get<1>(GetParam()); }
+
+ socklen_t AddrLen() {
+ if (Family() == AF_INET) {
+ return sizeof(sockaddr_in);
+ }
+ return sizeof(sockaddr_in6);
+ }
+
+ int HdrLen() {
+ if (Family() == AF_INET) {
+ return sizeof(struct iphdr);
+ }
+ // IPv6 raw sockets don't include the header.
+ return 0;
+ }
// The socket used for both reading and writing.
int s_;
// The loopback address.
- struct sockaddr_in addr_;
+ struct sockaddr_storage addr_;
};
void RawSocketTest::SetUp() {
if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
- ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()),
+ ASSERT_THAT(socket(Family(), SOCK_RAW, Protocol()),
SyscallFailsWithErrno(EPERM));
GTEST_SKIP();
}
- ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
+ ASSERT_THAT(s_ = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
addr_ = {};
// We don't set ports because raw sockets don't have a notion of ports.
- addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- addr_.sin_family = AF_INET;
+ if (Family() == AF_INET) {
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr_);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ } else {
+ struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr_);
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_addr = in6addr_loopback;
+ }
}
void RawSocketTest::TearDown() {
@@ -96,7 +118,7 @@ TEST_P(RawSocketTest, MultipleCreation) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
int s2;
- ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
+ ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
ASSERT_THAT(close(s2), SyscallSucceeds());
}
@@ -114,7 +136,7 @@ TEST_P(RawSocketTest, ShutdownWriteNoop) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds());
@@ -129,7 +151,7 @@ TEST_P(RawSocketTest, ShutdownReadNoop) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
@@ -137,9 +159,8 @@ TEST_P(RawSocketTest, ShutdownReadNoop) {
constexpr char kBuf[] = "gdg";
ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
- constexpr size_t kReadSize = sizeof(kBuf) + sizeof(struct iphdr);
- char c[kReadSize];
- ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(kReadSize));
+ std::vector<char> c(sizeof(kBuf) + HdrLen());
+ ASSERT_THAT(read(s_, c.data(), c.size()), SyscallSucceedsWithValue(c.size()));
}
// Test that listen() fails.
@@ -173,7 +194,7 @@ TEST_P(RawSocketTest, GetPeerName) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
struct sockaddr saddr;
socklen_t addrlen = sizeof(saddr);
@@ -223,7 +244,7 @@ TEST_P(RawSocketTest, ConnectToLoopback) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
}
@@ -242,7 +263,7 @@ TEST_P(RawSocketTest, BindToLocalhost) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
}
@@ -250,12 +271,18 @@ TEST_P(RawSocketTest, BindToLocalhost) {
TEST_P(RawSocketTest, BindToInvalid) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- struct sockaddr_in bind_addr = {};
- bind_addr.sin_family = AF_INET;
- bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
+ struct sockaddr_storage bind_addr = addr_;
+ if (Family() == AF_INET) {
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&bind_addr);
+ sin->sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
+ } else {
+ struct sockaddr_in6* sin6 =
+ reinterpret_cast<struct sockaddr_in6*>(&bind_addr);
+ memset(&sin6->sin6_addr.s6_addr, 0, sizeof(sin6->sin6_addr.s6_addr));
+ sin6->sin6_addr.s6_addr[0] = 1; // 1: - An address that we can't bind to.
+ }
ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)),
- SyscallFailsWithErrno(EADDRNOTAVAIL));
+ AddrLen()), SyscallFailsWithErrno(EADDRNOTAVAIL));
}
// Send and receive an packet.
@@ -267,9 +294,9 @@ TEST_P(RawSocketTest, SendAndReceive) {
ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
// Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
}
// We should be able to create multiple raw sockets for the same protocol and
@@ -278,22 +305,23 @@ TEST_P(RawSocketTest, MultipleSocketReceive) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
int s2;
- ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
+ ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
// Arbitrary.
constexpr char kBuf[] = "TB10";
ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
// Receive it on socket 1.
- char recv_buf1[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1, sizeof(recv_buf1)));
+ std::vector<char> recv_buf1(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1.data(), recv_buf1.size()));
// Receive it on socket 2.
- char recv_buf2[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s2, recv_buf2, sizeof(recv_buf2)));
+ std::vector<char> recv_buf2(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s2, recv_buf2.data(),
+ recv_buf2.size()));
- EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr),
- recv_buf2 + sizeof(struct iphdr), sizeof(kBuf)),
+ EXPECT_EQ(memcmp(recv_buf1.data() + HdrLen(),
+ recv_buf2.data() + HdrLen(), sizeof(kBuf)),
0);
ASSERT_THAT(close(s2), SyscallSucceeds());
@@ -304,7 +332,7 @@ TEST_P(RawSocketTest, SendAndReceiveViaConnect) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
// Arbitrary.
@@ -313,9 +341,9 @@ TEST_P(RawSocketTest, SendAndReceiveViaConnect) {
SyscallSucceedsWithValue(sizeof(kBuf)));
// Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
}
// Bind to localhost, then send and receive packets.
@@ -323,7 +351,7 @@ TEST_P(RawSocketTest, BindSendAndReceive) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
// Arbitrary.
@@ -331,9 +359,9 @@ TEST_P(RawSocketTest, BindSendAndReceive) {
ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
// Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
}
// Bind and connect to localhost and send/receive packets.
@@ -341,10 +369,10 @@ TEST_P(RawSocketTest, BindConnectSendAndReceive) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
// Arbitrary.
@@ -352,9 +380,9 @@ TEST_P(RawSocketTest, BindConnectSendAndReceive) {
ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
// Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
}
// Check that setting SO_RCVBUF below min is clamped to the minimum
@@ -580,20 +608,16 @@ TEST_P(RawSocketTest, SetSocketSendBuf) {
ASSERT_EQ(quarter_sz, val);
}
-void RawSocketTest::SendBuf(const char* buf, int buf_len) {
- ASSERT_NO_FATAL_FAILURE(SendBufTo(s_, addr_, buf, buf_len));
-}
-
// Test that receive buffer limits are not enforced when the recv buffer is
// empty.
TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
int min = 0;
@@ -616,10 +640,10 @@ TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) {
ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
// Receive the packet and make sure it's identical.
- std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr));
+ std::vector<char> recv_buf(buf.size() + HdrLen());
ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
EXPECT_EQ(
- memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()),
+ memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()),
0);
}
@@ -631,10 +655,10 @@ TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) {
RandomizeBuffer(buf.data(), buf.size());
ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
// Receive the packet and make sure it's identical.
- std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr));
+ std::vector<char> recv_buf(buf.size() + HdrLen());
ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
EXPECT_EQ(
- memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()),
+ memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()),
0);
}
}
@@ -652,10 +676,10 @@ TEST_P(RawSocketTest, RecvBufLimits) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
SyscallSucceeds());
int min = 0;
@@ -716,16 +740,16 @@ TEST_P(RawSocketTest, RecvBufLimits) {
// Verify that the expected number of packets are available to be read.
for (int i = 0; i < sent - 1; i++) {
// Receive the packet and make sure it's identical.
- std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr));
+ std::vector<char> recv_buf(buf.size() + HdrLen());
ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
- EXPECT_EQ(memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(),
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), buf.data(),
buf.size()),
0);
}
// Assert that the last packet is dropped because the receive buffer should
// be full after the first four packets.
- std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr));
+ std::vector<char> recv_buf(buf.size() + HdrLen());
struct iovec iov = {};
iov.iov_base = static_cast<void*>(const_cast<char*>(recv_buf.data()));
iov.iov_len = buf.size();
@@ -740,30 +764,54 @@ TEST_P(RawSocketTest, RecvBufLimits) {
}
}
-void RawSocketTest::SendBufTo(int sock, const struct sockaddr_in& addr,
- const char* buf, int buf_len) {
+void RawSocketTest::SendBuf(const char* buf, int buf_len) {
// It's safe to use const_cast here because sendmsg won't modify the iovec or
// address.
struct iovec iov = {};
iov.iov_base = static_cast<void*>(const_cast<char*>(buf));
iov.iov_len = static_cast<size_t>(buf_len);
struct msghdr msg = {};
- msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr));
- msg.msg_namelen = sizeof(addr);
+ msg.msg_name = static_cast<void*>(&addr_);
+ msg.msg_namelen = AddrLen();
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = NULL;
msg.msg_controllen = 0;
msg.msg_flags = 0;
- ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(buf_len));
+ ASSERT_THAT(sendmsg(s_, &msg, 0), SyscallSucceedsWithValue(buf_len));
}
void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) {
- ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, recv_buf_len));
+ ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s_, recv_buf, recv_buf_len));
+}
+
+void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf,
+ size_t recv_buf_len) {
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len));
}
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest,
- ::testing::Values(IPPROTO_TCP, IPPROTO_UDP));
+ ::testing::Combine(
+ ::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
+ ::testing::Values(AF_INET, AF_INET6)));
+
+// AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to.
+TEST(RawSocketTest, IPv6ProtoRaw) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW),
+ SyscallSucceeds());
+
+ // Verify that writing yields EINVAL.
+ char buf[] = "This is such a weird little edge case";
+ struct sockaddr_in6 sin6 = {};
+ sin6.sin6_family = AF_INET6;
+ sin6.sin6_addr = in6addr_loopback;
+ ASSERT_THAT(sendto(sock, buf, sizeof(buf), 0 /* flags */,
+ reinterpret_cast<struct sockaddr*>(&sin6), sizeof(sin6)),
+ SyscallFailsWithErrno(EINVAL));
+}
} // namespace