diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 57 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_state_autogen.go | 6 |
2 files changed, 25 insertions, 38 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index c9a8e72a3..046679f76 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -37,14 +37,6 @@ import ( // Our hash table has 16K buckets. const numBuckets = 1 << 14 -// Direction of the tuple. -type direction int - -const ( - dirOriginal direction = iota - dirReply -) - // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. // @@ -56,8 +48,9 @@ type tuple struct { // conn is the connection tracking entry this tuple belongs to. conn *conn - // direction is the direction of the tuple. - direction direction + // reply is true iff the tuple's direction is opposite that of the first + // packet seen on the connection. + reply bool mu sync.RWMutex `state:"nosave"` // +checklocks:mu @@ -155,7 +148,7 @@ func (cn *conn) timedOut(now time.Time) bool { // // TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. // +checklocks:cn.mu -func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) { +func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) { if pkt.TransportProtocolNumber != header.TCPProtocolNumber { return } @@ -170,13 +163,10 @@ func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) { return } - switch dir { - case dirOriginal: - cn.tcb.UpdateStateOutbound(tcpHeader) - case dirReply: + if reply { cn.tcb.UpdateStateInbound(tcpHeader) - default: - panic(fmt.Sprintf("unhandled dir = %d", dir)) + } else { + cn.tcb.UpdateStateOutbound(tcpHeader) } } @@ -277,8 +267,8 @@ func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { // for this new connection. conn := &conn{ ct: ct, - original: tuple{tupleID: tid, direction: dirOriginal}, - reply: tuple{tupleID: tid.reply(), direction: dirReply}, + original: tuple{tupleID: tid}, + reply: tuple{tupleID: tid.reply(), reply: true}, lastUsed: now, } conn.original.conn = conn @@ -458,41 +448,38 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. - dir := pkt.tuple.direction + reply := pkt.tuple.reply tid, performManip := func() (tupleID, bool) { cn.mu.Lock() defer cn.mu.Unlock() var tuple *tuple - switch dir { - case dirOriginal: + if reply { if dnat { - if !cn.destinationManip { + if !cn.sourceManip { return tupleID{}, false } - } else if !cn.sourceManip { + } else if !cn.destinationManip { return tupleID{}, false } - tuple = &cn.reply - case dirReply: + tuple = &cn.original + } else { if dnat { - if !cn.sourceManip { + if !cn.destinationManip { return tupleID{}, false } - } else if !cn.destinationManip { + } else if !cn.sourceManip { return tupleID{}, false } - tuple = &cn.original - default: - panic(fmt.Sprintf("unhandled dir = %d", dir)) + tuple = &cn.reply } // Mark the connection as having been used recently so it isn't reaped. cn.lastUsed = time.Now() // Update connection state. - cn.updateLocked(pkt, dir) + cn.updateLocked(pkt, reply) return tuple.id(), true }() @@ -637,10 +624,10 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now t // TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. // +checklocks:b.mu func removeConnFromBucket(b *bucket, tuple *tuple) { - if tuple.direction == dirOriginal { - b.tuples.Remove(&tuple.conn.reply) - } else { + if tuple.reply { b.tuples.Remove(&tuple.conn.original) + } else { + b.tuples.Remove(&tuple.conn.reply) } } diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go index 99fc2df69..f1befa422 100644 --- a/pkg/tcpip/stack/stack_state_autogen.go +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -14,7 +14,7 @@ func (t *tuple) StateFields() []string { return []string{ "tupleEntry", "conn", - "direction", + "reply", "tupleID", } } @@ -26,7 +26,7 @@ func (t *tuple) StateSave(stateSinkObject state.Sink) { t.beforeSave() stateSinkObject.Save(0, &t.tupleEntry) stateSinkObject.Save(1, &t.conn) - stateSinkObject.Save(2, &t.direction) + stateSinkObject.Save(2, &t.reply) stateSinkObject.Save(3, &t.tupleID) } @@ -36,7 +36,7 @@ func (t *tuple) afterLoad() {} func (t *tuple) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &t.tupleEntry) stateSourceObject.Load(1, &t.conn) - stateSourceObject.Load(2, &t.direction) + stateSourceObject.Load(2, &t.reply) stateSourceObject.Load(3, &t.tupleID) } |