summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go6
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go24
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go13
-rw-r--r--pkg/tcpip/network/ipv6/mld.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go13
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc52
-rw-r--r--test/syscalls/linux/socket_test_util.cc13
-rw-r--r--test/syscalls/linux/socket_test_util.h7
8 files changed, 111 insertions, 23 deletions
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index da88d65d1..d9b5fe6ed 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -262,13 +262,15 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
localAddr := addressEndpoint.AddressWithPrefix().Address
addressEndpoint.DecRef()
addressEndpoint = nil
- igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
+ if err := igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.IGMPProtocolNumber,
TTL: header.IGMPTTL,
TOS: stack.DefaultTOS,
}, header.IPv4OptionsSerializer{
&header.IPv4SerializableRouterAlertOption{},
- })
+ }); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index cc045c7a9..bb25a76fe 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -237,7 +237,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error {
hdrLen := header.IPv4MinimumSize
var optLen int
if options != nil {
@@ -245,19 +245,19 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
}
hdrLen += optLen
if hdrLen > header.IPv4MaximumHeaderSize {
- // Since we have no way to report an error we must either panic or create
- // a packet which is different to what was requested. Choose panic as this
- // would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize))
+ return tcpip.ErrMessageTooLong
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
- length := uint16(pkt.Size())
+ length := pkt.Size()
+ if length > math.MaxUint16 {
+ return tcpip.ErrMessageTooLong
+ }
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
- TotalLength: length,
+ TotalLength: uint16(length),
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
@@ -268,6 +268,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
})
ip.SetChecksum(^ip.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
+ return nil
}
// handleFragments fragments pkt and calls the handler function on each
@@ -295,7 +296,9 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
+ return err
+ }
// iptables filtering. All packets that reach here are locally
// generated.
@@ -383,7 +386,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
+ return 0, err
+ }
+
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 2f82c3d5f..ae4a8f508 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -553,11 +553,11 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error {
extHdrsLen := extensionHeaders.Length()
length := pkt.Size() + extensionHeaders.Length()
if length > math.MaxUint16 {
- panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16))
+ return tcpip.ErrMessageTooLong
}
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
@@ -570,6 +570,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
ExtensionHeaders: extensionHeaders,
})
pkt.NetworkProtocolNumber = ProtocolNumber
+ return nil
}
func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
@@ -622,7 +623,9 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil {
+ return err
+ }
// iptables filtering. All packets that reach here are locally
// generated.
@@ -711,7 +714,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */)
+ if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil {
+ return 0, err
+ }
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index e8d1e7a79..ec54d88cc 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -249,10 +249,12 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
Data: buffer.View(icmp).ToVectorisedView(),
})
- mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.MLDHopLimit,
- }, extensionHeaders)
+ }, extensionHeaders); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sentStats.Dropped.Increment()
return false, err
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index d515eb622..1d8fee50b 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -732,10 +732,12 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
})
sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- }, nil /* extensionHeaders */)
+ }, nil /* extensionHeaders */); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
@@ -1854,11 +1856,12 @@ func (ndp *ndpState) startSolicitingRouters() {
})
sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- }, nil /* extensionHeaders */)
-
+ }, nil /* extensionHeaders */); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index a7c46adbf..2ed4f6f9c 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -678,6 +678,58 @@ TEST_P(RawPacketTest, GetSocketAcceptConn) {
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
+class RawPacketMsgSizeTest : public ::testing::TestWithParam<TestAddress> {};
+
+TEST_P(RawPacketMsgSizeTest, SendTooLong) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ TestAddress addr = GetParam().WithPort(kPort);
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(addr.family(), SOCK_RAW, IPPROTO_UDP));
+
+ ASSERT_THAT(
+ connect(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ const char buf[65536] = {};
+ ASSERT_THAT(send(udp_sock.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EMSGSIZE));
+}
+
+TEST_P(RawPacketMsgSizeTest, SpliceTooLong) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ const char buf[65536] = {};
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ ASSERT_THAT(write(fds[1], buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ TestAddress addr = GetParam().WithPort(kPort);
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(addr.family(), SOCK_RAW, IPPROTO_UDP));
+
+ ASSERT_THAT(
+ connect(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ ssize_t n = splice(fds[0], nullptr, udp_sock.get(), nullptr, sizeof(buf), 0);
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(n, SyscallFailsWithErrno(EMSGSIZE));
+ } else {
+ // TODO(gvisor.dev/issue/138): Linux sends out multiple UDP datagrams, each
+ // of the size of a page.
+ EXPECT_THAT(n, SyscallSucceedsWithValue(sizeof(buf)));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(AllRawPacketMsgSizeTest, RawPacketMsgSizeTest,
+ ::testing::Values(V4Loopback(), V6Loopback()));
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 26dacc95e..b2a96086c 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -791,6 +791,19 @@ void RecvNoData(int sock) {
SyscallFailsWithErrno(EAGAIN));
}
+TestAddress TestAddress::WithPort(uint16_t port) const {
+ TestAddress addr = *this;
+ switch (addr.family()) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(port);
+ break;
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = htons(port);
+ break;
+ }
+ return addr;
+}
+
TestAddress V4Any() {
TestAddress t("V4Any");
t.addr.ss_family = AF_INET;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 75c0d4735..b3ab286b8 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -486,9 +486,14 @@ struct TestAddress {
sockaddr_storage addr;
socklen_t addr_len;
- int family() const { return addr.ss_family; }
explicit TestAddress(std::string description = "")
: description(std::move(description)), addr(), addr_len() {}
+
+ int family() const { return addr.ss_family; }
+
+ // Returns a new TestAddress with specified port. If port is not supported,
+ // the same TestAddress is returned.
+ TestAddress WithPort(uint16_t port) const;
};
constexpr char kMulticastAddress[] = "224.0.2.1";