From 08f1d96168ace77ff105da76f384aa0997a21e2f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 12 Oct 2021 15:29:57 -0700 Subject: Separate DNAT and SNAT manip states This change also refactors the conntrack packet handling code to not perform the actual rewriting of the packet while holding the lock. This change prepares for a followup CL that adds support for twice-NAT. Updates #5696. PiperOrigin-RevId: 402671685 --- pkg/tcpip/stack/conntrack.go | 162 ++++++++++++++++++++++--------------------- 1 file changed, 83 insertions(+), 79 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 16d295271..48f290187 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -45,17 +45,6 @@ const ( dirReply ) -// Manipulation type for the connection. -// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and -// DNAT at the same time. -type manipType int - -const ( - manipNone manipType = iota - manipSource - manipDestination -) - // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. // @@ -124,10 +113,14 @@ type conn struct { // // +checklocks:mu finalized bool - // manip indicates if the packet should be manipulated. + // sourceManip indicates the packet's source is manipulated. // // +checklocks:mu - manip manipType + sourceManip bool + // destinationManip indicates the packet's destination is manipulated. + // + // +checklocks:mu + destinationManip bool // tcb is TCB control block. It is used to keep track of states // of tcp connection. // @@ -286,7 +279,6 @@ func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { ct: ct, original: tuple{tupleID: tid, direction: dirOriginal}, reply: tuple{tupleID: tid.reply(), direction: dirReply}, - manip: manipNone, lastUsed: now, } conn.original.conn = conn @@ -393,8 +385,16 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) return } - if cn.manip != manipNone { - return + if dnat { + if cn.destinationManip { + return + } + cn.destinationManip = true + } else { + if cn.sourceManip { + return + } + cn.sourceManip = true } cn.reply.mu.Lock() @@ -403,11 +403,9 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) if dnat { cn.reply.tupleID.srcAddr = address cn.reply.tupleID.srcPort = port - cn.manip = manipDestination } else { cn.reply.tupleID.dstAddr = address cn.reply.tupleID.dstPort = port - cn.manip = manipSource } } @@ -421,68 +419,24 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) { return } - netHeader := pkt.Network() - - // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be - // validated if checksum offloading is off. It may require IP defrag if the - // packets are fragmented. - - var newAddr tcpip.Address - var newPort uint16 - - updateSRCFields := false - - dir := pkt.tuple.direction - - cn.mu.Lock() - defer cn.mu.Unlock() - - switch hook { - case Prerouting, Output: - if cn.manip == manipDestination && dir == dirOriginal { - id := cn.reply.id() - newPort = id.srcPort - newAddr = id.srcAddr - pkt.NatDone = true - } else if cn.manip == manipSource && dir == dirReply { - id := cn.original.id() - newPort = id.srcPort - newAddr = id.srcAddr - pkt.NatDone = true - } - case Input, Postrouting: - if cn.manip == manipSource && dir == dirOriginal { - id := cn.reply.id() - newPort = id.dstPort - newAddr = id.dstAddr - updateSRCFields = true - pkt.NatDone = true - } else if cn.manip == manipDestination && dir == dirReply { - id := cn.original.id() - newPort = id.dstPort - newAddr = id.dstAddr - updateSRCFields = true - pkt.NatDone = true - } - default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) - } - - if !pkt.NatDone { - return - } - fullChecksum := false updatePseudoHeader := false + dnat := false switch hook { case Prerouting: // Packet came from outside the stack so it must have a checksum set // already. fullChecksum = true updatePseudoHeader = true + + dnat = true case Input: - case Output, Postrouting: - // Calculate the TCP checksum and set it. + case Forward: + panic("should not handle packet in the forwarding hook") + case Output: + dnat = true + fallthrough + case Postrouting: if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { @@ -490,23 +444,73 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) { updatePseudoHeader = true } default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) + panic(fmt.Sprintf("unrecognized hook = %d", hook)) + } + + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. + + dir := pkt.tuple.direction + tid, performManip := func() (tupleID, bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + var tuple *tuple + switch dir { + case dirOriginal: + if dnat { + if !cn.destinationManip { + return tupleID{}, false + } + } else if !cn.sourceManip { + return tupleID{}, false + } + + tuple = &cn.reply + case dirReply: + if dnat { + if !cn.sourceManip { + return tupleID{}, false + } + } else if !cn.destinationManip { + return tupleID{}, false + } + + tuple = &cn.original + default: + panic(fmt.Sprintf("unhandled dir = %d", dir)) + } + + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = time.Now() + // Update connection state. + cn.updateLocked(pkt, dir) + + return tuple.id(), true + }() + if !performManip { + return + } + + newPort := tid.dstPort + newAddr := tid.dstAddr + if dnat { + newPort = tid.srcPort + newAddr = tid.srcAddr } rewritePacket( - netHeader, + pkt.Network(), transportHeader, - updateSRCFields, + !dnat, fullChecksum, updatePseudoHeader, newPort, newAddr, ) - // Mark the connection as having been used recently so it isn't reaped. - cn.lastUsed = time.Now() - // Update connection state. - cn.updateLocked(pkt, dir) + pkt.NatDone = true } // bucket gets the conntrack bucket for a tupleID. @@ -651,7 +655,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ t.conn.mu.RLock() defer t.conn.mu.RUnlock() - if t.conn.manip != manipDestination { + if !t.conn.destinationManip { // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } -- cgit v1.2.3