diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-11-01 18:10:58 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-11-01 18:13:01 -0700 |
commit | 42a08f036f7be69b6c6d1b971911cd1aea611ece (patch) | |
tree | 548f6f9f1d3aed23a59548070ca1c76a8211ffb0 | |
parent | 58017e655399384afed2cedea0e269cb1ad2dd7e (diff) |
Allow partial packets in ICMP errors when NATing
An ICMP error may not hold the full packet that triggered the ICMP
response. As long as the IP header and the transport header is
parsable, we should be able to successfully NAT as that is all that
we need to identify the connection.
PiperOrigin-RevId: 406966048
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 82 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 270 |
2 files changed, 200 insertions, 152 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index c489506bb..7fa657001 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -209,17 +209,41 @@ type bucket struct { tuples tupleList } -func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) { - switch pkt.tuple.id().transProto { +// A netAndTransHeadersFunc returns the network and transport headers found +// in an ICMP payload. The transport layer's payload will not be returned. +// +// May panic if the packet does not hold the transport header. +type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte) + +func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv4(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the buffer is smaller than + // the total length specified in the IPv4 header. + transHdr := icmpPayload[netHdr.HeaderLength():] + return netHdr, transHdr[:minTransHdrLen] +} + +func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv6(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the IP payload is smaller than + // the payload length specified in the IPv6 header. + transHdr := icmpPayload[header.IPv6MinimumSize:] + return netHdr, transHdr[:minTransHdrLen] +} + +func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) { + switch transProto { case header.TCPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.TCP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize) + return netHeader, header.TCP(transHeaderBytes), true } case header.UDPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.UDP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize) + return netHeader, header.UDP(transHeaderBytes), true } } return nil, nil, false @@ -246,7 +270,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic("should have dropped packets with IPv4 options") } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok { return netHdr, transHdr, true, true } case header.ICMPv6ProtocolNumber: @@ -264,7 +288,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto)) } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok { return netHdr, transHdr, true, true } } @@ -283,34 +307,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro } } -func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { - switch transProto { - case header.TCPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.TCP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } - case header.UDPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.UDP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } +func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok { + return tupleID{ + srcAddr: netHdr.DestinationAddress(), + srcPort: transHdr.DestinationPort(), + dstAddr: netHdr.SourceAddress(), + dstPort: transHdr.SourcePort(), + transProto: transProto, + netProto: netProto, + }, true } return tupleID{}, false @@ -349,7 +355,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { return tupleID{}, false, false } - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { return tid, true, true } case header.ICMPv6ProtocolNumber: @@ -370,7 +376,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { } // TODO(https://gvisor.dev/issue/6789): Handle extension headers. - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { return tid, true, true } } diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 7fe3b29d9..b2383576c 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) { } func TestNATICMPError(t *testing.T) { - const srcPort = 1234 - const dstPort = 5432 + const ( + srcPort = 1234 + dstPort = 5432 + dataSize = 4 + ) type icmpTypeTest struct { name string @@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) { netProto: ipv4.ProtocolNumber, host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original) - hdr := buffer.NewPrependable(totalLen) + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)) if n := copy(hdr.Prepend(len(original)), original); n != len(original) { t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) } @@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) { icmp.SetType(header.ICMPv4Type(icmpType)) icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0)) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.ICMPv4ProtocolNumber, utils.Host1IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, @@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.UDPMinimumSize - hdr := buffer.NewPrependable(totalLen) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.UDPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.TCPMinimumSize - hdr := buffer.NewPrependable(totalLen) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.TCPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) { Src: utils.Host1IPv6Addr.AddressWithPrefix.Address, Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, })) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), payloadLen, header.ICMPv6ProtocolNumber, utils.Host1IPv6Addr.AddressWithPrefix.Address, @@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.UDPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(udp), header.UDPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.TCPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(tcp), header.TCPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) { }, } + trimTests := []struct { + name string + trimLen int + expectNATedICMP bool + }{ + { + name: "Trim nothing", + trimLen: 0, + expectNATedICMP: true, + }, + { + name: "Trim data", + trimLen: dataSize, + expectNATedICMP: true, + }, + { + name: "Trim data and transport header", + trimLen: dataSize + 1, + expectNATedICMP: false, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { for _, transportType := range test.transportTypes { t.Run(transportType.name, func(t *testing.T) { for _, icmpType := range test.icmpTypes { t.Run(icmpType.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - - ep1 := channel.New(1, header.IPv6MinimumMTU, "") - ep2 := channel.New(1, header.IPv6MinimumMTU, "") - utils.SetupRouterStack(t, s, ep1, ep2) - - ipv6 := test.netProto == ipv6.ProtocolNumber - ipt := s.IPTables() - - table := stack.Table{ - Rules: []stack.Rule{ - // Prerouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - InputInterface: utils.RouterNIC2Name, + for _, trimTest := range trimTests { + t.Run(trimTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + + ep1 := channel.New(1, header.IPv6MinimumMTU, "") + ep2 := channel.New(1, header.IPv6MinimumMTU, "") + utils.SetupRouterStack(t, s, ep1, ep2) + + ipv6 := test.netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + + table := stack.Table{ + Rules: []stack.Rule{ + // Prerouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, + }, + { + Target: &stack.AcceptTarget{}, + }, + + // Input + { + Target: &stack.AcceptTarget{}, + }, + + // Forward + { + Target: &stack.AcceptTarget{}, + }, + + // Output + { + Target: &stack.AcceptTarget{}, + }, + + // Postrouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, + }, + { + Target: &stack.AcceptTarget{}, + }, }, - Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, - }, - { - Target: &stack.AcceptTarget{}, - }, - - // Input - { - Target: &stack.AcceptTarget{}, - }, - - // Forward - { - Target: &stack.AcceptTarget{}, - }, - - // Output - { - Target: &stack.AcceptTarget{}, - }, - - // Postrouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - OutputInterface: utils.RouterNIC1Name, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 2, + stack.Forward: 3, + stack.Output: 4, + stack.Postrouting: 5, }, - Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, - }, - { - Target: &stack.AcceptTarget{}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 2, - stack.Forward: 3, - stack.Output: 4, - stack.Postrouting: 5, - }, - } + } - if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) - } + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } - ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(), - })) + buf := transportType.buf - { - pkt, ok := ep1.Read() - if !ok { - t.Fatal("expected to read a packet on ep1") - } - pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) - transportType.checkNATed(t, pktView) - if t.Failed() { - t.FailNow() - } + ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: append(buffer.View(nil), buf...).ToVectorisedView(), + })) - ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), - })) - } + { + pkt, ok := ep1.Read() + if !ok { + t.Fatal("expected to read a packet on ep1") + } + pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) + transportType.checkNATed(t, pktView) + if t.Failed() { + t.FailNow() + } + + pktView = pktView[:len(pktView)-trimTest.trimLen] + buf = buf[:len(buf)-trimTest.trimLen] + + ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), + })) + } - pkt, ok := ep2.Read() - if ok != icmpType.expectResponse { - t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse) - } - if !icmpType.expectResponse { - return + pkt, ok := ep2.Read() + expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP + if ok != expectResponse { + t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse) + } + if !expectResponse { + return + } + test.decrementTTL(buf) + test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val) + }) } - test.decrementTTL(transportType.buf) - test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val) }) } }) |