diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 100 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv4.go | 71 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv6.go | 88 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv6.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 31 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/icmp_rate_limit.go | 86 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 36 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 101 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 156 |
23 files changed, 717 insertions, 94 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index afcabd51d..096ad71ab 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -586,3 +586,103 @@ func Payload(want []byte) TransportChecker { } } } + +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and +// potentially additional ICMPv4 header fields. +func ICMPv4(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) + } + + icmp := header.ICMPv4(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv4Type creates a checker that checks the ICMPv4 Type field. +func ICMPv4Type(want header.ICMPv4Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv4Code creates a checker that checks the ICMPv4 Code field. +func ICMPv4Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} + +// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and +// potentially additional ICMPv6 header fields. +func ICMPv6(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) + } + + icmp := header.ICMPv6(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv6Type creates a checker that checks the ICMPv6 Type field. +func ICMPv6Type(want header.ICMPv6Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv6Code creates a checker that checks the ICMPv6 Code field. +func ICMPv6Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c52c0d851..0cac6c0a5 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -25,13 +26,29 @@ type ICMPv4 []byte const ( // ICMPv4PayloadOffset defines the start of ICMP payload. - ICMPv4PayloadOffset = 4 + ICMPv4PayloadOffset = 8 // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. ICMPv4MinimumSize = 8 // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 + + // icmpv4ChecksumOffset is the offset of the checksum field + // in an ICMPv4 message. + icmpv4ChecksumOffset = 2 + + // icmpv4MTUOffset is the offset of the MTU field + // in a ICMPv4FragmentationNeeded message. + icmpv4MTUOffset = 6 + + // icmpv4IdentOffset is the offset of the ident field + // in a ICMPv4EchoRequest/Reply message. + icmpv4IdentOffset = 4 + + // icmpv4SequenceOffset is the offset of the sequence field + // in a ICMPv4EchoRequest/Reply message. + icmpv4SequenceOffset = 6 ) // ICMPv4Type is the ICMP type field described in RFC 792. @@ -72,12 +89,12 @@ func (b ICMPv4) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) } // SetChecksum sets the ICMP checksum field. func (b ICMPv4) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -102,3 +119,51 @@ func (ICMPv4) SetDestinationPort(uint16) { func (b ICMPv4) Payload() []byte { return b[ICMPv4PayloadOffset:] } + +// MTU retrieves the MTU field from an ICMPv4 message. +func (b ICMPv4) MTU() uint16 { + return binary.BigEndian.Uint16(b[icmpv4MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv4 message. +func (b ICMPv4) SetMTU(mtu uint16) { + binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv4 message. +func (b ICMPv4) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv4IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv4 message. +func (b ICMPv4) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv4 message. +func (b ICMPv4) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv4 message. +func (b ICMPv4) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence) +} + +// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, +// and payload. +func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := uint16(0) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 3cc57e234..1125a7d14 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -25,14 +26,18 @@ type ICMPv6 []byte const ( // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. - ICMPv6MinimumSize = 4 + ICMPv6MinimumSize = 8 + + // ICMPv6PayloadOffset is the offset of the payload in an + // ICMP packet. + ICMPv6PayloadOffset = 8 // ICMPv6ProtocolNumber is the ICMP transport protocol number. ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 // ICMPv6NeighborSolicitMinimumSize is the minimum size of a // neighbor solicitation packet. - ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16 + ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16 // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. ICMPv6NeighborAdvertSize = 32 @@ -42,11 +47,27 @@ const ( // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP // destination unreachable packet. - ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP // packet-too-big packet. - ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + + // icmpv6ChecksumOffset is the offset of the checksum field + // in an ICMPv6 message. + icmpv6ChecksumOffset = 2 + + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 + // PacketTooBig message. + icmpv6MTUOffset = 4 + + // icmpv6IdentOffset is the offset of the ident field + // in a ICMPv6 Echo Request/Reply message. + icmpv6IdentOffset = 4 + + // icmpv6SequenceOffset is the offset of the sequence field + // in a ICMPv6 Echo Request/Reply message. + icmpv6SequenceOffset = 6 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -89,12 +110,12 @@ func (b ICMPv6) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } // SetChecksum calculates and sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -115,7 +136,60 @@ func (ICMPv6) SetSourcePort(uint16) { func (ICMPv6) SetDestinationPort(uint16) { } +// MTU retrieves the MTU field from an ICMPv6 message. +func (b ICMPv6) MTU() uint32 { + return binary.BigEndian.Uint32(b[icmpv6MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv6 message. +func (b ICMPv6) SetMTU(mtu uint32) { + binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv6 message. +func (b ICMPv6) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv6IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv6 message. +func (b ICMPv6) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv6 message. +func (b ICMPv6) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv6 message. +func (b ICMPv6) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) +} + // Payload implements Transport.Payload. func (b ICMPv6) Payload() []byte { - return b[ICMPv6MinimumSize:] + return b[ICMPv6PayloadOffset:] +} + +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// IPv6 src/dst addresses and the payload. +func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := Checksum([]byte(src), 0) + xsum = Checksum([]byte(dst), xsum) + var upperLayerLength [4]byte + binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) + xsum = Checksum(upperLayerLength[:], xsum) + xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 17fc9c68e..554632a64 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -21,16 +21,18 @@ import ( ) const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksum = 10 - srcAddr = 12 - dstAddr = 16 + versIHL = 0 + tos = 1 + // IPv4TotalLenOffset is the offset of the total length field in the + // IPv4 header. + IPv4TotalLenOffset = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksum = 10 + srcAddr = 12 + dstAddr = 16 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -103,6 +105,11 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP + // packet that every IPv4 capable host must be able to + // process/reassemble. + IPv4MinimumProcessableDatagramSize = 576 ) // Flags that may be set in an IPv4 packet. @@ -163,7 +170,7 @@ func (b IPv4) FragmentOffset() uint16 { // TotalLength returns the "total length" field of the ipv4 header. func (b IPv4) TotalLength() uint16 { - return binary.BigEndian.Uint16(b[totalLen:]) + return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } // Checksum returns the checksum field of the ipv4 header. @@ -209,7 +216,7 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { - binary.BigEndian.PutUint16(b[totalLen:], totalLength) + binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } // SetChecksum sets the checksum field of the ipv4 header. @@ -265,7 +272,7 @@ func (b IPv4) Encode(i *IPv4Fields) { // packets are produced. func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) { b.SetTotalLength(totalLength) - checksum := Checksum(b[totalLen:totalLen+2], partialChecksum) + checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum) b.SetChecksum(^checksum) } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index bc4e56535..093850e25 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -22,12 +22,14 @@ import ( ) const ( - versTCFL = 0 - payloadLen = 4 - nextHdr = 6 - hopLimit = 7 - v6SrcAddr = 8 - v6DstAddr = v6SrcAddr + IPv6AddressSize + versTCFL = 0 + // IPv6PayloadLenOffset is the offset of the PayloadLength field in + // IPv6 header. + IPv6PayloadLenOffset = 4 + nextHdr = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -94,7 +96,7 @@ var IPv6EmptySubnet = func() tcpip.Subnet { // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { - return binary.BigEndian.Uint16(b[payloadLen:]) + return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) } // HopLimit returns the value of the "hop limit" field of the ipv6 header. @@ -148,7 +150,7 @@ func (b IPv6) SetTOS(t uint8, l uint32) { // SetPayloadLength sets the "payload length" field of the ipv6 header. func (b IPv6) SetPayloadLength(payloadLength uint16) { - binary.BigEndian.PutUint16(b[payloadLen:], payloadLength) + binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength) } // SetSourceAddress sets the "source address" field of the ipv6 header. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6bbfcd97f..4b3bd74fa 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -319,7 +319,8 @@ func TestIPv4ReceiveControl(t *testing.T) { icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) icmp.SetType(header.ICMPv4DstUnreachable) icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv4 header. ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) @@ -539,7 +540,7 @@ func TestIPv6ReceiveControl(t *testing.T) { defer ep.Close() - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4 + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize } @@ -559,10 +560,11 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) icmp.SetType(c.typ) icmp.SetCode(c.code) - copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) ip.Encode(&header.IPv6Fields{ PayloadLength: 100, NextHeader: 10, @@ -574,7 +576,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Build the fragmentation header if needed. if c.fragmentOffset != nil { ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) frag.Encode(&header.IPv6FragmentFields{ NextHeader: 10, FragmentOffset: *c.fragmentOffset, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 497164cbb..a25756443 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,8 +15,6 @@ package ipv4 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -117,7 +115,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.handleControl(stack.ControlPortUnreachable, 0, vv) case header.ICMPv4FragmentationNeeded: - mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:])) + mtu := uint32(h.MTU()) e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1689af16f..346de9ae3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,8 +15,6 @@ package ipv6 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -82,7 +80,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V return } vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:]) + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) case header.ICMPv6DstUnreachable: @@ -130,7 +128,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { sent.Dropped.Increment() @@ -162,7 +160,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { sent.Dropped.Increment() return @@ -233,7 +231,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr pkt[icmpV6LengthOffset] = 1 copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) length := uint16(hdr.UsedLength()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -272,24 +270,3 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return "", false } - -func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := header.Checksum([]byte(src), 0) - xsum = header.Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = header.Checksum(upperLayerLength[:], xsum) - xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = header.Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^header.Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum -} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d0dc72506..227a65cf2 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -153,7 +153,7 @@ func TestICMPCounts(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) handleIPv6Payload(hdr) } @@ -321,7 +321,7 @@ func TestLinkResolution(t *testing.T) { hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index b692c60ce..788de3dfe 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", "nic.go", @@ -42,6 +43,7 @@ go_library( "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go new file mode 100644 index 000000000..f8156be47 --- /dev/null +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -0,0 +1,86 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "sync" + + "golang.org/x/time/rate" +) + +const ( + // icmpLimit is the default maximum number of ICMP messages permitted by this + // rate limiter. + icmpLimit = 1000 + + // icmpBurst is the default number of ICMP messages that can be sent in a single + // burst. + icmpBurst = 50 +) + +// ICMPRateLimiter is a global rate limiter that controls the generation of +// ICMP messages generated by the stack. +type ICMPRateLimiter struct { + mu sync.RWMutex + l *rate.Limiter +} + +// NewICMPRateLimiter returns a global rate limiter for controlling the rate +// at which ICMP messages are generated by the stack. +func NewICMPRateLimiter() *ICMPRateLimiter { + return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)} +} + +// Allow returns true if we are allowed to send at least 1 message at the +// moment. +func (i *ICMPRateLimiter) Allow() bool { + i.mu.RLock() + allow := i.l.Allow() + i.mu.RUnlock() + return allow +} + +// Limit returns the maximum number of ICMP messages that can be sent in one +// second. +func (i *ICMPRateLimiter) Limit() rate.Limit { + i.mu.RLock() + defer i.mu.RUnlock() + return i.l.Limit() +} + +// SetLimit sets the maximum number of ICMP messages that can be sent in one +// second. +func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) { + i.mu.RLock() + defer i.mu.RUnlock() + i.l.SetLimit(newLimit) +} + +// Burst returns how many ICMP messages can be sent at any single instant. +func (i *ICMPRateLimiter) Burst() int { + i.mu.RLock() + defer i.mu.RUnlock() + return i.l.Burst() +} + +// SetBurst sets the maximum number of ICMP messages allowed at any single +// instant. +// +// NOTE: Changing Burst causes the underlying rate limiter to be recreated. +func (i *ICMPRateLimiter) SetBurst(burst int) { + i.mu.Lock() + i.l = rate.NewLimiter(i.l.Limit(), burst) + i.mu.Unlock() +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f947b55db..ae56e0ffd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -679,7 +679,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { n.stack.stats.MalformedRcvdPackets.Increment() } } @@ -720,6 +720,11 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Stack returns the instance of the Stack that owns this NIC. +func (n *NIC) Stack() *Stack { + return n.stack +} + type networkEndpointKind int32 const ( @@ -823,3 +828,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool { } } } + +// stack returns the Stack instance that owns the underlying endpoint. +func (r *referencedNetworkEndpoint) stack() *Stack { + return r.nic.stack +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2037eef9f..67b70b2ee 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -109,7 +109,7 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index e52cdd674..5c8b7977a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -217,3 +217,8 @@ func (r *Route) MakeLoopedRoute() Route { } return l } + +// Stack returns the instance of the Stack that owns this route. +func (r *Route) Stack() *Stack { + return r.ref.stack() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d69162ba1..1d5e84a8b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -28,6 +28,7 @@ import ( "sync" "time" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -389,6 +390,10 @@ type Stack struct { // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. resumableEndpoints []ResumableEndpoint + + // icmpRateLimiter is a global rate limiter for all ICMP messages generated + // by the stack. + icmpRateLimiter *ICMPRateLimiter } // Options contains optional Stack configuration. @@ -434,6 +439,7 @@ func New(network []string, transport []string, opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, raw: opts.Raw, + icmpRateLimiter: NewICMPRateLimiter(), } // Add specified network protocols. @@ -1215,3 +1221,33 @@ func (s *Stack) IPTables() iptables.IPTables { func (s *Stack) SetIPTables(ipt iptables.IPTables) { s.tables = ipt } + +// ICMPLimit returns the maximum number of ICMP messages that can be sent +// in one second. +func (s *Stack) ICMPLimit() rate.Limit { + return s.icmpRateLimiter.Limit() +} + +// SetICMPLimit sets the maximum number of ICMP messages that be sent +// in one second. +func (s *Stack) SetICMPLimit(newLimit rate.Limit) { + s.icmpRateLimiter.SetLimit(newLimit) +} + +// ICMPBurst returns the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) ICMPBurst() int { + return s.icmpRateLimiter.Burst() +} + +// SetICMPBurst sets the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) SetICMPBurst(burst int) { + s.icmpRateLimiter.SetBurst(burst) +} + +// AllowICMPMessage returns true if we the rate limiter allows at least one +// ICMP message to be sent at this instant. +func (s *Stack) AllowICMPMessage() bool { + return s.icmpRateLimiter.Allow() +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5335897f5..ca185279e 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -251,7 +251,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 8f9b86cce..05aa42c98 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -720,6 +720,10 @@ type ICMPv4SentPacketStats struct { // Dropped is the total number of ICMPv4 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. @@ -738,6 +742,10 @@ type ICMPv6SentPacketStats struct { // Dropped is the total number of ICMPv6 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 451d3880e..e1f622af6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,7 +15,6 @@ package icmp import ( - "encoding/binary" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -368,14 +367,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // Set the ident to the user-specified port. Sequence number should - // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + icmpv4.SetIdent(ident) data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. @@ -394,14 +392,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident) - - hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) copy(icmpv6, data) - data = data[header.ICMPv6EchoMinimumSize:] + // Set the ident. Sequence number is provided by the user. + icmpv6.SetIdent(ident) + data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return tcpip.ErrInvalidEndpointState diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 7fdba5d56..1eb790932 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -23,7 +23,6 @@ package icmp import ( - "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -92,7 +91,7 @@ func (p *protocol) MinimumPacketSize() int { case ProtocolNumber4: return header.ICMPv4MinimumSize case ProtocolNumber6: - return header.ICMPv6EchoMinimumSize + return header.ICMPv6MinimumSize } panic(fmt.Sprint("unknown protocol number: ", p.number)) } @@ -101,16 +100,18 @@ func (p *protocol) MinimumPacketSize() int { func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { switch p.number { case ProtocolNumber4: - return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil + hdr := header.ICMPv4(v) + return 0, hdr.Ident(), nil case ProtocolNumber6: - return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil + hdr := header.ICMPv6(v) + return 0, hdr.Ident(), nil } panic(fmt.Sprint("unknown protocol number: ", p.number)) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index ee04dcfcc..2a13b2022 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -129,7 +129,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 272481aa0..18c707a57 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -267,7 +267,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -286,9 +286,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) icmp.SetType(typ) icmp.SetCode(code) - - copy(icmp[header.ICMPv4PayloadOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) + const icmpv4VariableHeaderOffset = 4 + copy(icmp[icmpv4VariableHeaderOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f76e7fbe1..068d9a272 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -69,7 +69,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + // TODO(b/129426613): only send an ICMP message if UDP checksum is valid. + + // Only send ICMP error if the address is not a multicast/broadcast + // v4/v6 address or the source is not the unspecified address. + // + // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { + return true + } + + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported; or + // + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + switch len(id.LocalAddress) { + case header.IPv4AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() + return true + } + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 where it mentioned that at least 8 + // bytes of the payload must be included. Today linux and other + // systems implement the] RFC1812 definition and not the original + // RFC 1122 requirement. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4DstUnreachable) + pkt.SetCode(header.ICMPv4PortUnreachable) + pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + + case header.IPv6AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() + return true + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) + pkt.SetType(header.ICMPv6DstUnreachable) + pkt.SetCode(header.ICMPv6PortUnreachable) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 9da6edce2..995d6e8a1 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -461,7 +461,11 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { } func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) + return newMinPayload(30) +} + +func newMinPayload(minSize int) []byte { + b := make([]byte, minSize+rand.Intn(100)) for i := range b { b[i] = byte(rand.Intn(256)) } @@ -1238,3 +1242,153 @@ func TestMulticastInterfaceOption(t *testing.T) { }) } } + +// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV4UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true, will result in a payload large enough + // so that the final generated IPv4 packet is larger than + // header.IPv4MinimumProcessableDatagramSize. + largePayload bool + }{ + {unicastV4, true, false}, + {unicastV4, true, true}, + {multicastV4, false, false}, + {multicastV4, false, true}, + {broadcast, false, false}, + {broadcast, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(576) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } + + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} + +// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV6UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true will result in a payload large enough to + // create an IPv6 packet > header.IPv6MinimumMTU bytes. + largePayload bool + }{ + {unicastV6, true, false}, + {unicastV6, true, true}, + {multicastV6, false, false}, + {multicastV6, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(1280) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} |