summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go5
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go23
-rw-r--r--test/syscalls/linux/BUILD1
-rw-r--r--test/syscalls/linux/packet_socket.cc171
4 files changed, 180 insertions, 20 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index f79bda922..aa081e90d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -672,13 +672,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
- if a.Protocol != uint16(s.protocol) {
- return syserr.ErrInvalidArgument
- }
-
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: socket.Ntohs(a.Protocol),
}
} else {
if s.minSockAddrLen() > len(sockaddr) {
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 2c9786175..1f30e5adb 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -59,13 +59,11 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
ops tcpip.SocketOptions
@@ -84,6 +82,8 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
// +checklocks:mu
+ netProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
closed bool
// +checklocks:mu
bound bool
@@ -98,10 +98,7 @@ type endpoint struct {
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
+ stack: s,
cooked: cooked,
netProto: netProto,
waiterQueue: waiterQueue,
@@ -214,13 +211,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc
ep.mu.Lock()
closed := ep.closed
nicID := ep.boundNIC
+ proto := ep.netProto
ep.mu.Unlock()
if closed {
return 0, &tcpip.ErrClosedForSend{}
}
var remote tcpip.LinkAddress
- proto := ep.netProto
if to := opts.To; to != nil {
remote = tcpip.LinkAddress(to.Addr)
@@ -296,7 +293,8 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto {
// If the NIC being bound is the same then just return success.
return nil
}
@@ -306,12 +304,13 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.bound = false
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
ep.bound = true
ep.boundNIC = addr.NIC
+ ep.netProto = netProto
return nil
}
@@ -473,10 +472,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.netProto}
}
// Stats returns a pointer to the endpoint stats.
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 85fa58970..5efb3e620 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1487,6 +1487,7 @@ cc_binary(
srcs = ["packet_socket.cc"],
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:socket_util",
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index c8d1e1d4a..43828a52e 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -14,10 +14,15 @@
#include <net/if.h>
#include <netinet/if_ether.h>
+#include <netpacket/packet.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/types.h>
#include <limits>
#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/socket_util.h"
@@ -27,10 +32,13 @@ namespace testing {
namespace {
+using ::testing::AnyOf;
using ::testing::Combine;
+using ::testing::Eq;
using ::testing::Values;
-class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
+class PacketSocketCreationTest
+ : public ::testing::TestWithParam<std::tuple<int, int>> {
protected:
void SetUp() override {
if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
@@ -42,18 +50,175 @@ class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
}
};
-TEST_P(PacketSocketTest, Create) {
+TEST_P(PacketSocketCreationTest, Create) {
const auto [type, protocol] = GetParam();
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, type, htons(protocol)));
EXPECT_GE(fd.get(), 0);
}
-INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest,
+INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketCreationTest,
Combine(Values(SOCK_DGRAM, SOCK_RAW),
Values(0, 1, 255, ETH_P_IP, ETH_P_IPV6,
std::numeric_limits<uint16_t>::max())));
+class PacketSocketTest : public ::testing::TestWithParam<int> {
+ protected:
+ void SetUp() override {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
+ ASSERT_THAT(socket(AF_PACKET, GetParam(), 0),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP() << "Missing packet socket capability";
+ }
+
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, GetParam(), 0));
+ }
+
+ FileDescriptor socket_;
+};
+
+TEST_P(PacketSocketTest, RebindProtocol) {
+ const bool kEthHdrIncluded = GetParam() == SOCK_RAW;
+
+ sockaddr_in udp_bind_addr = {
+ .sin_family = AF_INET,
+ .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)},
+ };
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ {
+ // Bind the socket so that we have something to send packets to.
+ //
+ // If we didn't do this, the UDP packets we send will be responded to with
+ // ICMP Destination Port Unreachable errors.
+ ASSERT_THAT(
+ bind(udp_sock.get(), reinterpret_cast<const sockaddr*>(&udp_bind_addr),
+ sizeof(udp_bind_addr)),
+ SyscallSucceeds());
+ socklen_t addrlen = sizeof(udp_bind_addr);
+ ASSERT_THAT(
+ getsockname(udp_sock.get(), reinterpret_cast<sockaddr*>(&udp_bind_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(addrlen, sizeof(udp_bind_addr));
+ }
+
+ const int loopback_index = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
+
+ auto send_udp_message = [&](const uint64_t v) {
+ ASSERT_THAT(
+ sendto(udp_sock.get(), reinterpret_cast<const char*>(&v), sizeof(v),
+ 0 /* flags */, reinterpret_cast<const sockaddr*>(&udp_bind_addr),
+ sizeof(udp_bind_addr)),
+ SyscallSucceeds());
+ };
+
+ auto bind_to_network_protocol = [&](uint16_t protocol) {
+ const sockaddr_ll packet_bind_addr = {
+ .sll_family = AF_PACKET,
+ .sll_protocol = htons(protocol),
+ .sll_ifindex = loopback_index,
+ };
+
+ ASSERT_THAT(bind(socket_.get(),
+ reinterpret_cast<const sockaddr*>(&packet_bind_addr),
+ sizeof(packet_bind_addr)),
+ SyscallSucceeds());
+ };
+
+ auto test_recv = [&, this](const uint64_t v) {
+ constexpr int kInfiniteTimeout = -1;
+ pollfd pfd = {
+ .fd = socket_.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfiniteTimeout),
+ SyscallSucceedsWithValue(1));
+
+ struct {
+ ethhdr eth;
+ iphdr ip;
+ udphdr udp;
+ uint64_t payload;
+ char unused;
+ } ABSL_ATTRIBUTE_PACKED read_pkt;
+ sockaddr_ll src;
+ socklen_t src_len = sizeof(src);
+
+ char* buf = reinterpret_cast<char*>(&read_pkt);
+ size_t buflen = sizeof(read_pkt);
+ size_t expected_read_len = sizeof(read_pkt) - sizeof(read_pkt.unused);
+ if (!kEthHdrIncluded) {
+ buf += sizeof(read_pkt.eth);
+ buflen -= sizeof(read_pkt.eth);
+ expected_read_len -= sizeof(read_pkt.eth);
+ }
+
+ ASSERT_THAT(recvfrom(socket_.get(), buf, buflen, 0,
+ reinterpret_cast<sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(expected_read_len));
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but returns sizeof(sockaddr_ll) since
+ // https://github.com/torvalds/linux/commit/b2cf86e1563e33a14a1c69b3e508d15dc12f804c.
+ ASSERT_THAT(src_len, ::testing::AnyOf(
+ ::testing::Eq(sizeof(src)),
+ ::testing::Eq(sizeof(src) - sizeof(src.sll_addr) +
+ ETH_ALEN)));
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, loopback_index);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
+ // This came from the loopback device, so the address is all 0s.
+ constexpr uint8_t allZeroesMAC[ETH_ALEN] = {};
+ EXPECT_EQ(memcmp(src.sll_addr, allZeroesMAC, sizeof(allZeroesMAC)), 0);
+ if (kEthHdrIncluded) {
+ EXPECT_EQ(memcmp(read_pkt.eth.h_dest, allZeroesMAC, sizeof(allZeroesMAC)),
+ 0);
+ EXPECT_EQ(
+ memcmp(read_pkt.eth.h_source, allZeroesMAC, sizeof(allZeroesMAC)), 0);
+ EXPECT_EQ(ntohs(read_pkt.eth.h_proto), ETH_P_IP);
+ }
+ // IHL hold the size of the header in 4 byte units.
+ EXPECT_EQ(read_pkt.ip.ihl, sizeof(iphdr) / 4);
+ EXPECT_EQ(read_pkt.ip.version, IPVERSION);
+ const uint16_t ip_pkt_size =
+ sizeof(read_pkt) - sizeof(read_pkt.eth) - sizeof(read_pkt.unused);
+ EXPECT_EQ(ntohs(read_pkt.ip.tot_len), ip_pkt_size);
+ EXPECT_EQ(read_pkt.ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ntohl(read_pkt.ip.daddr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(read_pkt.ip.saddr), INADDR_LOOPBACK);
+ EXPECT_EQ(read_pkt.udp.source, udp_bind_addr.sin_port);
+ EXPECT_EQ(read_pkt.udp.dest, udp_bind_addr.sin_port);
+ EXPECT_EQ(ntohs(read_pkt.udp.len), ip_pkt_size - sizeof(read_pkt.ip));
+ EXPECT_EQ(read_pkt.payload, v);
+ };
+
+ // The packet socket is not bound to IPv4 so we should not receive the sent
+ // message.
+ uint64_t counter = 0;
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+
+ // Bind to IPv4 and expect to receive the UDP packet we send after binding.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+ ASSERT_NO_FATAL_FAILURE(test_recv(counter));
+
+ // Bind the packet socket to a random protocol.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(255));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+
+ // Bind back to IPv4 and expect to the UDP packet we send after binding
+ // back to IPv4.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+ ASSERT_NO_FATAL_FAILURE(test_recv(counter));
+}
+
+INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest,
+ Values(SOCK_DGRAM, SOCK_RAW));
+
} // namespace
} // namespace testing