summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-11-01 18:10:58 -0700
committergVisor bot <gvisor-bot@google.com>2021-11-01 18:13:01 -0700
commit42a08f036f7be69b6c6d1b971911cd1aea611ece (patch)
tree548f6f9f1d3aed23a59548070ca1c76a8211ffb0
parent58017e655399384afed2cedea0e269cb1ad2dd7e (diff)
Allow partial packets in ICMP errors when NATing
An ICMP error may not hold the full packet that triggered the ICMP response. As long as the IP header and the transport header is parsable, we should be able to successfully NAT as that is all that we need to identify the connection. PiperOrigin-RevId: 406966048
-rw-r--r--pkg/tcpip/stack/conntrack.go82
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go270
2 files changed, 200 insertions, 152 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index c489506bb..7fa657001 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -209,17 +209,41 @@ type bucket struct {
tuples tupleList
}
-func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) {
- switch pkt.tuple.id().transProto {
+// A netAndTransHeadersFunc returns the network and transport headers found
+// in an ICMP payload. The transport layer's payload will not be returned.
+//
+// May panic if the packet does not hold the transport header.
+type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte)
+
+func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv4(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the buffer is smaller than
+ // the total length specified in the IPv4 header.
+ transHdr := icmpPayload[netHdr.HeaderLength():]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv6(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the IP payload is smaller than
+ // the payload length specified in the IPv6 header.
+ transHdr := icmpPayload[header.IPv6MinimumSize:]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) {
+ switch transProto {
case header.TCPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.TCP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize)
+ return netHeader, header.TCP(transHeaderBytes), true
}
case header.UDPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.UDP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize)
+ return netHeader, header.UDP(transHeaderBytes), true
}
}
return nil, nil, false
@@ -246,7 +270,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic("should have dropped packets with IPv4 options")
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok {
return netHdr, transHdr, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -264,7 +288,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto))
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok {
return netHdr, transHdr, true, true
}
}
@@ -283,34 +307,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro
}
}
-func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
- switch transProto {
- case header.TCPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.TCP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
- case header.UDPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.UDP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
+func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok {
+ return tupleID{
+ srcAddr: netHdr.DestinationAddress(),
+ srcPort: transHdr.DestinationPort(),
+ dstAddr: netHdr.SourceAddress(),
+ dstPort: transHdr.SourcePort(),
+ transProto: transProto,
+ netProto: netProto,
+ }, true
}
return tupleID{}, false
@@ -349,7 +355,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
return tupleID{}, false, false
}
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
return tid, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -370,7 +376,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
}
// TODO(https://gvisor.dev/issue/6789): Handle extension headers.
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
return tid, true, true
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 7fe3b29d9..b2383576c 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) {
}
func TestNATICMPError(t *testing.T) {
- const srcPort = 1234
- const dstPort = 5432
+ const (
+ srcPort = 1234
+ dstPort = 5432
+ dataSize = 4
+ )
type icmpTypeTest struct {
name string
@@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) {
netProto: ipv4.ProtocolNumber,
host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)
- hdr := buffer.NewPrependable(totalLen)
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original))
if n := copy(hdr.Prepend(len(original)), original); n != len(original) {
t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
}
@@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) {
icmp.SetType(header.ICMPv4Type(icmpType))
icmp.SetChecksum(0)
icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.ICMPv4ProtocolNumber,
utils.Host1IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
@@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.UDPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.UDPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.TCPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.TCPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) {
Src: utils.Host1IPv6Addr.AddressWithPrefix.Address,
Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
}))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
payloadLen,
header.ICMPv6ProtocolNumber,
utils.Host1IPv6Addr.AddressWithPrefix.Address,
@@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.UDPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(udp),
header.UDPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.TCPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(tcp),
header.TCPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) {
},
}
+ trimTests := []struct {
+ name string
+ trimLen int
+ expectNATedICMP bool
+ }{
+ {
+ name: "Trim nothing",
+ trimLen: 0,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data",
+ trimLen: dataSize,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data and transport header",
+ trimLen: dataSize + 1,
+ expectNATedICMP: false,
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, transportType := range test.transportTypes {
t.Run(transportType.name, func(t *testing.T) {
for _, icmpType := range test.icmpTypes {
t.Run(icmpType.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
-
- ep1 := channel.New(1, header.IPv6MinimumMTU, "")
- ep2 := channel.New(1, header.IPv6MinimumMTU, "")
- utils.SetupRouterStack(t, s, ep1, ep2)
-
- ipv6 := test.netProto == ipv6.ProtocolNumber
- ipt := s.IPTables()
-
- table := stack.Table{
- Rules: []stack.Rule{
- // Prerouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- InputInterface: utils.RouterNIC2Name,
+ for _, trimTest := range trimTests {
+ t.Run(trimTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+
+ ep1 := channel.New(1, header.IPv6MinimumMTU, "")
+ ep2 := channel.New(1, header.IPv6MinimumMTU, "")
+ utils.SetupRouterStack(t, s, ep1, ep2)
+
+ ipv6 := test.netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
},
- Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
- },
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Input
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Forward
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Output
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Postrouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- OutputInterface: utils.RouterNIC1Name,
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
},
- Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
- },
- {
- Target: &stack.AcceptTarget{},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 2,
- stack.Forward: 3,
- stack.Output: 4,
- stack.Postrouting: 5,
- },
- }
+ }
- if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
- }
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
- ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(),
- }))
+ buf := transportType.buf
- {
- pkt, ok := ep1.Read()
- if !ok {
- t.Fatal("expected to read a packet on ep1")
- }
- pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
- transportType.checkNATed(t, pktView)
- if t.Failed() {
- t.FailNow()
- }
+ ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: append(buffer.View(nil), buf...).ToVectorisedView(),
+ }))
- ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
- }))
- }
+ {
+ pkt, ok := ep1.Read()
+ if !ok {
+ t.Fatal("expected to read a packet on ep1")
+ }
+ pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
+ transportType.checkNATed(t, pktView)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ pktView = pktView[:len(pktView)-trimTest.trimLen]
+ buf = buf[:len(buf)-trimTest.trimLen]
+
+ ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
+ }))
+ }
- pkt, ok := ep2.Read()
- if ok != icmpType.expectResponse {
- t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse)
- }
- if !icmpType.expectResponse {
- return
+ pkt, ok := ep2.Read()
+ expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP
+ if ok != expectResponse {
+ t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse)
+ }
+ if !expectResponse {
+ return
+ }
+ test.decrementTTL(buf)
+ test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val)
+ })
}
- test.decrementTTL(transportType.buf)
- test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val)
})
}
})