From ac2200b8a9c269926d2eb98a7c23be79b4738fcf Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Mon, 26 Aug 2019 12:28:26 -0700 Subject: Prevent a network endpoint to send/rcv if its address was removed This addresses the problem where an endpoint has its address removed but still has outstanding references held by routes used in connected TCP/UDP sockets which prevent the removal of the endpoint. The fix adds a new "expired" flag to the referenced network endpoint, which is set when an endpoint has its address removed. Incoming packets are not delivered to an expired endpoint (unless in promiscuous mode), while sending outgoing packets triggers an error to the caller (unless in spoofing mode). In addition, a few helper functions were added to stack_test.go to reduce code duplications. PiperOrigin-RevId: 265514326 --- pkg/tcpip/stack/route.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'pkg/tcpip/stack/route.go') diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 391ab4344..e52cdd674 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. func (r *Route) IsResolutionRequired() bool { - return r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() @@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err -- cgit v1.2.3 From 3789c34b22e7a7466149bfbeedf05bf49188130c Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 3 Sep 2019 15:59:58 -0700 Subject: Make UDP traceroute work. Adds support to generate Port Unreachable messages for UDP datagrams received on a port for which there is no valid endpoint. Fixes #703 PiperOrigin-RevId: 267034418 --- pkg/tcpip/checker/checker.go | 100 +++++++++++++ pkg/tcpip/header/icmpv4.go | 71 +++++++++- pkg/tcpip/header/icmpv6.go | 88 +++++++++++- pkg/tcpip/header/ipv4.go | 33 +++-- pkg/tcpip/header/ipv6.go | 18 +-- pkg/tcpip/network/ip_test.go | 12 +- pkg/tcpip/network/ipv4/icmp.go | 4 +- pkg/tcpip/network/ipv6/icmp.go | 31 +--- pkg/tcpip/network/ipv6/icmp_test.go | 4 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/icmp_rate_limit.go | 86 ++++++++++++ pkg/tcpip/stack/nic.go | 12 +- pkg/tcpip/stack/registration.go | 2 +- pkg/tcpip/stack/route.go | 5 + pkg/tcpip/stack/stack.go | 36 +++++ pkg/tcpip/stack/transport_test.go | 2 +- pkg/tcpip/tcpip.go | 8 ++ pkg/tcpip/transport/icmp/endpoint.go | 19 ++- pkg/tcpip/transport/icmp/protocol.go | 11 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 8 +- pkg/tcpip/transport/udp/protocol.go | 101 ++++++++++++- pkg/tcpip/transport/udp/udp_test.go | 156 ++++++++++++++++++++- 23 files changed, 717 insertions(+), 94 deletions(-) create mode 100644 pkg/tcpip/stack/icmp_rate_limit.go (limited to 'pkg/tcpip/stack/route.go') diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index afcabd51d..096ad71ab 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -586,3 +586,103 @@ func Payload(want []byte) TransportChecker { } } } + +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and +// potentially additional ICMPv4 header fields. +func ICMPv4(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) + } + + icmp := header.ICMPv4(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv4Type creates a checker that checks the ICMPv4 Type field. +func ICMPv4Type(want header.ICMPv4Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv4Code creates a checker that checks the ICMPv4 Code field. +func ICMPv4Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} + +// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and +// potentially additional ICMPv6 header fields. +func ICMPv6(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) + } + + icmp := header.ICMPv6(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv6Type creates a checker that checks the ICMPv6 Type field. +func ICMPv6Type(want header.ICMPv6Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv6Code creates a checker that checks the ICMPv6 Code field. +func ICMPv6Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c52c0d851..0cac6c0a5 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -25,13 +26,29 @@ type ICMPv4 []byte const ( // ICMPv4PayloadOffset defines the start of ICMP payload. - ICMPv4PayloadOffset = 4 + ICMPv4PayloadOffset = 8 // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. ICMPv4MinimumSize = 8 // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 + + // icmpv4ChecksumOffset is the offset of the checksum field + // in an ICMPv4 message. + icmpv4ChecksumOffset = 2 + + // icmpv4MTUOffset is the offset of the MTU field + // in a ICMPv4FragmentationNeeded message. + icmpv4MTUOffset = 6 + + // icmpv4IdentOffset is the offset of the ident field + // in a ICMPv4EchoRequest/Reply message. + icmpv4IdentOffset = 4 + + // icmpv4SequenceOffset is the offset of the sequence field + // in a ICMPv4EchoRequest/Reply message. + icmpv4SequenceOffset = 6 ) // ICMPv4Type is the ICMP type field described in RFC 792. @@ -72,12 +89,12 @@ func (b ICMPv4) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) } // SetChecksum sets the ICMP checksum field. func (b ICMPv4) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -102,3 +119,51 @@ func (ICMPv4) SetDestinationPort(uint16) { func (b ICMPv4) Payload() []byte { return b[ICMPv4PayloadOffset:] } + +// MTU retrieves the MTU field from an ICMPv4 message. +func (b ICMPv4) MTU() uint16 { + return binary.BigEndian.Uint16(b[icmpv4MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv4 message. +func (b ICMPv4) SetMTU(mtu uint16) { + binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv4 message. +func (b ICMPv4) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv4IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv4 message. +func (b ICMPv4) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv4 message. +func (b ICMPv4) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv4 message. +func (b ICMPv4) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence) +} + +// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, +// and payload. +func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := uint16(0) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 3cc57e234..1125a7d14 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -25,14 +26,18 @@ type ICMPv6 []byte const ( // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. - ICMPv6MinimumSize = 4 + ICMPv6MinimumSize = 8 + + // ICMPv6PayloadOffset is the offset of the payload in an + // ICMP packet. + ICMPv6PayloadOffset = 8 // ICMPv6ProtocolNumber is the ICMP transport protocol number. ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 // ICMPv6NeighborSolicitMinimumSize is the minimum size of a // neighbor solicitation packet. - ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16 + ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16 // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. ICMPv6NeighborAdvertSize = 32 @@ -42,11 +47,27 @@ const ( // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP // destination unreachable packet. - ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP // packet-too-big packet. - ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + + // icmpv6ChecksumOffset is the offset of the checksum field + // in an ICMPv6 message. + icmpv6ChecksumOffset = 2 + + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 + // PacketTooBig message. + icmpv6MTUOffset = 4 + + // icmpv6IdentOffset is the offset of the ident field + // in a ICMPv6 Echo Request/Reply message. + icmpv6IdentOffset = 4 + + // icmpv6SequenceOffset is the offset of the sequence field + // in a ICMPv6 Echo Request/Reply message. + icmpv6SequenceOffset = 6 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -89,12 +110,12 @@ func (b ICMPv6) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } // SetChecksum calculates and sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -115,7 +136,60 @@ func (ICMPv6) SetSourcePort(uint16) { func (ICMPv6) SetDestinationPort(uint16) { } +// MTU retrieves the MTU field from an ICMPv6 message. +func (b ICMPv6) MTU() uint32 { + return binary.BigEndian.Uint32(b[icmpv6MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv6 message. +func (b ICMPv6) SetMTU(mtu uint32) { + binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv6 message. +func (b ICMPv6) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv6IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv6 message. +func (b ICMPv6) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv6 message. +func (b ICMPv6) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv6 message. +func (b ICMPv6) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) +} + // Payload implements Transport.Payload. func (b ICMPv6) Payload() []byte { - return b[ICMPv6MinimumSize:] + return b[ICMPv6PayloadOffset:] +} + +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// IPv6 src/dst addresses and the payload. +func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := Checksum([]byte(src), 0) + xsum = Checksum([]byte(dst), xsum) + var upperLayerLength [4]byte + binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) + xsum = Checksum(upperLayerLength[:], xsum) + xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 17fc9c68e..554632a64 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -21,16 +21,18 @@ import ( ) const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksum = 10 - srcAddr = 12 - dstAddr = 16 + versIHL = 0 + tos = 1 + // IPv4TotalLenOffset is the offset of the total length field in the + // IPv4 header. + IPv4TotalLenOffset = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksum = 10 + srcAddr = 12 + dstAddr = 16 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -103,6 +105,11 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP + // packet that every IPv4 capable host must be able to + // process/reassemble. + IPv4MinimumProcessableDatagramSize = 576 ) // Flags that may be set in an IPv4 packet. @@ -163,7 +170,7 @@ func (b IPv4) FragmentOffset() uint16 { // TotalLength returns the "total length" field of the ipv4 header. func (b IPv4) TotalLength() uint16 { - return binary.BigEndian.Uint16(b[totalLen:]) + return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } // Checksum returns the checksum field of the ipv4 header. @@ -209,7 +216,7 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { - binary.BigEndian.PutUint16(b[totalLen:], totalLength) + binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } // SetChecksum sets the checksum field of the ipv4 header. @@ -265,7 +272,7 @@ func (b IPv4) Encode(i *IPv4Fields) { // packets are produced. func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) { b.SetTotalLength(totalLength) - checksum := Checksum(b[totalLen:totalLen+2], partialChecksum) + checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum) b.SetChecksum(^checksum) } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index bc4e56535..093850e25 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -22,12 +22,14 @@ import ( ) const ( - versTCFL = 0 - payloadLen = 4 - nextHdr = 6 - hopLimit = 7 - v6SrcAddr = 8 - v6DstAddr = v6SrcAddr + IPv6AddressSize + versTCFL = 0 + // IPv6PayloadLenOffset is the offset of the PayloadLength field in + // IPv6 header. + IPv6PayloadLenOffset = 4 + nextHdr = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -94,7 +96,7 @@ var IPv6EmptySubnet = func() tcpip.Subnet { // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { - return binary.BigEndian.Uint16(b[payloadLen:]) + return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) } // HopLimit returns the value of the "hop limit" field of the ipv6 header. @@ -148,7 +150,7 @@ func (b IPv6) SetTOS(t uint8, l uint32) { // SetPayloadLength sets the "payload length" field of the ipv6 header. func (b IPv6) SetPayloadLength(payloadLength uint16) { - binary.BigEndian.PutUint16(b[payloadLen:], payloadLength) + binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength) } // SetSourceAddress sets the "source address" field of the ipv6 header. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6bbfcd97f..4b3bd74fa 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -319,7 +319,8 @@ func TestIPv4ReceiveControl(t *testing.T) { icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) icmp.SetType(header.ICMPv4DstUnreachable) icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv4 header. ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) @@ -539,7 +540,7 @@ func TestIPv6ReceiveControl(t *testing.T) { defer ep.Close() - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4 + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize } @@ -559,10 +560,11 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) icmp.SetType(c.typ) icmp.SetCode(c.code) - copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) ip.Encode(&header.IPv6Fields{ PayloadLength: 100, NextHeader: 10, @@ -574,7 +576,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Build the fragmentation header if needed. if c.fragmentOffset != nil { ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) frag.Encode(&header.IPv6FragmentFields{ NextHeader: 10, FragmentOffset: *c.fragmentOffset, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 497164cbb..a25756443 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,8 +15,6 @@ package ipv4 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -117,7 +115,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.handleControl(stack.ControlPortUnreachable, 0, vv) case header.ICMPv4FragmentationNeeded: - mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:])) + mtu := uint32(h.MTU()) e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1689af16f..346de9ae3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,8 +15,6 @@ package ipv6 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -82,7 +80,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V return } vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:]) + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) case header.ICMPv6DstUnreachable: @@ -130,7 +128,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { sent.Dropped.Increment() @@ -162,7 +160,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { sent.Dropped.Increment() return @@ -233,7 +231,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr pkt[icmpV6LengthOffset] = 1 copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) length := uint16(hdr.UsedLength()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -272,24 +270,3 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return "", false } - -func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := header.Checksum([]byte(src), 0) - xsum = header.Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = header.Checksum(upperLayerLength[:], xsum) - xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = header.Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^header.Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum -} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d0dc72506..227a65cf2 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -153,7 +153,7 @@ func TestICMPCounts(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) handleIPv6Payload(hdr) } @@ -321,7 +321,7 @@ func TestLinkResolution(t *testing.T) { hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index b692c60ce..788de3dfe 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", "nic.go", @@ -42,6 +43,7 @@ go_library( "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go new file mode 100644 index 000000000..f8156be47 --- /dev/null +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -0,0 +1,86 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "sync" + + "golang.org/x/time/rate" +) + +const ( + // icmpLimit is the default maximum number of ICMP messages permitted by this + // rate limiter. + icmpLimit = 1000 + + // icmpBurst is the default number of ICMP messages that can be sent in a single + // burst. + icmpBurst = 50 +) + +// ICMPRateLimiter is a global rate limiter that controls the generation of +// ICMP messages generated by the stack. +type ICMPRateLimiter struct { + mu sync.RWMutex + l *rate.Limiter +} + +// NewICMPRateLimiter returns a global rate limiter for controlling the rate +// at which ICMP messages are generated by the stack. +func NewICMPRateLimiter() *ICMPRateLimiter { + return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)} +} + +// Allow returns true if we are allowed to send at least 1 message at the +// moment. +func (i *ICMPRateLimiter) Allow() bool { + i.mu.RLock() + allow := i.l.Allow() + i.mu.RUnlock() + return allow +} + +// Limit returns the maximum number of ICMP messages that can be sent in one +// second. +func (i *ICMPRateLimiter) Limit() rate.Limit { + i.mu.RLock() + defer i.mu.RUnlock() + return i.l.Limit() +} + +// SetLimit sets the maximum number of ICMP messages that can be sent in one +// second. +func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) { + i.mu.RLock() + defer i.mu.RUnlock() + i.l.SetLimit(newLimit) +} + +// Burst returns how many ICMP messages can be sent at any single instant. +func (i *ICMPRateLimiter) Burst() int { + i.mu.RLock() + defer i.mu.RUnlock() + return i.l.Burst() +} + +// SetBurst sets the maximum number of ICMP messages allowed at any single +// instant. +// +// NOTE: Changing Burst causes the underlying rate limiter to be recreated. +func (i *ICMPRateLimiter) SetBurst(burst int) { + i.mu.Lock() + i.l = rate.NewLimiter(i.l.Limit(), burst) + i.mu.Unlock() +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f947b55db..ae56e0ffd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -679,7 +679,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { n.stack.stats.MalformedRcvdPackets.Increment() } } @@ -720,6 +720,11 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Stack returns the instance of the Stack that owns this NIC. +func (n *NIC) Stack() *Stack { + return n.stack +} + type networkEndpointKind int32 const ( @@ -823,3 +828,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool { } } } + +// stack returns the Stack instance that owns the underlying endpoint. +func (r *referencedNetworkEndpoint) stack() *Stack { + return r.nic.stack +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2037eef9f..67b70b2ee 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -109,7 +109,7 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index e52cdd674..5c8b7977a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -217,3 +217,8 @@ func (r *Route) MakeLoopedRoute() Route { } return l } + +// Stack returns the instance of the Stack that owns this route. +func (r *Route) Stack() *Stack { + return r.ref.stack() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d69162ba1..1d5e84a8b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -28,6 +28,7 @@ import ( "sync" "time" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -389,6 +390,10 @@ type Stack struct { // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. resumableEndpoints []ResumableEndpoint + + // icmpRateLimiter is a global rate limiter for all ICMP messages generated + // by the stack. + icmpRateLimiter *ICMPRateLimiter } // Options contains optional Stack configuration. @@ -434,6 +439,7 @@ func New(network []string, transport []string, opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, raw: opts.Raw, + icmpRateLimiter: NewICMPRateLimiter(), } // Add specified network protocols. @@ -1215,3 +1221,33 @@ func (s *Stack) IPTables() iptables.IPTables { func (s *Stack) SetIPTables(ipt iptables.IPTables) { s.tables = ipt } + +// ICMPLimit returns the maximum number of ICMP messages that can be sent +// in one second. +func (s *Stack) ICMPLimit() rate.Limit { + return s.icmpRateLimiter.Limit() +} + +// SetICMPLimit sets the maximum number of ICMP messages that be sent +// in one second. +func (s *Stack) SetICMPLimit(newLimit rate.Limit) { + s.icmpRateLimiter.SetLimit(newLimit) +} + +// ICMPBurst returns the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) ICMPBurst() int { + return s.icmpRateLimiter.Burst() +} + +// SetICMPBurst sets the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) SetICMPBurst(burst int) { + s.icmpRateLimiter.SetBurst(burst) +} + +// AllowICMPMessage returns true if we the rate limiter allows at least one +// ICMP message to be sent at this instant. +func (s *Stack) AllowICMPMessage() bool { + return s.icmpRateLimiter.Allow() +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5335897f5..ca185279e 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -251,7 +251,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 8f9b86cce..05aa42c98 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -720,6 +720,10 @@ type ICMPv4SentPacketStats struct { // Dropped is the total number of ICMPv4 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. @@ -738,6 +742,10 @@ type ICMPv6SentPacketStats struct { // Dropped is the total number of ICMPv6 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 451d3880e..e1f622af6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,7 +15,6 @@ package icmp import ( - "encoding/binary" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -368,14 +367,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // Set the ident to the user-specified port. Sequence number should - // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + icmpv4.SetIdent(ident) data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. @@ -394,14 +392,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident) - - hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) copy(icmpv6, data) - data = data[header.ICMPv6EchoMinimumSize:] + // Set the ident. Sequence number is provided by the user. + icmpv6.SetIdent(ident) + data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return tcpip.ErrInvalidEndpointState diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 7fdba5d56..1eb790932 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -23,7 +23,6 @@ package icmp import ( - "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -92,7 +91,7 @@ func (p *protocol) MinimumPacketSize() int { case ProtocolNumber4: return header.ICMPv4MinimumSize case ProtocolNumber6: - return header.ICMPv6EchoMinimumSize + return header.ICMPv6MinimumSize } panic(fmt.Sprint("unknown protocol number: ", p.number)) } @@ -101,16 +100,18 @@ func (p *protocol) MinimumPacketSize() int { func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { switch p.number { case ProtocolNumber4: - return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil + hdr := header.ICMPv4(v) + return 0, hdr.Ident(), nil case ProtocolNumber6: - return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil + hdr := header.ICMPv6(v) + return 0, hdr.Ident(), nil } panic(fmt.Sprint("unknown protocol number: ", p.number)) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index ee04dcfcc..2a13b2022 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -129,7 +129,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 272481aa0..18c707a57 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -267,7 +267,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -286,9 +286,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) icmp.SetType(typ) icmp.SetCode(code) - - copy(icmp[header.ICMPv4PayloadOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) + const icmpv4VariableHeaderOffset = 4 + copy(icmp[icmpv4VariableHeaderOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f76e7fbe1..068d9a272 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -69,7 +69,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + // TODO(b/129426613): only send an ICMP message if UDP checksum is valid. + + // Only send ICMP error if the address is not a multicast/broadcast + // v4/v6 address or the source is not the unspecified address. + // + // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { + return true + } + + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported; or + // + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + switch len(id.LocalAddress) { + case header.IPv4AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() + return true + } + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 where it mentioned that at least 8 + // bytes of the payload must be included. Today linux and other + // systems implement the] RFC1812 definition and not the original + // RFC 1122 requirement. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4DstUnreachable) + pkt.SetCode(header.ICMPv4PortUnreachable) + pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + + case header.IPv6AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() + return true + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) + pkt.SetType(header.ICMPv6DstUnreachable) + pkt.SetCode(header.ICMPv6PortUnreachable) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 9da6edce2..995d6e8a1 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -461,7 +461,11 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { } func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) + return newMinPayload(30) +} + +func newMinPayload(minSize int) []byte { + b := make([]byte, minSize+rand.Intn(100)) for i := range b { b[i] = byte(rand.Intn(256)) } @@ -1238,3 +1242,153 @@ func TestMulticastInterfaceOption(t *testing.T) { }) } } + +// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV4UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true, will result in a payload large enough + // so that the final generated IPv4 packet is larger than + // header.IPv4MinimumProcessableDatagramSize. + largePayload bool + }{ + {unicastV4, true, false}, + {unicastV4, true, true}, + {multicastV4, false, false}, + {multicastV4, false, true}, + {broadcast, false, false}, + {broadcast, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(576) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } + + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} + +// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV6UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true will result in a payload large enough to + // create an IPv6 packet > header.IPv6MinimumMTU bytes. + largePayload bool + }{ + {unicastV6, true, false}, + {unicastV6, true, true}, + {multicastV6, false, false}, + {multicastV6, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(1280) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} -- cgit v1.2.3