diff options
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 13 | ||||
-rw-r--r-- | test/syscalls/linux/packet_socket_raw.cc | 52 | ||||
-rw-r--r-- | test/syscalls/linux/socket_test_util.cc | 13 | ||||
-rw-r--r-- | test/syscalls/linux/socket_test_util.h | 7 |
8 files changed, 111 insertions, 23 deletions
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index da88d65d1..d9b5fe6ed 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -262,13 +262,15 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip localAddr := addressEndpoint.AddressWithPrefix().Address addressEndpoint.DecRef() addressEndpoint = nil - igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ + if err := igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.IGMPProtocolNumber, TTL: header.IGMPTTL, TOS: stack.DefaultTOS, }, header.IPv4OptionsSerializer{ &header.IPv4SerializableRouterAlertOption{}, - }) + }); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index cc045c7a9..bb25a76fe 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -237,7 +237,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error { hdrLen := header.IPv4MinimumSize var optLen int if options != nil { @@ -245,19 +245,19 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet } hdrLen += optLen if hdrLen > header.IPv4MaximumHeaderSize { - // Since we have no way to report an error we must either panic or create - // a packet which is different to what was requested. Choose panic as this - // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize)) + return tcpip.ErrMessageTooLong } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) - length := uint16(pkt.Size()) + length := pkt.Size() + if length > math.MaxUint16 { + return tcpip.ErrMessageTooLong + } // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ - TotalLength: length, + TotalLength: uint16(length), ID: uint16(id), TTL: params.TTL, TOS: params.TOS, @@ -268,6 +268,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet }) ip.SetChecksum(^ip.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber + return nil } // handleFragments fragments pkt and calls the handler function on each @@ -295,7 +296,9 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) + if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { + return err + } // iptables filtering. All packets that reach here are locally // generated. @@ -383,7 +386,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) + if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { + return 0, err + } + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 2f82c3d5f..ae4a8f508 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -553,11 +553,11 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { extHdrsLen := extensionHeaders.Length() length := pkt.Size() + extensionHeaders.Length() if length > math.MaxUint16 { - panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16)) + return tcpip.ErrMessageTooLong } ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ @@ -570,6 +570,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet ExtensionHeaders: extensionHeaders, }) pkt.NetworkProtocolNumber = ProtocolNumber + return nil } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { @@ -622,7 +623,9 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */) + if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { + return err + } // iptables filtering. All packets that reach here are locally // generated. @@ -711,7 +714,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */) + if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { + return 0, err + } networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index e8d1e7a79..ec54d88cc 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -249,10 +249,12 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp Data: buffer.View(icmp).ToVectorisedView(), }) - mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.MLDHopLimit, - }, extensionHeaders) + }, extensionHeaders); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sentStats.Dropped.Increment() return false, err diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d515eb622..1d8fee50b 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -732,10 +732,12 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add }) sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }, nil /* extensionHeaders */) + }, nil /* extensionHeaders */); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() @@ -1854,11 +1856,12 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }, nil /* extensionHeaders */) - + }, nil /* extensionHeaders */); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index a7c46adbf..2ed4f6f9c 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -678,6 +678,58 @@ TEST_P(RawPacketTest, GetSocketAcceptConn) { INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); +class RawPacketMsgSizeTest : public ::testing::TestWithParam<TestAddress> {}; + +TEST_P(RawPacketMsgSizeTest, SendTooLong) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + TestAddress addr = GetParam().WithPort(kPort); + + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(addr.family(), SOCK_RAW, IPPROTO_UDP)); + + ASSERT_THAT( + connect(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + const char buf[65536] = {}; + ASSERT_THAT(send(udp_sock.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(EMSGSIZE)); +} + +TEST_P(RawPacketMsgSizeTest, SpliceTooLong) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + const char buf[65536] = {}; + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + ASSERT_THAT(write(fds[1], buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + TestAddress addr = GetParam().WithPort(kPort); + + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(addr.family(), SOCK_RAW, IPPROTO_UDP)); + + ASSERT_THAT( + connect(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + ssize_t n = splice(fds[0], nullptr, udp_sock.get(), nullptr, sizeof(buf), 0); + if (IsRunningOnGvisor()) { + EXPECT_THAT(n, SyscallFailsWithErrno(EMSGSIZE)); + } else { + // TODO(gvisor.dev/issue/138): Linux sends out multiple UDP datagrams, each + // of the size of a page. + EXPECT_THAT(n, SyscallSucceedsWithValue(sizeof(buf))); + } +} + +INSTANTIATE_TEST_SUITE_P(AllRawPacketMsgSizeTest, RawPacketMsgSizeTest, + ::testing::Values(V4Loopback(), V6Loopback())); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 26dacc95e..b2a96086c 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -791,6 +791,19 @@ void RecvNoData(int sock) { SyscallFailsWithErrno(EAGAIN)); } +TestAddress TestAddress::WithPort(uint16_t port) const { + TestAddress addr = *this; + switch (addr.family()) { + case AF_INET: + reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(port); + break; + case AF_INET6: + reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = htons(port); + break; + } + return addr; +} + TestAddress V4Any() { TestAddress t("V4Any"); t.addr.ss_family = AF_INET; diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 75c0d4735..b3ab286b8 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -486,9 +486,14 @@ struct TestAddress { sockaddr_storage addr; socklen_t addr_len; - int family() const { return addr.ss_family; } explicit TestAddress(std::string description = "") : description(std::move(description)), addr(), addr_len() {} + + int family() const { return addr.ss_family; } + + // Returns a new TestAddress with specified port. If port is not supported, + // the same TestAddress is returned. + TestAddress WithPort(uint16_t port) const; }; constexpr char kMulticastAddress[] = "224.0.2.1"; |