diff options
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 156 |
1 files changed, 137 insertions, 19 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index a3f403855..4a28be585 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -209,19 +209,120 @@ type bucket struct { tuples tupleList } -func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { +func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.ChecksummableTransport, isICMPError bool, ok bool) { switch pkt.TransportProtocolNumber { case header.TCPProtocolNumber: if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { - return tcpHeader, true + return pkt.Network(), tcpHeader, false, true } case header.UDPProtocolNumber: if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { - return udpHeader, true + return pkt.Network(), udpHeader, false, true + } + case header.ICMPv4ProtocolNumber: + h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) + if !ok { + panic(fmt.Sprintf("should have a valid IPv4 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv4MinimumSize)) + } + + ipv4 := header.IPv4(h) + if ipv4.HeaderLength() > header.IPv4MinimumSize { + // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. + panic("should have dropped packets with IPv4 options") + } + + switch pkt.tuple.id().transProto { + case header.TCPProtocolNumber: + // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. + netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize) + if !ok { + return nil, nil, false, false + } + netHeader := header.IPv4(netAndTransHeader) + return netHeader, header.TCP(netHeader.Payload()), true, true + case header.UDPProtocolNumber: + // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. + netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.UDPMinimumSize) + if !ok { + return nil, nil, false, false + } + netHeader := header.IPv4(netAndTransHeader) + return netHeader, header.UDP(netHeader.Payload()), true, true } } - return nil, false + return nil, nil, false, false +} + +func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID { + return tupleID{ + srcAddr: netHdr.SourceAddress(), + srcPort: transHdr.SourcePort(), + dstAddr: netHdr.DestinationAddress(), + dstPort: transHdr.DestinationPort(), + transProto: transProto, + netProto: netProto, + } +} + +func getTupleIDForPacketInICMPError(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID { + return tupleID{ + srcAddr: netHdr.DestinationAddress(), + srcPort: transHdr.DestinationPort(), + dstAddr: netHdr.SourceAddress(), + dstPort: transHdr.SourcePort(), + transProto: transProto, + netProto: netProto, + } +} + +func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if transHeader := header.TCP(pkt.TransportHeader().View()); len(transHeader) >= header.TCPMinimumSize { + return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true + } + case header.UDPProtocolNumber: + if transHeader := header.UDP(pkt.TransportHeader().View()); len(transHeader) >= header.UDPMinimumSize { + return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true + } + case header.ICMPv4ProtocolNumber: + icmp := header.ICMPv4(pkt.TransportHeader().View()) + if len(icmp) < header.ICMPv4MinimumSize { + return tupleID{}, false, false + } + + switch icmp.Type() { + case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem: + default: + return tupleID{}, false, false + } + + h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) + if !ok { + return tupleID{}, false, false + } + + ipv4 := header.IPv4(h) + if ipv4.HeaderLength() > header.IPv4MinimumSize { + // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. + return tupleID{}, false, false + } + switch ipv4.TransportProtocol() { + case header.TCPProtocolNumber: + if netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize); ok { + netHdr := header.IPv4(netAndTransHeader) + return getTupleIDForPacketInICMPError(netHdr, header.IPv4ProtocolNumber, header.TCP(netHdr.Payload()), header.TCPProtocolNumber), true, true + } + case header.UDPProtocolNumber: + if netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.UDPMinimumSize); ok { + netHdr := header.IPv4(netAndTransHeader) + return getTupleIDForPacketInICMPError(netHdr, header.IPv4ProtocolNumber, header.UDP(netHdr.Payload()), header.UDPProtocolNumber), true, true + } + } + } + + return tupleID{}, false, false } func (ct *ConnTrack) init() { @@ -231,21 +332,11 @@ func (ct *ConnTrack) init() { } func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { - netHeader := pkt.Network() - transportHeader, ok := getTransportHeader(pkt) + tid, isICMPError, ok := getTupleID(pkt) if !ok { return nil } - tid := tupleID{ - srcAddr: netHeader.SourceAddress(), - srcPort: transportHeader.SourcePort(), - dstAddr: netHeader.DestinationAddress(), - dstPort: transportHeader.DestinationPort(), - transProto: pkt.TransportProtocolNumber, - netProto: pkt.NetworkProtocolNumber, - } - bktID := ct.bucket(tid) ct.mu.RLock() @@ -257,6 +348,11 @@ func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { return t } + if isICMPError { + // Do not create a noop entry in response to an ICMP error. + return nil + } + bkt.mu.Lock() defer bkt.mu.Unlock() @@ -407,7 +503,7 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) // // Returns true if the packet can skip the NAT table. func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { - transportHeader, ok := getTransportHeader(pkt) + netHdr, transHdr, isICMPError, ok := getHeaders(pkt) if !ok { return false } @@ -498,9 +594,9 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { } rewritePacket( - pkt.Network(), - transportHeader, - !dnat, + netHdr, + transHdr, + !dnat != isICMPError, fullChecksum, updatePseudoHeader, newPort, @@ -508,6 +604,28 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { ) *natDone = true + + if !isICMPError { + return true + } + + // We performed NAT on (erroneous) packet that triggered an ICMP response, but + // not the ICMP packet itself. + switch pkt.TransportProtocolNumber { + case header.ICMPv4ProtocolNumber: + icmp := header.ICMPv4(pkt.TransportHeader().View()) + // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum. + icmp.SetChecksum(0) + icmp.SetChecksum(header.ICMPv4Checksum(icmp, pkt.Data().AsRange().Checksum())) + + network := header.IPv4(pkt.NetworkHeader().View()) + if dnat { + network.SetDestinationAddressWithChecksumUpdate(tid.srcAddr) + } else { + network.SetSourceAddressWithChecksumUpdate(tid.dstAddr) + } + } + return true } |