diff options
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 23 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 1 | ||||
-rw-r--r-- | test/syscalls/linux/packet_socket.cc | 171 |
4 files changed, 180 insertions, 20 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f79bda922..aa081e90d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -672,13 +672,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) - if a.Protocol != uint16(s.protocol) { - return syserr.ErrInvalidArgument - } - addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: socket.Ntohs(a.Protocol), } } else { if s.minSockAddrLen() > len(sockaddr) { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 2c9786175..1f30e5adb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -59,13 +59,11 @@ type packet struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue cooked bool ops tcpip.SocketOptions @@ -84,6 +82,8 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu + netProto tcpip.NetworkProtocolNumber + // +checklocks:mu closed bool // +checklocks:mu bound bool @@ -98,10 +98,7 @@ type endpoint struct { // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - }, + stack: s, cooked: cooked, netProto: netProto, waiterQueue: waiterQueue, @@ -214,13 +211,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC + proto := ep.netProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} } var remote tcpip.LinkAddress - proto := ep.netProto if to := opts.To; to != nil { remote = tcpip.LinkAddress(to.Addr) @@ -296,7 +293,8 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound && ep.boundNIC == addr.NIC { + netProto := tcpip.NetworkProtocolNumber(addr.Port) + if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto { // If the NIC being bound is the same then just return success. return nil } @@ -306,12 +304,13 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.bound = false // Bind endpoint to receive packets from specific interface. - if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } ep.bound = true ep.boundNIC = addr.NIC + ep.netProto = netProto return nil } @@ -473,10 +472,8 @@ func (*endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() - // Make a copy of the endpoint info. - ret := ep.TransportEndpointInfo - ep.mu.RUnlock() - return &ret + defer ep.mu.RUnlock() + return &stack.TransportEndpointInfo{NetProto: ep.netProto} } // Stats returns a pointer to the endpoint stats. diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 85fa58970..5efb3e620 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1487,6 +1487,7 @@ cc_binary( srcs = ["packet_socket.cc"], linkstatic = 1, deps = [ + ":ip_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:socket_util", diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index c8d1e1d4a..43828a52e 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -14,10 +14,15 @@ #include <net/if.h> #include <netinet/if_ether.h> +#include <netpacket/packet.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/types.h> #include <limits> #include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/socket_util.h" @@ -27,10 +32,13 @@ namespace testing { namespace { +using ::testing::AnyOf; using ::testing::Combine; +using ::testing::Eq; using ::testing::Values; -class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { +class PacketSocketCreationTest + : public ::testing::TestWithParam<std::tuple<int, int>> { protected: void SetUp() override { if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) { @@ -42,18 +50,175 @@ class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { } }; -TEST_P(PacketSocketTest, Create) { +TEST_P(PacketSocketCreationTest, Create) { const auto [type, protocol] = GetParam(); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, type, htons(protocol))); EXPECT_GE(fd.get(), 0); } -INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest, +INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketCreationTest, Combine(Values(SOCK_DGRAM, SOCK_RAW), Values(0, 1, 255, ETH_P_IP, ETH_P_IPV6, std::numeric_limits<uint16_t>::max()))); +class PacketSocketTest : public ::testing::TestWithParam<int> { + protected: + void SetUp() override { + if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) { + ASSERT_THAT(socket(AF_PACKET, GetParam(), 0), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP() << "Missing packet socket capability"; + } + + socket_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, GetParam(), 0)); + } + + FileDescriptor socket_; +}; + +TEST_P(PacketSocketTest, RebindProtocol) { + const bool kEthHdrIncluded = GetParam() == SOCK_RAW; + + sockaddr_in udp_bind_addr = { + .sin_family = AF_INET, + .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}, + }; + + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + { + // Bind the socket so that we have something to send packets to. + // + // If we didn't do this, the UDP packets we send will be responded to with + // ICMP Destination Port Unreachable errors. + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<const sockaddr*>(&udp_bind_addr), + sizeof(udp_bind_addr)), + SyscallSucceeds()); + socklen_t addrlen = sizeof(udp_bind_addr); + ASSERT_THAT( + getsockname(udp_sock.get(), reinterpret_cast<sockaddr*>(&udp_bind_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_THAT(addrlen, sizeof(udp_bind_addr)); + } + + const int loopback_index = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); + + auto send_udp_message = [&](const uint64_t v) { + ASSERT_THAT( + sendto(udp_sock.get(), reinterpret_cast<const char*>(&v), sizeof(v), + 0 /* flags */, reinterpret_cast<const sockaddr*>(&udp_bind_addr), + sizeof(udp_bind_addr)), + SyscallSucceeds()); + }; + + auto bind_to_network_protocol = [&](uint16_t protocol) { + const sockaddr_ll packet_bind_addr = { + .sll_family = AF_PACKET, + .sll_protocol = htons(protocol), + .sll_ifindex = loopback_index, + }; + + ASSERT_THAT(bind(socket_.get(), + reinterpret_cast<const sockaddr*>(&packet_bind_addr), + sizeof(packet_bind_addr)), + SyscallSucceeds()); + }; + + auto test_recv = [&, this](const uint64_t v) { + constexpr int kInfiniteTimeout = -1; + pollfd pfd = { + .fd = socket_.get(), + .events = POLLIN, + }; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfiniteTimeout), + SyscallSucceedsWithValue(1)); + + struct { + ethhdr eth; + iphdr ip; + udphdr udp; + uint64_t payload; + char unused; + } ABSL_ATTRIBUTE_PACKED read_pkt; + sockaddr_ll src; + socklen_t src_len = sizeof(src); + + char* buf = reinterpret_cast<char*>(&read_pkt); + size_t buflen = sizeof(read_pkt); + size_t expected_read_len = sizeof(read_pkt) - sizeof(read_pkt.unused); + if (!kEthHdrIncluded) { + buf += sizeof(read_pkt.eth); + buflen -= sizeof(read_pkt.eth); + expected_read_len -= sizeof(read_pkt.eth); + } + + ASSERT_THAT(recvfrom(socket_.get(), buf, buflen, 0, + reinterpret_cast<sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(expected_read_len)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but returns sizeof(sockaddr_ll) since + // https://github.com/torvalds/linux/commit/b2cf86e1563e33a14a1c69b3e508d15dc12f804c. + ASSERT_THAT(src_len, ::testing::AnyOf( + ::testing::Eq(sizeof(src)), + ::testing::Eq(sizeof(src) - sizeof(src.sll_addr) + + ETH_ALEN))); + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, loopback_index); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + // This came from the loopback device, so the address is all 0s. + constexpr uint8_t allZeroesMAC[ETH_ALEN] = {}; + EXPECT_EQ(memcmp(src.sll_addr, allZeroesMAC, sizeof(allZeroesMAC)), 0); + if (kEthHdrIncluded) { + EXPECT_EQ(memcmp(read_pkt.eth.h_dest, allZeroesMAC, sizeof(allZeroesMAC)), + 0); + EXPECT_EQ( + memcmp(read_pkt.eth.h_source, allZeroesMAC, sizeof(allZeroesMAC)), 0); + EXPECT_EQ(ntohs(read_pkt.eth.h_proto), ETH_P_IP); + } + // IHL hold the size of the header in 4 byte units. + EXPECT_EQ(read_pkt.ip.ihl, sizeof(iphdr) / 4); + EXPECT_EQ(read_pkt.ip.version, IPVERSION); + const uint16_t ip_pkt_size = + sizeof(read_pkt) - sizeof(read_pkt.eth) - sizeof(read_pkt.unused); + EXPECT_EQ(ntohs(read_pkt.ip.tot_len), ip_pkt_size); + EXPECT_EQ(read_pkt.ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ntohl(read_pkt.ip.daddr), INADDR_LOOPBACK); + EXPECT_EQ(ntohl(read_pkt.ip.saddr), INADDR_LOOPBACK); + EXPECT_EQ(read_pkt.udp.source, udp_bind_addr.sin_port); + EXPECT_EQ(read_pkt.udp.dest, udp_bind_addr.sin_port); + EXPECT_EQ(ntohs(read_pkt.udp.len), ip_pkt_size - sizeof(read_pkt.ip)); + EXPECT_EQ(read_pkt.payload, v); + }; + + // The packet socket is not bound to IPv4 so we should not receive the sent + // message. + uint64_t counter = 0; + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + + // Bind to IPv4 and expect to receive the UDP packet we send after binding. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + ASSERT_NO_FATAL_FAILURE(test_recv(counter)); + + // Bind the packet socket to a random protocol. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(255)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + + // Bind back to IPv4 and expect to the UDP packet we send after binding + // back to IPv4. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + ASSERT_NO_FATAL_FAILURE(test_recv(counter)); +} + +INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest, + Values(SOCK_DGRAM, SOCK_RAW)); + } // namespace } // namespace testing |