summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4
diff options
context:
space:
mode:
authorBruno Dal Bo <brunodalbo@google.com>2021-09-22 15:01:15 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-22 15:07:05 -0700
commit586f147cd6f0324328a318324049b2b54e3e7bcd (patch)
tree7c4775e1ed3a46c7084ad9011914331cbb8885a9 /pkg/tcpip/network/ipv4
parent4f67756752002dc72bb64cdecd1fa17746f8217f (diff)
Do not rate limit ICMP Echos by default
As per https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt linux does not limit ICMP Echos by default. icmp_ratemask - INTEGER Mask made of ICMP types for which rates are being limited. Significant bits: IHGFEDCBA9876543210 Default mask: 0000001100000011000 (6168) Bit definitions (see include/linux/icmp.h): 0 Echo Reply 3 Destination Unreachable * 4 Source Quench * 5 Redirect 8 Echo Request B Time Exceeded * C Parameter Problem * D Timestamp Request E Timestamp Reply F Info Request G Info Reply H Address Mask Request I Address Mask Reply * These are rate limited by default (see default mask above) Equivalently for ICMPv6. Lay out foundation for ICMP rate masks, exposing that configuration will be addressed later when the need arises (#6521). Fixes #6519 PiperOrigin-RevId: 398337963
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go89
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go28
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go136
3 files changed, 202 insertions, 51 deletions
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 2aa38eb98..d51c36f19 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -240,12 +240,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.echoRequest.Increment()
- sent := e.stats.icmp.packetsSent
- if !e.protocol.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return
- }
-
// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
// be modifying fields.
@@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
defer r.Release()
+ sent := e.stats.icmp.packetsSent
+ if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) {
+ sent.rateLimited.Increment()
+ return
+ }
+
// TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
// header information, we may have to change this code to handle the
// ICMP header no longer being in the data buffer.
@@ -562,13 +562,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
@@ -606,6 +599,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) {
+ switch reason := reason.(type) {
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonProtoUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetworkUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonFragmentationNeeded:
+ return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0
+ case *icmpReasonTTLExceeded:
+ return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0
+ case *icmpReasonParamProblem:
+ return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType, icmpCode) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
// Now work out how much of the triggering packet we should return.
// As per RFC 1812 Section 4.3.2.3
//
@@ -658,44 +680,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonProtoUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetworkUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4NetUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4HostUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonFragmentationNeeded:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
- counter = sent.dstUnreachable
- case *icmpReasonTTLExceeded:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.timeExceeded
- case *icmpReasonParamProblem:
- icmpHdr.SetType(header.ICMPv4ParamProblem)
- icmpHdr.SetCode(header.ICMPv4UnusedCode)
- icmpHdr.SetPointer(reason.pointer)
- counter = sent.paramProblem
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetPointer(pointer)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 63223bc92..25f5a52e3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -1215,6 +1215,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv4Type]struct{}
}
// defaultTTL is the current default TTL for the protocol. Only the
@@ -1330,6 +1333,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type and code may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool {
+ // Mimic linux and never rate limit for PMTU discovery.
+ // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288
+ if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded {
+ return true
+ }
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
@@ -1409,6 +1429,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
+ // Set ICMP rate limiting to Linux defaults.
+ // See https://man7.org/linux/man-pages/man7/icmp.7.html.
+ p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{
+ header.ICMPv4DstUnreachable: struct{}{},
+ header.ICMPv4SrcQuench: struct{}{},
+ header.ICMPv4TimeExceeded: struct{}{},
+ header.ICMPv4ParamProblem: struct{}{},
+ }
return p
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index e7b5b3ea2..ef91245d7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -3373,3 +3373,139 @@ func TestCloseLocking(t *testing.T) {
}
}()
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv4ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.UDPProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}