diff options
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 82 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 270 |
2 files changed, 200 insertions, 152 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index c489506bb..7fa657001 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -209,17 +209,41 @@ type bucket struct { tuples tupleList } -func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) { - switch pkt.tuple.id().transProto { +// A netAndTransHeadersFunc returns the network and transport headers found +// in an ICMP payload. The transport layer's payload will not be returned. +// +// May panic if the packet does not hold the transport header. +type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte) + +func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv4(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the buffer is smaller than + // the total length specified in the IPv4 header. + transHdr := icmpPayload[netHdr.HeaderLength():] + return netHdr, transHdr[:minTransHdrLen] +} + +func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv6(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the IP payload is smaller than + // the payload length specified in the IPv6 header. + transHdr := icmpPayload[header.IPv6MinimumSize:] + return netHdr, transHdr[:minTransHdrLen] +} + +func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) { + switch transProto { case header.TCPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.TCP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize) + return netHeader, header.TCP(transHeaderBytes), true } case header.UDPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.UDP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize) + return netHeader, header.UDP(transHeaderBytes), true } } return nil, nil, false @@ -246,7 +270,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic("should have dropped packets with IPv4 options") } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok { return netHdr, transHdr, true, true } case header.ICMPv6ProtocolNumber: @@ -264,7 +288,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto)) } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok { return netHdr, transHdr, true, true } } @@ -283,34 +307,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro } } -func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { - switch transProto { - case header.TCPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.TCP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } - case header.UDPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.UDP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } +func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok { + return tupleID{ + srcAddr: netHdr.DestinationAddress(), + srcPort: transHdr.DestinationPort(), + dstAddr: netHdr.SourceAddress(), + dstPort: transHdr.SourcePort(), + transProto: transProto, + netProto: netProto, + }, true } return tupleID{}, false @@ -349,7 +355,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { return tupleID{}, false, false } - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { return tid, true, true } case header.ICMPv6ProtocolNumber: @@ -370,7 +376,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { } // TODO(https://gvisor.dev/issue/6789): Handle extension headers. - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { return tid, true, true } } diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 7fe3b29d9..b2383576c 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) { } func TestNATICMPError(t *testing.T) { - const srcPort = 1234 - const dstPort = 5432 + const ( + srcPort = 1234 + dstPort = 5432 + dataSize = 4 + ) type icmpTypeTest struct { name string @@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) { netProto: ipv4.ProtocolNumber, host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original) - hdr := buffer.NewPrependable(totalLen) + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)) if n := copy(hdr.Prepend(len(original)), original); n != len(original) { t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) } @@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) { icmp.SetType(header.ICMPv4Type(icmpType)) icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0)) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.ICMPv4ProtocolNumber, utils.Host1IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, @@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.UDPMinimumSize - hdr := buffer.NewPrependable(totalLen) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.UDPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.TCPMinimumSize - hdr := buffer.NewPrependable(totalLen) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.TCPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) { Src: utils.Host1IPv6Addr.AddressWithPrefix.Address, Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, })) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), payloadLen, header.ICMPv6ProtocolNumber, utils.Host1IPv6Addr.AddressWithPrefix.Address, @@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.UDPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(udp), header.UDPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.TCPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(tcp), header.TCPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) { }, } + trimTests := []struct { + name string + trimLen int + expectNATedICMP bool + }{ + { + name: "Trim nothing", + trimLen: 0, + expectNATedICMP: true, + }, + { + name: "Trim data", + trimLen: dataSize, + expectNATedICMP: true, + }, + { + name: "Trim data and transport header", + trimLen: dataSize + 1, + expectNATedICMP: false, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { for _, transportType := range test.transportTypes { t.Run(transportType.name, func(t *testing.T) { for _, icmpType := range test.icmpTypes { t.Run(icmpType.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - - ep1 := channel.New(1, header.IPv6MinimumMTU, "") - ep2 := channel.New(1, header.IPv6MinimumMTU, "") - utils.SetupRouterStack(t, s, ep1, ep2) - - ipv6 := test.netProto == ipv6.ProtocolNumber - ipt := s.IPTables() - - table := stack.Table{ - Rules: []stack.Rule{ - // Prerouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - InputInterface: utils.RouterNIC2Name, + for _, trimTest := range trimTests { + t.Run(trimTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + + ep1 := channel.New(1, header.IPv6MinimumMTU, "") + ep2 := channel.New(1, header.IPv6MinimumMTU, "") + utils.SetupRouterStack(t, s, ep1, ep2) + + ipv6 := test.netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + + table := stack.Table{ + Rules: []stack.Rule{ + // Prerouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, + }, + { + Target: &stack.AcceptTarget{}, + }, + + // Input + { + Target: &stack.AcceptTarget{}, + }, + + // Forward + { + Target: &stack.AcceptTarget{}, + }, + + // Output + { + Target: &stack.AcceptTarget{}, + }, + + // Postrouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, + }, + { + Target: &stack.AcceptTarget{}, + }, }, - Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, - }, - { - Target: &stack.AcceptTarget{}, - }, - - // Input - { - Target: &stack.AcceptTarget{}, - }, - - // Forward - { - Target: &stack.AcceptTarget{}, - }, - - // Output - { - Target: &stack.AcceptTarget{}, - }, - - // Postrouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - OutputInterface: utils.RouterNIC1Name, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 2, + stack.Forward: 3, + stack.Output: 4, + stack.Postrouting: 5, }, - Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, - }, - { - Target: &stack.AcceptTarget{}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 2, - stack.Forward: 3, - stack.Output: 4, - stack.Postrouting: 5, - }, - } + } - if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) - } + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } - ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(), - })) + buf := transportType.buf - { - pkt, ok := ep1.Read() - if !ok { - t.Fatal("expected to read a packet on ep1") - } - pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) - transportType.checkNATed(t, pktView) - if t.Failed() { - t.FailNow() - } + ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: append(buffer.View(nil), buf...).ToVectorisedView(), + })) - ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), - })) - } + { + pkt, ok := ep1.Read() + if !ok { + t.Fatal("expected to read a packet on ep1") + } + pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) + transportType.checkNATed(t, pktView) + if t.Failed() { + t.FailNow() + } + + pktView = pktView[:len(pktView)-trimTest.trimLen] + buf = buf[:len(buf)-trimTest.trimLen] + + ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), + })) + } - pkt, ok := ep2.Read() - if ok != icmpType.expectResponse { - t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse) - } - if !icmpType.expectResponse { - return + pkt, ok := ep2.Read() + expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP + if ok != expectResponse { + t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse) + } + if !expectResponse { + return + } + test.decrementTTL(buf) + test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val) + }) } - test.decrementTTL(transportType.buf) - test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val) }) } }) |