From 6e83c4b751c60652247d0ebbe559261352b2131f Mon Sep 17 00:00:00 2001
From: Ghanan Gowripalan <ghanan@google.com>
Date: Fri, 1 Oct 2021 13:35:04 -0700
Subject: Drop conn.tcbHook

...as the packet's direction gives us the information that tcbHook is
used to derive.

PiperOrigin-RevId: 400280102
---
 pkg/tcpip/stack/conntrack.go | 31 ++++++++++++++++---------------
 pkg/tcpip/stack/iptables.go  |  2 +-
 2 files changed, 17 insertions(+), 16 deletions(-)

(limited to 'pkg/tcpip/stack')

diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 2145a8496..bd47f734f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -113,10 +113,6 @@ type conn struct {
 	// TODO(gvisor.dev/issue/5696): Support updating manipulation type.
 	manip manipType
 
-	// tcbHook indicates if the packet is inbound or outbound to
-	// update the state of tcb. It is immutable.
-	tcbHook Hook
-
 	mu sync.Mutex `state:"nosave"`
 	// tcb is TCB control block. It is used to keep track of states
 	// of tcp connection.
@@ -133,10 +129,9 @@ type conn struct {
 }
 
 // newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+func newConn(orig, reply tupleID, manip manipType) *conn {
 	conn := conn{
 		manip:    manip,
-		tcbHook:  hook,
 		lastUsed: time.Now(),
 	}
 	conn.original = tuple{conn: &conn, tupleID: orig}
@@ -164,7 +159,7 @@ func (cn *conn) timedOut(now time.Time) bool {
 //
 // TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
 // +checklocks:cn.mu
-func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
+func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) {
 	if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
 		return
 	}
@@ -176,10 +171,16 @@ func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
 	// established or not, so the client/server distinction isn't important.
 	if cn.tcb.IsEmpty() {
 		cn.tcb.Init(tcpHeader)
-	} else if hook == cn.tcbHook {
+		return
+	}
+
+	switch dir {
+	case dirOriginal:
 		cn.tcb.UpdateStateOutbound(tcpHeader)
-	} else {
+	case dirReply:
 		cn.tcb.UpdateStateInbound(tcpHeader)
+	default:
+		panic(fmt.Sprintf("unhandled dir = %d", dir))
 	}
 }
 
@@ -318,7 +319,7 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1
 		// TODO(gvisor.dev/issue/5696): Support updating an existing connection.
 		return nil
 	}
-	conn = newConn(tid, replyTID, manipDestination, hook)
+	conn = newConn(tid, replyTID, manipDestination)
 	ct.insertConn(conn)
 	return conn
 }
@@ -342,7 +343,7 @@ func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, a
 		// TODO(gvisor.dev/issue/5696): Support updating an existing connection.
 		return nil
 	}
-	conn = newConn(tid, replyTID, manipSource, hook)
+	conn = newConn(tid, replyTID, manipSource)
 	ct.insertConn(conn)
 	return conn
 }
@@ -515,7 +516,7 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Rou
 	// Mark the connection as having been used recently so it isn't reaped.
 	cn.lastUsed = time.Now()
 	// Update connection state.
-	cn.updateLocked(pkt, hook)
+	cn.updateLocked(pkt, dir)
 }
 
 // maybeInsertNoop tries to insert a no-op connection entry to keep connections
@@ -524,7 +525,7 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Rou
 //
 // This should be called after traversing iptables rules only, to ensure that
 // pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer) {
 	// If there were a rule applying to this packet, it would be marked
 	// with NatDone.
 	if pkt.NatDone {
@@ -546,11 +547,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
 	if err != nil {
 		return
 	}
-	conn := newConn(tid, tid.reply(), manipNone, hook)
+	conn := newConn(tid, tid.reply(), manipNone)
 	ct.insertConn(conn)
 	conn.mu.Lock()
 	defer conn.mu.Unlock()
-	conn.updateLocked(pkt, hook)
+	conn.updateLocked(pkt, dirOriginal)
 }
 
 // bucket gets the conntrack bucket for a tupleID.
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index dcba7eba6..a20bef3c5 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -388,7 +388,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
 	// binding is created: this usually does not map the packet, but exists
 	// to ensure we don't map another stream over an existing one."
 	if shouldTrack {
-		it.connections.maybeInsertNoop(pkt, hook)
+		it.connections.maybeInsertNoop(pkt)
 	}
 
 	// Every table returned Accept.
-- 
cgit v1.2.3