summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-09-17 15:31:19 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-17 15:37:08 -0700
commit7dacdbef528f7b556f23c1b02a360363dc556e31 (patch)
tree39f1b8fb51f3aaa830fcfc5459c15ff4bb75223f
parent4076153be6840c50ade746087b221a12d7bd2b3b (diff)
Allow rebinding packet socket protocol
...to change the network protocol a packet socket may receive packets from. This CL is a portion of an originally larger CL that was split with https://github.com/google/gvisor/commit/a8ad692fd36cbaf7f5a6b9af39d601053dbee338 being the dependent CL. That CL (accidentally) included the change in the endpoint's `afterLoad` method to take the required lock when accessing the endpoint's netProto field. That change should have been in this CL. The CL that made the change mentioned in the commit message is cl/396946187. PiperOrigin-RevId: 397412582
-rw-r--r--pkg/sentry/socket/netstack/netstack.go5
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go23
-rw-r--r--test/syscalls/linux/BUILD1
-rw-r--r--test/syscalls/linux/packet_socket.cc171
4 files changed, 180 insertions, 20 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index f79bda922..aa081e90d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -672,13 +672,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
- if a.Protocol != uint16(s.protocol) {
- return syserr.ErrInvalidArgument
- }
-
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: socket.Ntohs(a.Protocol),
}
} else {
if s.minSockAddrLen() > len(sockaddr) {
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 2c9786175..1f30e5adb 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -59,13 +59,11 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
ops tcpip.SocketOptions
@@ -84,6 +82,8 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
// +checklocks:mu
+ netProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
closed bool
// +checklocks:mu
bound bool
@@ -98,10 +98,7 @@ type endpoint struct {
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
+ stack: s,
cooked: cooked,
netProto: netProto,
waiterQueue: waiterQueue,
@@ -214,13 +211,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc
ep.mu.Lock()
closed := ep.closed
nicID := ep.boundNIC
+ proto := ep.netProto
ep.mu.Unlock()
if closed {
return 0, &tcpip.ErrClosedForSend{}
}
var remote tcpip.LinkAddress
- proto := ep.netProto
if to := opts.To; to != nil {
remote = tcpip.LinkAddress(to.Addr)
@@ -296,7 +293,8 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto {
// If the NIC being bound is the same then just return success.
return nil
}
@@ -306,12 +304,13 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.bound = false
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
ep.bound = true
ep.boundNIC = addr.NIC
+ ep.netProto = netProto
return nil
}
@@ -473,10 +472,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.netProto}
}
// Stats returns a pointer to the endpoint stats.
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 85fa58970..5efb3e620 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1487,6 +1487,7 @@ cc_binary(
srcs = ["packet_socket.cc"],
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:socket_util",
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index c8d1e1d4a..43828a52e 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -14,10 +14,15 @@
#include <net/if.h>
#include <netinet/if_ether.h>
+#include <netpacket/packet.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/types.h>
#include <limits>
#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/socket_util.h"
@@ -27,10 +32,13 @@ namespace testing {
namespace {
+using ::testing::AnyOf;
using ::testing::Combine;
+using ::testing::Eq;
using ::testing::Values;
-class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
+class PacketSocketCreationTest
+ : public ::testing::TestWithParam<std::tuple<int, int>> {
protected:
void SetUp() override {
if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
@@ -42,18 +50,175 @@ class PacketSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
}
};
-TEST_P(PacketSocketTest, Create) {
+TEST_P(PacketSocketCreationTest, Create) {
const auto [type, protocol] = GetParam();
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, type, htons(protocol)));
EXPECT_GE(fd.get(), 0);
}
-INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest,
+INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketCreationTest,
Combine(Values(SOCK_DGRAM, SOCK_RAW),
Values(0, 1, 255, ETH_P_IP, ETH_P_IPV6,
std::numeric_limits<uint16_t>::max())));
+class PacketSocketTest : public ::testing::TestWithParam<int> {
+ protected:
+ void SetUp() override {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
+ ASSERT_THAT(socket(AF_PACKET, GetParam(), 0),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP() << "Missing packet socket capability";
+ }
+
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, GetParam(), 0));
+ }
+
+ FileDescriptor socket_;
+};
+
+TEST_P(PacketSocketTest, RebindProtocol) {
+ const bool kEthHdrIncluded = GetParam() == SOCK_RAW;
+
+ sockaddr_in udp_bind_addr = {
+ .sin_family = AF_INET,
+ .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)},
+ };
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ {
+ // Bind the socket so that we have something to send packets to.
+ //
+ // If we didn't do this, the UDP packets we send will be responded to with
+ // ICMP Destination Port Unreachable errors.
+ ASSERT_THAT(
+ bind(udp_sock.get(), reinterpret_cast<const sockaddr*>(&udp_bind_addr),
+ sizeof(udp_bind_addr)),
+ SyscallSucceeds());
+ socklen_t addrlen = sizeof(udp_bind_addr);
+ ASSERT_THAT(
+ getsockname(udp_sock.get(), reinterpret_cast<sockaddr*>(&udp_bind_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(addrlen, sizeof(udp_bind_addr));
+ }
+
+ const int loopback_index = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
+
+ auto send_udp_message = [&](const uint64_t v) {
+ ASSERT_THAT(
+ sendto(udp_sock.get(), reinterpret_cast<const char*>(&v), sizeof(v),
+ 0 /* flags */, reinterpret_cast<const sockaddr*>(&udp_bind_addr),
+ sizeof(udp_bind_addr)),
+ SyscallSucceeds());
+ };
+
+ auto bind_to_network_protocol = [&](uint16_t protocol) {
+ const sockaddr_ll packet_bind_addr = {
+ .sll_family = AF_PACKET,
+ .sll_protocol = htons(protocol),
+ .sll_ifindex = loopback_index,
+ };
+
+ ASSERT_THAT(bind(socket_.get(),
+ reinterpret_cast<const sockaddr*>(&packet_bind_addr),
+ sizeof(packet_bind_addr)),
+ SyscallSucceeds());
+ };
+
+ auto test_recv = [&, this](const uint64_t v) {
+ constexpr int kInfiniteTimeout = -1;
+ pollfd pfd = {
+ .fd = socket_.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfiniteTimeout),
+ SyscallSucceedsWithValue(1));
+
+ struct {
+ ethhdr eth;
+ iphdr ip;
+ udphdr udp;
+ uint64_t payload;
+ char unused;
+ } ABSL_ATTRIBUTE_PACKED read_pkt;
+ sockaddr_ll src;
+ socklen_t src_len = sizeof(src);
+
+ char* buf = reinterpret_cast<char*>(&read_pkt);
+ size_t buflen = sizeof(read_pkt);
+ size_t expected_read_len = sizeof(read_pkt) - sizeof(read_pkt.unused);
+ if (!kEthHdrIncluded) {
+ buf += sizeof(read_pkt.eth);
+ buflen -= sizeof(read_pkt.eth);
+ expected_read_len -= sizeof(read_pkt.eth);
+ }
+
+ ASSERT_THAT(recvfrom(socket_.get(), buf, buflen, 0,
+ reinterpret_cast<sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(expected_read_len));
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but returns sizeof(sockaddr_ll) since
+ // https://github.com/torvalds/linux/commit/b2cf86e1563e33a14a1c69b3e508d15dc12f804c.
+ ASSERT_THAT(src_len, ::testing::AnyOf(
+ ::testing::Eq(sizeof(src)),
+ ::testing::Eq(sizeof(src) - sizeof(src.sll_addr) +
+ ETH_ALEN)));
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, loopback_index);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
+ // This came from the loopback device, so the address is all 0s.
+ constexpr uint8_t allZeroesMAC[ETH_ALEN] = {};
+ EXPECT_EQ(memcmp(src.sll_addr, allZeroesMAC, sizeof(allZeroesMAC)), 0);
+ if (kEthHdrIncluded) {
+ EXPECT_EQ(memcmp(read_pkt.eth.h_dest, allZeroesMAC, sizeof(allZeroesMAC)),
+ 0);
+ EXPECT_EQ(
+ memcmp(read_pkt.eth.h_source, allZeroesMAC, sizeof(allZeroesMAC)), 0);
+ EXPECT_EQ(ntohs(read_pkt.eth.h_proto), ETH_P_IP);
+ }
+ // IHL hold the size of the header in 4 byte units.
+ EXPECT_EQ(read_pkt.ip.ihl, sizeof(iphdr) / 4);
+ EXPECT_EQ(read_pkt.ip.version, IPVERSION);
+ const uint16_t ip_pkt_size =
+ sizeof(read_pkt) - sizeof(read_pkt.eth) - sizeof(read_pkt.unused);
+ EXPECT_EQ(ntohs(read_pkt.ip.tot_len), ip_pkt_size);
+ EXPECT_EQ(read_pkt.ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ntohl(read_pkt.ip.daddr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(read_pkt.ip.saddr), INADDR_LOOPBACK);
+ EXPECT_EQ(read_pkt.udp.source, udp_bind_addr.sin_port);
+ EXPECT_EQ(read_pkt.udp.dest, udp_bind_addr.sin_port);
+ EXPECT_EQ(ntohs(read_pkt.udp.len), ip_pkt_size - sizeof(read_pkt.ip));
+ EXPECT_EQ(read_pkt.payload, v);
+ };
+
+ // The packet socket is not bound to IPv4 so we should not receive the sent
+ // message.
+ uint64_t counter = 0;
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+
+ // Bind to IPv4 and expect to receive the UDP packet we send after binding.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+ ASSERT_NO_FATAL_FAILURE(test_recv(counter));
+
+ // Bind the packet socket to a random protocol.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(255));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+
+ // Bind back to IPv4 and expect to the UDP packet we send after binding
+ // back to IPv4.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+ ASSERT_NO_FATAL_FAILURE(test_recv(counter));
+}
+
+INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest,
+ Values(SOCK_DGRAM, SOCK_RAW));
+
} // namespace
} // namespace testing