diff options
author | Kevin Krakauer <krakauer@google.com> | 2020-06-26 19:04:59 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-06-26 19:07:02 -0700 |
commit | 66d1665441461a5226ba0c884e22888d58f393b6 (patch) | |
tree | d0e014b672eecefe76a3414973cd3fa641cde816 | |
parent | 8dbeac53ce1b3c1cf4a5f2f0ccdd7196f4656fd8 (diff) |
IPv6 raw sockets. Needed for ip6tables.
IPv6 raw sockets never include the IPv6 header.
PiperOrigin-RevId: 318582989
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 61 | ||||
-rw-r--r-- | test/syscalls/BUILD | 2 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 4 | ||||
-rw-r--r-- | test/syscalls/linux/raw_socket.cc (renamed from test/syscalls/linux/raw_socket_ipv4.cc) | 190 |
4 files changed, 148 insertions, 109 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index dd514d397..766c7648e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -94,7 +94,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if netProto != header.IPv4ProtocolNumber { + if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -215,6 +215,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // We can create, but not write to, unassociated IPv6 endpoints. + if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + return 0, nil, tcpip.ErrInvalidOptionValue + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -319,12 +324,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrNoRoute } - // We don't support IPv6 yet, so this has to be an IPv4 address. - if len(opts.To.Addr) != header.IPv4AddressSize { - e.mu.RUnlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) @@ -354,17 +353,13 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch e.NetProto { - case header.IPv4ProtocolNumber: - if !e.associated { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ - Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { - return 0, nil, err - } - break + if !e.associated { + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { + return 0, nil, err } - + } else { hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, @@ -373,9 +368,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, }); err != nil { return 0, nil, err } - - default: - return 0, nil, tcpip.ErrUnknownProtocol } return int64(len(payloadBytes)), nil, nil @@ -400,11 +392,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // We don't support IPv6 yet. - if len(addr.Addr) != header.IPv4AddressSize { - return tcpip.ErrInvalidEndpointState - } - nic := addr.NIC if e.bound { if e.BindNICID == 0 { @@ -470,14 +457,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // Callers must provide an IPv4 address or no network address (for - // binding to a NIC, but not an address). - if len(addr.Addr) != 0 && len(addr.Addr) != 4 { - return tcpip.ErrInvalidEndpointState - } - // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } @@ -680,9 +661,19 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { }, } - headers := append(buffer.View(nil), pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) - combinedVV := headers.ToVectorisedView() + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + var combinedVV buffer.VectorisedView + if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) + headers = append(headers, pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + } combinedVV.Append(pkt.Data) packet.data = combinedVV packet.timestampNS = e.stack.NowNanoseconds() diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 23123006f..c4fff0ac8 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -547,7 +547,7 @@ syscall_test( ) syscall_test( - test = "//test/syscalls/linux:raw_socket_ipv4_test", + test = "//test/syscalls/linux:raw_socket_test", vfs2 = "True", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 7282d675e..270b9e4c4 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1800,9 +1800,9 @@ cc_binary( ) cc_binary( - name = "raw_socket_ipv4_test", + name = "raw_socket_test", testonly = 1, - srcs = ["raw_socket_ipv4.cc"], + srcs = ["raw_socket.cc"], linkstatic = 1, deps = [ ":socket_test_util", diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket.cc index 0116c3e94..05c4ed03f 100644 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -15,12 +15,12 @@ #include <linux/capability.h> #include <netinet/in.h> #include <netinet/ip.h> +#include <netinet/ip6.h> #include <netinet/ip_icmp.h> #include <poll.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> - #include <algorithm> #include "gtest/gtest.h" @@ -39,7 +39,7 @@ namespace testing { namespace { // Fixture for tests parameterized by protocol. -class RawSocketTest : public ::testing::TestWithParam<int> { +class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { protected: // Creates a socket to be used in tests. void SetUp() override; @@ -50,36 +50,58 @@ class RawSocketTest : public ::testing::TestWithParam<int> { // Sends buf via s_. void SendBuf(const char* buf, int buf_len); - // Sends buf to the provided address via the provided socket. - void SendBufTo(int sock, const struct sockaddr_in& addr, const char* buf, - int buf_len); - // Reads from s_ into recv_buf. void ReceiveBuf(char* recv_buf, size_t recv_buf_len); - int Protocol() { return GetParam(); } + void ReceiveBufFrom(int sock, char* recv_buf, size_t recv_buf_len); + + int Protocol() { return std::get<0>(GetParam()); } + + int Family() { return std::get<1>(GetParam()); } + + socklen_t AddrLen() { + if (Family() == AF_INET) { + return sizeof(sockaddr_in); + } + return sizeof(sockaddr_in6); + } + + int HdrLen() { + if (Family() == AF_INET) { + return sizeof(struct iphdr); + } + // IPv6 raw sockets don't include the header. + return 0; + } // The socket used for both reading and writing. int s_; // The loopback address. - struct sockaddr_in addr_; + struct sockaddr_storage addr_; }; void RawSocketTest::SetUp() { if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()), + ASSERT_THAT(socket(Family(), SOCK_RAW, Protocol()), SyscallFailsWithErrno(EPERM)); GTEST_SKIP(); } - ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); + ASSERT_THAT(s_ = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); addr_ = {}; // We don't set ports because raw sockets don't have a notion of ports. - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr_); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + } else { + struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr_); + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = in6addr_loopback; + } } void RawSocketTest::TearDown() { @@ -96,7 +118,7 @@ TEST_P(RawSocketTest, MultipleCreation) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); + ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); ASSERT_THAT(close(s2), SyscallSucceeds()); } @@ -114,7 +136,7 @@ TEST_P(RawSocketTest, ShutdownWriteNoop) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); @@ -129,7 +151,7 @@ TEST_P(RawSocketTest, ShutdownReadNoop) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); @@ -137,9 +159,8 @@ TEST_P(RawSocketTest, ShutdownReadNoop) { constexpr char kBuf[] = "gdg"; ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - constexpr size_t kReadSize = sizeof(kBuf) + sizeof(struct iphdr); - char c[kReadSize]; - ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(kReadSize)); + std::vector<char> c(sizeof(kBuf) + HdrLen()); + ASSERT_THAT(read(s_, c.data(), c.size()), SyscallSucceedsWithValue(c.size())); } // Test that listen() fails. @@ -173,7 +194,7 @@ TEST_P(RawSocketTest, GetPeerName) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); struct sockaddr saddr; socklen_t addrlen = sizeof(saddr); @@ -223,7 +244,7 @@ TEST_P(RawSocketTest, ConnectToLoopback) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); } @@ -242,7 +263,7 @@ TEST_P(RawSocketTest, BindToLocalhost) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); } @@ -250,12 +271,18 @@ TEST_P(RawSocketTest, BindToLocalhost) { TEST_P(RawSocketTest, BindToInvalid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct sockaddr_in bind_addr = {}; - bind_addr.sin_family = AF_INET; - bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. + struct sockaddr_storage bind_addr = addr_; + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&bind_addr); + sin->sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. + } else { + struct sockaddr_in6* sin6 = + reinterpret_cast<struct sockaddr_in6*>(&bind_addr); + memset(&sin6->sin6_addr.s6_addr, 0, sizeof(sin6->sin6_addr.s6_addr)); + sin6->sin6_addr.s6_addr[0] = 1; // 1: - An address that we can't bind to. + } ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); + AddrLen()), SyscallFailsWithErrno(EADDRNOTAVAIL)); } // Send and receive an packet. @@ -267,9 +294,9 @@ TEST_P(RawSocketTest, SendAndReceive) { ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // We should be able to create multiple raw sockets for the same protocol and @@ -278,22 +305,23 @@ TEST_P(RawSocketTest, MultipleSocketReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); + ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); // Arbitrary. constexpr char kBuf[] = "TB10"; ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive it on socket 1. - char recv_buf1[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1, sizeof(recv_buf1))); + std::vector<char> recv_buf1(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1.data(), recv_buf1.size())); // Receive it on socket 2. - char recv_buf2[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s2, recv_buf2, sizeof(recv_buf2))); + std::vector<char> recv_buf2(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s2, recv_buf2.data(), + recv_buf2.size())); - EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr), - recv_buf2 + sizeof(struct iphdr), sizeof(kBuf)), + EXPECT_EQ(memcmp(recv_buf1.data() + HdrLen(), + recv_buf2.data() + HdrLen(), sizeof(kBuf)), 0); ASSERT_THAT(close(s2), SyscallSucceeds()); @@ -304,7 +332,7 @@ TEST_P(RawSocketTest, SendAndReceiveViaConnect) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); // Arbitrary. @@ -313,9 +341,9 @@ TEST_P(RawSocketTest, SendAndReceiveViaConnect) { SyscallSucceedsWithValue(sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // Bind to localhost, then send and receive packets. @@ -323,7 +351,7 @@ TEST_P(RawSocketTest, BindSendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); // Arbitrary. @@ -331,9 +359,9 @@ TEST_P(RawSocketTest, BindSendAndReceive) { ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // Bind and connect to localhost and send/receive packets. @@ -341,10 +369,10 @@ TEST_P(RawSocketTest, BindConnectSendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); // Arbitrary. @@ -352,9 +380,9 @@ TEST_P(RawSocketTest, BindConnectSendAndReceive) { ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // Check that setting SO_RCVBUF below min is clamped to the minimum @@ -580,20 +608,16 @@ TEST_P(RawSocketTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } -void RawSocketTest::SendBuf(const char* buf, int buf_len) { - ASSERT_NO_FATAL_FAILURE(SendBufTo(s_, addr_, buf, buf_len)); -} - // Test that receive buffer limits are not enforced when the recv buffer is // empty. TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); int min = 0; @@ -616,10 +640,10 @@ TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); // Receive the packet and make sure it's identical. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); EXPECT_EQ( - memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()), + memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), 0); } @@ -631,10 +655,10 @@ TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { RandomizeBuffer(buf.data(), buf.size()); ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); // Receive the packet and make sure it's identical. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); EXPECT_EQ( - memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()), + memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), 0); } } @@ -652,10 +676,10 @@ TEST_P(RawSocketTest, RecvBufLimits) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); int min = 0; @@ -716,16 +740,16 @@ TEST_P(RawSocketTest, RecvBufLimits) { // Verify that the expected number of packets are available to be read. for (int i = 0; i < sent - 1; i++) { // Receive the packet and make sure it's identical. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); - EXPECT_EQ(memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), 0); } // Assert that the last packet is dropped because the receive buffer should // be full after the first four packets. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); struct iovec iov = {}; iov.iov_base = static_cast<void*>(const_cast<char*>(recv_buf.data())); iov.iov_len = buf.size(); @@ -740,30 +764,54 @@ TEST_P(RawSocketTest, RecvBufLimits) { } } -void RawSocketTest::SendBufTo(int sock, const struct sockaddr_in& addr, - const char* buf, int buf_len) { +void RawSocketTest::SendBuf(const char* buf, int buf_len) { // It's safe to use const_cast here because sendmsg won't modify the iovec or // address. struct iovec iov = {}; iov.iov_base = static_cast<void*>(const_cast<char*>(buf)); iov.iov_len = static_cast<size_t>(buf_len); struct msghdr msg = {}; - msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr)); - msg.msg_namelen = sizeof(addr); + msg.msg_name = static_cast<void*>(&addr_); + msg.msg_namelen = AddrLen(); msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = 0; - ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(buf_len)); + ASSERT_THAT(sendmsg(s_, &msg, 0), SyscallSucceedsWithValue(buf_len)); } void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) { - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, recv_buf_len)); + ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s_, recv_buf, recv_buf_len)); +} + +void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf, + size_t recv_buf_len) { + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len)); } INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest, - ::testing::Values(IPPROTO_TCP, IPPROTO_UDP)); + ::testing::Combine( + ::testing::Values(IPPROTO_TCP, IPPROTO_UDP), + ::testing::Values(AF_INET, AF_INET6))); + +// AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to. +TEST(RawSocketTest, IPv6ProtoRaw) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW), + SyscallSucceeds()); + + // Verify that writing yields EINVAL. + char buf[] = "This is such a weird little edge case"; + struct sockaddr_in6 sin6 = {}; + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = in6addr_loopback; + ASSERT_THAT(sendto(sock, buf, sizeof(buf), 0 /* flags */, + reinterpret_cast<struct sockaddr*>(&sin6), sizeof(sin6)), + SyscallFailsWithErrno(EINVAL)); +} } // namespace |