diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-10-27 13:39:24 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-10-27 13:41:53 -0700 |
commit | 3015c0ac67ef7703899e753121efe326dc0cbecd (patch) | |
tree | 58dc69fb3f6fde39bb4fd8d0c18a79906c6bcfa4 /pkg/tcpip/tests/integration | |
parent | 22a6a37079c69129d10abfbdd6fdfdf7a9d4a68d (diff) |
NAT ICMPv4 errors
...so a NAT-ed connection's socket can handle ICMP errors.
Updates #5916.
PiperOrigin-RevId: 405970089
Diffstat (limited to 'pkg/tcpip/tests/integration')
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 293 |
1 files changed, 293 insertions, 0 deletions
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 957a779bf..9e00a6350 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -1779,3 +1779,296 @@ func TestNAT(t *testing.T) { }) } } + +func TestNATICMPError(t *testing.T) { + const srcPort = 1234 + const dstPort = 5432 + + type icmpTypeTest struct { + name string + val uint8 + expectResponse bool + } + + type transportTypeTest struct { + name string + proto tcpip.TransportProtocolNumber + buf buffer.View + checkNATed func(*testing.T, buffer.View) + } + + ipHdr := func(v buffer.View, totalLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) { + ip := header.IPv4(v) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(transProto), + TTL: 64, + SrcAddr: srcAddr, + DstAddr: dstAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + host1Addr tcpip.Address + icmpError func(*testing.T, buffer.View, uint8) buffer.View + decrementTTL func(buffer.View) + checkNATedError func(*testing.T, buffer.View, buffer.View, uint8) + + transportTypes []transportTypeTest + icmpTypes []icmpTypeTest + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(original)), original); n != len(original) { + t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) + } + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmp.SetType(header.ICMPv4Type(icmpType)) + icmp.SetChecksum(0) + icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0)) + ipHdr(hdr.Prepend(header.IPv4MinimumSize), + totalLen, + header.ICMPv4ProtocolNumber, + utils.Host1IPv4Addr.AddressWithPrefix.Address, + utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, + ) + return hdr.View() + }, + decrementTTL: func(v buffer.View) { + ip := header.IPv4(v) + ip.SetTTL(ip.TTL() - 1) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + }, + checkNATedError: func(t *testing.T, v buffer.View, original buffer.View, icmpType uint8) { + checker.IPv4(t, v, + checker.SrcAddr(utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(utils.Host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Type(icmpType)), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload(original), + ), + ) + }, + transportTypes: []transportTypeTest{ + { + name: "UDP", + proto: header.UDPProtocolNumber, + buf: func() buffer.View { + totalLen := header.IPv4MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLen) + udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udp.SetSourcePort(srcPort) + udp.SetDestinationPort(dstPort) + udp.SetChecksum(0) + udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + utils.Host2IPv4Addr.AddressWithPrefix.Address, + utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + uint16(len(udp)), + ))) + ipHdr(hdr.Prepend(header.IPv4MinimumSize), + totalLen, + header.UDPProtocolNumber, + utils.Host2IPv4Addr.AddressWithPrefix.Address, + utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + ) + return hdr.View() + }(), + checkNATed: func(t *testing.T, v buffer.View) { + checker.IPv4(t, v, + checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), + checker.UDP( + checker.SrcPort(srcPort), + checker.DstPort(dstPort), + ), + ) + }, + }, + { + name: "TCP", + proto: header.TCPProtocolNumber, + buf: func() buffer.View { + totalLen := header.IPv4MinimumSize + header.TCPMinimumSize + hdr := buffer.NewPrependable(totalLen) + tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcp.SetSourcePort(srcPort) + tcp.SetDestinationPort(dstPort) + tcp.SetDataOffset(header.TCPMinimumSize) + tcp.SetChecksum(0) + tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum( + header.TCPProtocolNumber, + utils.Host2IPv4Addr.AddressWithPrefix.Address, + utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + uint16(len(tcp)), + ))) + ipHdr(hdr.Prepend(header.IPv4MinimumSize), + totalLen, + header.TCPProtocolNumber, + utils.Host2IPv4Addr.AddressWithPrefix.Address, + utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + ) + return hdr.View() + }(), + checkNATed: func(t *testing.T, v buffer.View) { + checker.IPv4(t, v, + checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), + checker.TCP( + checker.SrcPort(srcPort), + checker.DstPort(dstPort), + ), + ) + }, + }, + }, + icmpTypes: []icmpTypeTest{ + { + name: "Destination Unreachable", + val: uint8(header.ICMPv4DstUnreachable), + expectResponse: true, + }, + { + name: "Time Exceeded", + val: uint8(header.ICMPv4TimeExceeded), + expectResponse: true, + }, + { + name: "Parameter Problem", + val: uint8(header.ICMPv4ParamProblem), + expectResponse: true, + }, + { + name: "Echo Request", + val: uint8(header.ICMPv4Echo), + expectResponse: false, + }, + { + name: "Echo Reply", + val: uint8(header.ICMPv4EchoReply), + expectResponse: false, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, transportType := range test.transportTypes { + t.Run(transportType.name, func(t *testing.T) { + for _, icmpType := range test.icmpTypes { + t.Run(icmpType.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + + ep1 := channel.New(1, header.IPv6MinimumMTU, "") + ep2 := channel.New(1, header.IPv6MinimumMTU, "") + utils.SetupRouterStack(t, s, ep1, ep2) + + ipv6 := test.netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + + table := stack.Table{ + Rules: []stack.Rule{ + // Prerouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, + }, + { + Target: &stack.AcceptTarget{}, + }, + + // Input + { + Target: &stack.AcceptTarget{}, + }, + + // Forward + { + Target: &stack.AcceptTarget{}, + }, + + // Output + { + Target: &stack.AcceptTarget{}, + }, + + // Postrouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, + }, + { + Target: &stack.AcceptTarget{}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 2, + stack.Forward: 3, + stack.Output: 4, + stack.Postrouting: 5, + }, + } + + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } + + ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(), + })) + + { + pkt, ok := ep1.Read() + if !ok { + t.Fatal("expected to read a packet on ep1") + } + pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) + transportType.checkNATed(t, pktView) + if t.Failed() { + t.FailNow() + } + + ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), + })) + } + + pkt, ok := ep2.Read() + if ok != icmpType.expectResponse { + t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse) + } + if !icmpType.expectResponse { + return + } + test.decrementTTL(transportType.buf) + test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val) + }) + } + }) + } + }) + } +} |