summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorTing-Yu Wang <anivia@google.com>2021-01-15 15:03:30 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-15 15:10:27 -0800
commitec9e263f213c59e93f9c8b8123012b3db2dddc9a (patch)
tree87f0fca5791ece2a138fe822b4067569425f6bd2
parent55c7fe48d223ee5678dff7f5bf9a9e5f0482ab37 (diff)
Correctly return EMSGSIZE when packet is too big in raw socket.
IPv4 previously accepts the packet, while IPv6 panics. Neither is the behavior in Linux. splice() in Linux has different behavior than in gVisor. This change documents it in the SpliceTooLong test. Reported-by: syzbot+b550e78e5c24d1d521f2@syzkaller.appspotmail.com PiperOrigin-RevId: 352091286
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go6
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go24
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go13
-rw-r--r--pkg/tcpip/network/ipv6/mld.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go13
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc52
-rw-r--r--test/syscalls/linux/socket_test_util.cc13
-rw-r--r--test/syscalls/linux/socket_test_util.h7
8 files changed, 111 insertions, 23 deletions
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index da88d65d1..d9b5fe6ed 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -262,13 +262,15 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
localAddr := addressEndpoint.AddressWithPrefix().Address
addressEndpoint.DecRef()
addressEndpoint = nil
- igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
+ if err := igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.IGMPProtocolNumber,
TTL: header.IGMPTTL,
TOS: stack.DefaultTOS,
}, header.IPv4OptionsSerializer{
&header.IPv4SerializableRouterAlertOption{},
- })
+ }); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index cc045c7a9..bb25a76fe 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -237,7 +237,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error {
hdrLen := header.IPv4MinimumSize
var optLen int
if options != nil {
@@ -245,19 +245,19 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
}
hdrLen += optLen
if hdrLen > header.IPv4MaximumHeaderSize {
- // Since we have no way to report an error we must either panic or create
- // a packet which is different to what was requested. Choose panic as this
- // would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize))
+ return tcpip.ErrMessageTooLong
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
- length := uint16(pkt.Size())
+ length := pkt.Size()
+ if length > math.MaxUint16 {
+ return tcpip.ErrMessageTooLong
+ }
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
- TotalLength: length,
+ TotalLength: uint16(length),
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
@@ -268,6 +268,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
})
ip.SetChecksum(^ip.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
+ return nil
}
// handleFragments fragments pkt and calls the handler function on each
@@ -295,7 +296,9 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
+ return err
+ }
// iptables filtering. All packets that reach here are locally
// generated.
@@ -383,7 +386,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
+ return 0, err
+ }
+
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 2f82c3d5f..ae4a8f508 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -553,11 +553,11 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error {
extHdrsLen := extensionHeaders.Length()
length := pkt.Size() + extensionHeaders.Length()
if length > math.MaxUint16 {
- panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16))
+ return tcpip.ErrMessageTooLong
}
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
@@ -570,6 +570,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
ExtensionHeaders: extensionHeaders,
})
pkt.NetworkProtocolNumber = ProtocolNumber
+ return nil
}
func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
@@ -622,7 +623,9 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil {
+ return err
+ }
// iptables filtering. All packets that reach here are locally
// generated.
@@ -711,7 +714,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil {
+ return 0, err
+ }
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index e8d1e7a79..ec54d88cc 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -249,10 +249,12 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
Data: buffer.View(icmp).ToVectorisedView(),
})
- mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.MLDHopLimit,
- }, extensionHeaders)
+ }, extensionHeaders); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sentStats.Dropped.Increment()
return false, err
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index d515eb622..1d8fee50b 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -732,10 +732,12 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
})
sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- }, nil /* extensionHeaders */)
+ }, nil /* extensionHeaders */); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
@@ -1854,11 +1856,12 @@ func (ndp *ndpState) startSolicitingRouters() {
})
sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- }, nil /* extensionHeaders */)
-
+ }, nil /* extensionHeaders */); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index a7c46adbf..2ed4f6f9c 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -678,6 +678,58 @@ TEST_P(RawPacketTest, GetSocketAcceptConn) {
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
+class RawPacketMsgSizeTest : public ::testing::TestWithParam<TestAddress> {};
+
+TEST_P(RawPacketMsgSizeTest, SendTooLong) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ TestAddress addr = GetParam().WithPort(kPort);
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(addr.family(), SOCK_RAW, IPPROTO_UDP));
+
+ ASSERT_THAT(
+ connect(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ const char buf[65536] = {};
+ ASSERT_THAT(send(udp_sock.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EMSGSIZE));
+}
+
+TEST_P(RawPacketMsgSizeTest, SpliceTooLong) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ const char buf[65536] = {};
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ ASSERT_THAT(write(fds[1], buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ TestAddress addr = GetParam().WithPort(kPort);
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(addr.family(), SOCK_RAW, IPPROTO_UDP));
+
+ ASSERT_THAT(
+ connect(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ ssize_t n = splice(fds[0], nullptr, udp_sock.get(), nullptr, sizeof(buf), 0);
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(n, SyscallFailsWithErrno(EMSGSIZE));
+ } else {
+ // TODO(gvisor.dev/issue/138): Linux sends out multiple UDP datagrams, each
+ // of the size of a page.
+ EXPECT_THAT(n, SyscallSucceedsWithValue(sizeof(buf)));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(AllRawPacketMsgSizeTest, RawPacketMsgSizeTest,
+ ::testing::Values(V4Loopback(), V6Loopback()));
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 26dacc95e..b2a96086c 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -791,6 +791,19 @@ void RecvNoData(int sock) {
SyscallFailsWithErrno(EAGAIN));
}
+TestAddress TestAddress::WithPort(uint16_t port) const {
+ TestAddress addr = *this;
+ switch (addr.family()) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(port);
+ break;
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = htons(port);
+ break;
+ }
+ return addr;
+}
+
TestAddress V4Any() {
TestAddress t("V4Any");
t.addr.ss_family = AF_INET;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 75c0d4735..b3ab286b8 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -486,9 +486,14 @@ struct TestAddress {
sockaddr_storage addr;
socklen_t addr_len;
- int family() const { return addr.ss_family; }
explicit TestAddress(std::string description = "")
: description(std::move(description)), addr(), addr_len() {}
+
+ int family() const { return addr.ss_family; }
+
+ // Returns a new TestAddress with specified port. If port is not supported,
+ // the same TestAddress is returned.
+ TestAddress WithPort(uint16_t port) const;
};
constexpr char kMulticastAddress[] = "224.0.2.1";