summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/conntrack.go82
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go270
2 files changed, 200 insertions, 152 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index c489506bb..7fa657001 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -209,17 +209,41 @@ type bucket struct {
tuples tupleList
}
-func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) {
- switch pkt.tuple.id().transProto {
+// A netAndTransHeadersFunc returns the network and transport headers found
+// in an ICMP payload. The transport layer's payload will not be returned.
+//
+// May panic if the packet does not hold the transport header.
+type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte)
+
+func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv4(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the buffer is smaller than
+ // the total length specified in the IPv4 header.
+ transHdr := icmpPayload[netHdr.HeaderLength():]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv6(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the IP payload is smaller than
+ // the payload length specified in the IPv6 header.
+ transHdr := icmpPayload[header.IPv6MinimumSize:]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) {
+ switch transProto {
case header.TCPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.TCP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize)
+ return netHeader, header.TCP(transHeaderBytes), true
}
case header.UDPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.UDP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize)
+ return netHeader, header.UDP(transHeaderBytes), true
}
}
return nil, nil, false
@@ -246,7 +270,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic("should have dropped packets with IPv4 options")
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok {
return netHdr, transHdr, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -264,7 +288,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto))
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok {
return netHdr, transHdr, true, true
}
}
@@ -283,34 +307,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro
}
}
-func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
- switch transProto {
- case header.TCPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.TCP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
- case header.UDPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.UDP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
+func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok {
+ return tupleID{
+ srcAddr: netHdr.DestinationAddress(),
+ srcPort: transHdr.DestinationPort(),
+ dstAddr: netHdr.SourceAddress(),
+ dstPort: transHdr.SourcePort(),
+ transProto: transProto,
+ netProto: netProto,
+ }, true
}
return tupleID{}, false
@@ -349,7 +355,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
return tupleID{}, false, false
}
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
return tid, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -370,7 +376,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
}
// TODO(https://gvisor.dev/issue/6789): Handle extension headers.
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
return tid, true, true
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 7fe3b29d9..b2383576c 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) {
}
func TestNATICMPError(t *testing.T) {
- const srcPort = 1234
- const dstPort = 5432
+ const (
+ srcPort = 1234
+ dstPort = 5432
+ dataSize = 4
+ )
type icmpTypeTest struct {
name string
@@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) {
netProto: ipv4.ProtocolNumber,
host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)
- hdr := buffer.NewPrependable(totalLen)
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original))
if n := copy(hdr.Prepend(len(original)), original); n != len(original) {
t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
}
@@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) {
icmp.SetType(header.ICMPv4Type(icmpType))
icmp.SetChecksum(0)
icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.ICMPv4ProtocolNumber,
utils.Host1IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
@@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.UDPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.UDPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.TCPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.TCPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) {
Src: utils.Host1IPv6Addr.AddressWithPrefix.Address,
Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
}))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
payloadLen,
header.ICMPv6ProtocolNumber,
utils.Host1IPv6Addr.AddressWithPrefix.Address,
@@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.UDPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(udp),
header.UDPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.TCPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(tcp),
header.TCPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) {
},
}
+ trimTests := []struct {
+ name string
+ trimLen int
+ expectNATedICMP bool
+ }{
+ {
+ name: "Trim nothing",
+ trimLen: 0,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data",
+ trimLen: dataSize,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data and transport header",
+ trimLen: dataSize + 1,
+ expectNATedICMP: false,
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, transportType := range test.transportTypes {
t.Run(transportType.name, func(t *testing.T) {
for _, icmpType := range test.icmpTypes {
t.Run(icmpType.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
-
- ep1 := channel.New(1, header.IPv6MinimumMTU, "")
- ep2 := channel.New(1, header.IPv6MinimumMTU, "")
- utils.SetupRouterStack(t, s, ep1, ep2)
-
- ipv6 := test.netProto == ipv6.ProtocolNumber
- ipt := s.IPTables()
-
- table := stack.Table{
- Rules: []stack.Rule{
- // Prerouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- InputInterface: utils.RouterNIC2Name,
+ for _, trimTest := range trimTests {
+ t.Run(trimTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+
+ ep1 := channel.New(1, header.IPv6MinimumMTU, "")
+ ep2 := channel.New(1, header.IPv6MinimumMTU, "")
+ utils.SetupRouterStack(t, s, ep1, ep2)
+
+ ipv6 := test.netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
},
- Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
- },
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Input
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Forward
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Output
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Postrouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- OutputInterface: utils.RouterNIC1Name,
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
},
- Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
- },
- {
- Target: &stack.AcceptTarget{},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 2,
- stack.Forward: 3,
- stack.Output: 4,
- stack.Postrouting: 5,
- },
- }
+ }
- if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
- }
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
- ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(),
- }))
+ buf := transportType.buf
- {
- pkt, ok := ep1.Read()
- if !ok {
- t.Fatal("expected to read a packet on ep1")
- }
- pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
- transportType.checkNATed(t, pktView)
- if t.Failed() {
- t.FailNow()
- }
+ ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: append(buffer.View(nil), buf...).ToVectorisedView(),
+ }))
- ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
- }))
- }
+ {
+ pkt, ok := ep1.Read()
+ if !ok {
+ t.Fatal("expected to read a packet on ep1")
+ }
+ pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
+ transportType.checkNATed(t, pktView)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ pktView = pktView[:len(pktView)-trimTest.trimLen]
+ buf = buf[:len(buf)-trimTest.trimLen]
+
+ ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
+ }))
+ }
- pkt, ok := ep2.Read()
- if ok != icmpType.expectResponse {
- t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse)
- }
- if !icmpType.expectResponse {
- return
+ pkt, ok := ep2.Read()
+ expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP
+ if ok != expectResponse {
+ t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse)
+ }
+ if !expectResponse {
+ return
+ }
+ test.decrementTTL(buf)
+ test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val)
+ })
}
- test.decrementTTL(transportType.buf)
- test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val)
})
}
})