diff options
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 4 |
2 files changed, 36 insertions, 33 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 4fb7e9adb..b7cb54b1d 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -388,28 +388,33 @@ func (ct *ConnTrack) insertConn(conn *conn) { // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { - if pkt.NatDone { - return false - } - switch hook { case Prerouting, Input, Output, Postrouting: default: return false } - transportHeader, ok := getTransportHeader(pkt) - if !ok { + if conn, dir := ct.connFor(pkt); conn != nil { + conn.handlePacket(pkt, hook, dir, r) return false } - conn, dir := ct.connFor(pkt) // Connection not found for the packet. - if conn == nil { - // If this is the last hook in the data path for this packet (Input if - // incoming, Postrouting if outgoing), indicate that a connection should be - // inserted by the end of this hook. - return hook == Input || hook == Postrouting + // + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting +} + +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Route) { + if pkt.NatDone { + return + } + + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return } netHeader := pkt.Network() @@ -425,24 +430,24 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination && dir == dirOriginal { - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr + if cn.manip == manipDestination && dir == dirOriginal { + newPort = cn.reply.srcPort + newAddr = cn.reply.srcAddr pkt.NatDone = true - } else if conn.manip == manipSource && dir == dirReply { - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr + } else if cn.manip == manipSource && dir == dirReply { + newPort = cn.original.srcPort + newAddr = cn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource && dir == dirOriginal { - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr + if cn.manip == manipSource && dir == dirOriginal { + newPort = cn.reply.dstPort + newAddr = cn.reply.dstAddr updateSRCFields = true pkt.NatDone = true - } else if conn.manip == manipDestination && dir == dirReply { - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr + } else if cn.manip == manipDestination && dir == dirReply { + newPort = cn.original.dstPort + newAddr = cn.original.dstAddr updateSRCFields = true pkt.NatDone = true } @@ -451,7 +456,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } if !pkt.NatDone { - return false + return } fullChecksum := false @@ -486,15 +491,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { ) // Update the state of tcb. - conn.mu.Lock() - defer conn.mu.Unlock() + cn.mu.Lock() + defer cn.mu.Unlock() // Mark the connection as having been used recently so it isn't reaped. - conn.lastUsed = time.Now() + cn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(pkt, hook) - - return false + cn.updateLocked(pkt, hook) } // maybeInsertNoop tries to insert a no-op connection entry to keep connections diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8310645bf..949c44c9b 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -162,7 +162,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r // packet of the connection comes here. Other packets will be // manipulated in connection tracking. if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, r) + conn.handlePacket(pkt, hook, dirOriginal, r) } default: return RuleDrop, 0 @@ -213,7 +213,7 @@ func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint // // Does nothing if the protocol does not support connection tracking. if conn := ct.insertSNATConn(pkt, hook, port, address); conn != nil { - ct.handlePacket(pkt, hook, r) + conn.handlePacket(pkt, hook, dirOriginal, r) } return RuleAccept, 0 |