diff options
author | Bruno Dal Bo <brunodalbo@google.com> | 2021-09-22 15:01:15 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-09-22 15:07:05 -0700 |
commit | 586f147cd6f0324328a318324049b2b54e3e7bcd (patch) | |
tree | 7c4775e1ed3a46c7084ad9011914331cbb8885a9 /pkg/tcpip/network/ipv4 | |
parent | 4f67756752002dc72bb64cdecd1fa17746f8217f (diff) |
Do not rate limit ICMP Echos by default
As per https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt
linux does not limit ICMP Echos by default.
icmp_ratemask - INTEGER
Mask made of ICMP types for which rates are being limited.
Significant bits: IHGFEDCBA9876543210
Default mask: 0000001100000011000 (6168)
Bit definitions (see include/linux/icmp.h):
0 Echo Reply
3 Destination Unreachable *
4 Source Quench *
5 Redirect
8 Echo Request
B Time Exceeded *
C Parameter Problem *
D Timestamp Request
E Timestamp Reply
F Info Request
G Info Reply
H Address Mask Request
I Address Mask Reply
* These are rate limited by default (see default mask above)
Equivalently for ICMPv6.
Lay out foundation for ICMP rate masks, exposing that configuration will be
addressed later when the need arises (#6521).
Fixes #6519
PiperOrigin-RevId: 398337963
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 89 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 136 |
3 files changed, 202 insertions, 51 deletions
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 2aa38eb98..d51c36f19 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -240,12 +240,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.echoRequest.Increment() - sent := e.stats.icmp.packetsSent - if !e.protocol.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return - } - // DeliverTransportPacket will take ownership of pkt so don't use it beyond // this point. Make a deep copy of the data before pkt gets sent as we will // be modifying fields. @@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } defer r.Release() + sent := e.stats.icmp.packetsSent + if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) { + sent.rateLimited.Increment() + return + } + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -562,13 +562,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. @@ -606,6 +599,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } } + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) { + switch reason := reason.(type) { + case *icmpReasonPortUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonProtoUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetworkUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0 + case *icmpReasonFragmentationNeeded: + return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0 + case *icmpReasonTTLExceeded: + return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0 + case *icmpReasonParamProblem: + return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + }() + + if !p.allowICMPReply(icmpType, icmpCode) { + sent.rateLimited.Increment() + return nil + } + // Now work out how much of the triggering packet we should return. // As per RFC 1812 Section 4.3.2.3 // @@ -658,44 +680,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonProtoUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetworkUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4NetUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4HostUnreachable) - counter = sent.dstUnreachable - case *icmpReasonFragmentationNeeded: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) - counter = sent.dstUnreachable - case *icmpReasonTTLExceeded: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.timeExceeded - case *icmpReasonParamProblem: - icmpHdr.SetType(header.ICMPv4ParamProblem) - icmpHdr.SetCode(header.ICMPv4UnusedCode) - icmpHdr.SetPointer(reason.pointer) - counter = sent.paramProblem - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetCode(icmpCode) + icmpHdr.SetType(icmpType) + icmpHdr.SetPointer(pointer) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 63223bc92..25f5a52e3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -1215,6 +1215,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv4Type]struct{} } // defaultTTL is the current default TTL for the protocol. Only the @@ -1330,6 +1333,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type and code may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool { + // Mimic linux and never rate limit for PMTU discovery. + // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288 + if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded { + return true + } + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { @@ -1409,6 +1429,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) + // Set ICMP rate limiting to Linux defaults. + // See https://man7.org/linux/man-pages/man7/icmp.7.html. + p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{ + header.ICMPv4DstUnreachable: struct{}{}, + header.ICMPv4SrcQuench: struct{}{}, + header.ICMPv4TimeExceeded: struct{}{}, + header.ICMPv4ParamProblem: struct{}{}, + } return p } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index e7b5b3ea2..ef91245d7 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -3373,3 +3373,139 @@ func TestCloseLocking(t *testing.T) { } }() } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv4ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.UDPProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} |