summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/stack/conntrack.go354
-rw-r--r--pkg/tcpip/stack/iptables.go67
-rw-r--r--pkg/tcpip/stack/iptables_targets.go50
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/packet_buffer.go17
-rw-r--r--pkg/tcpip/stack/stack_state_autogen.go58
6 files changed, 249 insertions, 299 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 30545f634..16d295271 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -64,13 +64,21 @@ type tuple struct {
// tupleEntry is used to build an intrusive list of tuples.
tupleEntry
- tupleID
-
// conn is the connection tracking entry this tuple belongs to.
conn *conn
// direction is the direction of the tuple.
direction direction
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ tupleID tupleID
+}
+
+func (t *tuple) id() tupleID {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.tupleID
}
// tupleID uniquely identifies a connection in one direction. It currently
@@ -103,17 +111,23 @@ func (ti tupleID) reply() tupleID {
//
// +stateify savable
type conn struct {
+ ct *ConnTrack
+
// original is the tuple in original direction. It is immutable.
original tuple
- // reply is the tuple in reply direction. It is immutable.
+ // reply is the tuple in reply direction.
reply tuple
- // manip indicates if the packet should be manipulated. It is immutable.
- // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
- manip manipType
-
mu sync.RWMutex `state:"nosave"`
+ // Indicates that the connection has been finalized and may handle replies.
+ //
+ // +checklocks:mu
+ finalized bool
+ // manip indicates if the packet should be manipulated.
+ //
+ // +checklocks:mu
+ manip manipType
// tcb is TCB control block. It is used to keep track of states
// of tcp connection.
//
@@ -128,17 +142,6 @@ type conn struct {
lastUsed time.Time `state:".(unixTime)"`
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType) *conn {
- conn := conn{
- manip: manip,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
@@ -235,168 +238,180 @@ func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool)
return nil, false
}
-// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
-// TCP header.
-//
-// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
+func (ct *ConnTrack) init() {
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+ ct.buckets = make([]bucket, numBuckets)
+}
+
+func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
netHeader := pkt.Network()
transportHeader, ok := getTransportHeader(pkt)
if !ok {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
+ return nil
}
- return tupleID{
+ tid := tupleID{
srcAddr: netHeader.SourceAddress(),
srcPort: transportHeader.SourcePort(),
dstAddr: netHeader.DestinationAddress(),
dstPort: transportHeader.DestinationPort(),
transProto: pkt.TransportProtocolNumber,
netProto: pkt.NetworkProtocolNumber,
- }, nil
-}
+ }
-func (ct *ConnTrack) init() {
- ct.mu.Lock()
- defer ct.mu.Unlock()
- ct.buckets = make([]bucket, numBuckets)
-}
+ bktID := ct.bucket(tid)
+
+ ct.mu.RLock()
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ now := time.Now()
+ if t := bkt.connForTID(tid, now); t != nil {
+ return t
+ }
-// connFor gets the conn for pkt if it exists, or returns nil
-// if it does not. It returns an error when pkt does not contain a valid TCP
-// header.
-// TODO(gvisor.dev/issue/6168): Support UDP.
-func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil, dirOriginal
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ // Make sure a connection wasn't added between when we last checked the
+ // bucket and acquired the bucket's write lock.
+ if t := bkt.connForTIDRLocked(tid, now); t != nil {
+ return t
+ }
+
+ // This is the first packet we're seeing for the connection. Create an entry
+ // for this new connection.
+ conn := &conn{
+ ct: ct,
+ original: tuple{tupleID: tid, direction: dirOriginal},
+ reply: tuple{tupleID: tid.reply(), direction: dirReply},
+ manip: manipNone,
+ lastUsed: now,
}
- return ct.connForTID(tid)
+ conn.original.conn = conn
+ conn.reply.conn = conn
+
+ // For now, we only map an entry for the packet's original tuple as NAT may be
+ // performed on this connection. Until the packet goes through all the hooks
+ // and its final address/port is known, we cannot know what the response
+ // packet's addresses/ports will look like.
+ //
+ // This is okay because the destination cannot send its response until it
+ // receives the packet; the packet will only be received once all the hooks
+ // have been performed.
+ //
+ // See (*conn).finalize.
+ bkt.tuples.PushFront(&conn.original)
+ return &conn.original
}
-func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
+func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
bktID := ct.bucket(tid)
- now := time.Now()
ct.mu.RLock()
bkt := &ct.buckets[bktID]
ct.mu.RUnlock()
+ return bkt.connForTID(tid, time.Now())
+}
+
+func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple {
bkt.mu.RLock()
defer bkt.mu.RUnlock()
+ return bkt.connForTIDRLocked(tid, now)
+}
+
+// +checklocks:bkt.mu
+func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple {
for other := bkt.tuples.Front(); other != nil; other = other.Next() {
- if tid == other.tupleID && !other.conn.timedOut(now) {
- return other.conn, other.direction
+ if tid == other.id() && !other.conn.timedOut(now) {
+ return other
}
}
-
- return nil, dirOriginal
+ return nil
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Prerouting && hook != Output {
- return nil
- }
+func (ct *ConnTrack) finalize(cn *conn) {
+ tid := cn.reply.id()
+ id := ct.bucket(tid)
- replyTID := tid.reply()
- replyTID.srcAddr = address
- replyTID.srcPort = port
+ ct.mu.RLock()
+ bkt := &ct.buckets[id]
+ ct.mu.RUnlock()
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil {
+ // Another connection for the reply already exists. We can't do much about
+ // this so we leave the connection cn represents in a state where it can
+ // send packets but its responses will be mapped to some other connection.
+ // This may be okay if the connection only expects to send packets without
+ // any responses.
+ return
}
- conn = newConn(tid, replyTID, manipDestination)
- ct.insertConn(conn)
- return conn
+
+ bkt.tuples.PushFront(&cn.reply)
}
-func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Input && hook != Postrouting {
- return nil
+func (cn *conn) finalize() {
+ {
+ cn.mu.RLock()
+ finalized := cn.finalized
+ cn.mu.RUnlock()
+ if finalized {
+ return
+ }
}
- replyTID := tid.reply()
- replyTID.dstAddr = address
- replyTID.dstPort = port
-
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ cn.mu.Lock()
+ finalized := cn.finalized
+ cn.finalized = true
+ cn.mu.Unlock()
+ if finalized {
+ return
}
- conn = newConn(tid, replyTID, manipSource)
- ct.insertConn(conn)
- return conn
+
+ cn.ct.finalize(cn)
}
-// insertConn inserts conn into the appropriate table bucket.
-func (ct *ConnTrack) insertConn(conn *conn) {
- tupleBktID := ct.bucket(conn.original.tupleID)
- replyBktID := ct.bucket(conn.reply.tupleID)
+// performNAT setups up the connection for the specified NAT.
+//
+// Generally, only the first packet of a connection reaches this method; other
+// other packets will be manipulated without needing to modify the connection.
+func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) {
+ cn.performNATIfNoop(port, address, dnat)
+ cn.handlePacket(pkt, hook, r)
+}
- ct.mu.RLock()
- defer ct.mu.RUnlock()
+func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
- tupleBkt := &ct.buckets[tupleBktID]
- if tupleBktID == replyBktID {
- // Both tuples are in the same bucket.
- tupleBkt.mu.Lock()
- defer tupleBkt.mu.Unlock()
- insertConn(tupleBkt, tupleBkt, conn)
+ if cn.finalized {
return
}
- // Lock the buckets in the correct order.
- replyBkt := &ct.buckets[replyBktID]
- if tupleBktID < replyBktID {
- tupleBkt.mu.Lock()
- defer tupleBkt.mu.Unlock()
- replyBkt.mu.Lock()
- defer replyBkt.mu.Unlock()
- } else {
- replyBkt.mu.Lock()
- defer replyBkt.mu.Unlock()
- tupleBkt.mu.Lock()
- defer tupleBkt.mu.Unlock()
+ if cn.manip != manipNone {
+ return
}
- insertConn(tupleBkt, replyBkt, conn)
-}
-// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
-// +checklocks:tupleBkt.mu
-// +checklocks:replyBkt.mu
-func insertConn(tupleBkt *bucket, replyBkt *bucket, conn *conn) {
- // Now that we hold the locks, ensure the tuple hasn't been inserted by
- // another thread.
- // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
- alreadyInserted := false
- for other := tupleBkt.tuples.Front(); other != nil; other = other.Next() {
- if other.tupleID == conn.original.tupleID {
- alreadyInserted = true
- break
- }
- }
+ cn.reply.mu.Lock()
+ defer cn.reply.mu.Unlock()
- if !alreadyInserted {
- // Add the tuple to the map.
- tupleBkt.tuples.PushFront(&conn.original)
- replyBkt.tuples.PushFront(&conn.reply)
+ if dnat {
+ cn.reply.tupleID.srcAddr = address
+ cn.reply.tupleID.srcPort = port
+ cn.manip = manipDestination
+ } else {
+ cn.reply.tupleID.dstAddr = address
+ cn.reply.tupleID.dstPort = port
+ cn.manip = manipSource
}
}
-func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Route) {
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
if pkt.NatDone {
return
}
@@ -417,26 +432,35 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Rou
updateSRCFields := false
+ dir := pkt.tuple.direction
+
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
switch hook {
case Prerouting, Output:
if cn.manip == manipDestination && dir == dirOriginal {
- newPort = cn.reply.srcPort
- newAddr = cn.reply.srcAddr
+ id := cn.reply.id()
+ newPort = id.srcPort
+ newAddr = id.srcAddr
pkt.NatDone = true
} else if cn.manip == manipSource && dir == dirReply {
- newPort = cn.original.srcPort
- newAddr = cn.original.srcAddr
+ id := cn.original.id()
+ newPort = id.srcPort
+ newAddr = id.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
if cn.manip == manipSource && dir == dirOriginal {
- newPort = cn.reply.dstPort
- newAddr = cn.reply.dstAddr
+ id := cn.reply.id()
+ newPort = id.dstPort
+ newAddr = id.dstAddr
updateSRCFields = true
pkt.NatDone = true
} else if cn.manip == manipDestination && dir == dirReply {
- newPort = cn.original.dstPort
- newAddr = cn.original.dstAddr
+ id := cn.original.id()
+ newPort = id.dstPort
+ newAddr = id.dstAddr
updateSRCFields = true
pkt.NatDone = true
}
@@ -479,51 +503,12 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Rou
newAddr,
)
- // Update the state of tcb.
- cn.mu.Lock()
- defer cn.mu.Unlock()
-
// Mark the connection as having been used recently so it isn't reaped.
cn.lastUsed = time.Now()
// Update connection state.
cn.updateLocked(pkt, dir)
}
-// maybeInsertNoop tries to insert a no-op connection entry to keep connections
-// from getting clobbered when replies arrive. It only inserts if there isn't
-// already a connection for pkt.
-//
-// This should be called after traversing iptables rules only, to ensure that
-// pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer) {
- // If there were a rule applying to this packet, it would be marked
- // with NatDone.
- if pkt.NatDone {
- return
- }
-
- switch pkt.TransportProtocolNumber {
- case header.TCPProtocolNumber, header.UDPProtocolNumber:
- default:
- // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable
- // connections.
- return
- }
-
- // This is the first packet we're seeing for the TCP connection. Insert
- // the noop entry (an identity mapping) so that the response doesn't
- // get NATed, breaking the connection.
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return
- }
- conn := newConn(tid, tid.reply(), manipNone)
- ct.insertConn(conn)
- conn.mu.Lock()
- defer conn.mu.Unlock()
- conn.updateLocked(pkt, dirOriginal)
-}
-
// bucket gets the conntrack bucket for a tupleID.
func (ct *ConnTrack) bucket(id tupleID) int {
h := jenkins.Sum32(ct.seed)
@@ -617,7 +602,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now t
// To maintain lock order, we can only reap these tuples if the reply
// appears later in the table.
- replyBktID := ct.bucket(tuple.reply())
+ replyBktID := ct.bucket(tuple.id().reply())
if bktID > replyBktID {
return true
}
@@ -658,14 +643,19 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
transProto: transProto,
netProto: netProto,
}
- conn, _ := ct.connForTID(tid)
- if conn == nil {
+ t := ct.connForTID(tid)
+ if t == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip != manipDestination {
+ }
+
+ t.conn.mu.RLock()
+ defer t.conn.mu.RUnlock()
+ if t.conn.manip != manipDestination {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
- return conn.original.dstAddr, conn.original.dstPort, nil
+ id := t.conn.original.id()
+ return id.dstAddr, id.dstPort, nil
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 1021e484a..5808be685 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -277,8 +277,9 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp
return true
}
- if conn, dir := it.connections.connFor(pkt); conn != nil {
- conn.handlePacket(pkt, hook, dir, nil /* route */)
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, nil /* route */)
}
return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
@@ -297,20 +298,16 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
return true
}
- shouldTrack := true
- if conn, dir := it.connections.connFor(pkt); conn != nil {
- conn.handlePacket(pkt, hook, dir, nil /* route */)
- shouldTrack = false
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, nil /* route */)
}
- if !it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) {
- return false
+ ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
}
-
- // This is the last hook a packet will perform so if the packet's
- // connection is not tracked, we may need to add a no-op entry.
- it.maybeinsertNoopConn(pkt, hook, shouldTrack)
- return true
+ pkt.tuple = nil
+ return ret
}
// CheckForward performs the forward hook on the packet.
@@ -323,7 +320,6 @@ func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string
if it.shouldSkip(pkt.NetworkProtocolNumber) {
return true
}
-
return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
}
@@ -340,8 +336,9 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string)
return true
}
- if conn, dir := it.connections.connFor(pkt); conn != nil {
- conn.handlePacket(pkt, hook, dir, r)
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, r)
}
return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
@@ -360,20 +357,16 @@ func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP Addr
return true
}
- shouldTrack := true
- if conn, dir := it.connections.connFor(pkt); conn != nil {
- conn.handlePacket(pkt, hook, dir, r)
- shouldTrack = false
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, r)
}
- if !it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) {
- return false
+ ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
}
-
- // This is the last hook a packet will perform so if the packet's
- // connection is not tracked, we may need to add a no-op entry.
- it.maybeinsertNoopConn(pkt, hook, shouldTrack)
- return true
+ pkt.tuple = nil
+ return ret
}
func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool {
@@ -426,7 +419,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v {
+ switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -445,22 +438,6 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
return true
}
-func (it *IPTables) maybeinsertNoopConn(pkt *PacketBuffer, hook Hook, shouldTrack bool) {
- // If this connection should be tracked, try to add an entry for it. If
- // traversing the nat table didn't end in adding an entry,
- // maybeInsertNoop will add a no-op entry for the connection. This is
- // needeed when establishing connections so that the SYN/ACK reply to an
- // outgoing SYN is delivered to the correct endpoint rather than being
- // redirected by a prerouting rule.
- //
- // From the iptables documentation: "If there is no rule, a `null'
- // binding is created: this usually does not map the packet, but exists
- // to ensure we don't map another stream over an existing one."
- if shouldTrack {
- it.connections.maybeInsertNoop(pkt)
- }
-}
-
// beforeSave is invoked by stateify.
func (it *IPTables) beforeSave() {
// Ensure the reaper exits cleanly.
@@ -606,7 +583,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, addressEP)
+ return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 949c44c9b..8b74677d0 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,7 +79,7 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -97,7 +97,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -154,15 +154,8 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
pkt.NatDone = true
case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
-
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- conn.handlePacket(pkt, hook, dirOriginal, r)
+ if t := pkt.tuple; t != nil {
+ t.conn.performNAT(pkt, hook, r, rt.Port, address, true /* dnat */)
}
default:
return RuleDrop, 0
@@ -181,7 +174,7 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) {
+func snatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -197,30 +190,21 @@ func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint
if port == 0 {
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
- if port == 0 {
- port = header.UDP(pkt.TransportHeader().View()).SourcePort()
- }
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
case header.TCPProtocolNumber:
- if port == 0 {
- port = header.TCP(pkt.TransportHeader().View()).SourcePort()
- }
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
}
}
- // Set up conection for matching NAT rule. Only the first packet of the
- // connection comes here. Other packets will be manipulated in connection
- // tracking.
- //
- // Does nothing if the protocol does not support connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, port, address); conn != nil {
- conn.handlePacket(pkt, hook, dirOriginal, r)
+ if t := pkt.tuple; t != nil {
+ t.conn.performNAT(pkt, hook, r, port, address, false /* dnat */)
}
return RuleAccept, 0
}
// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
+func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -236,7 +220,7 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- return snatAction(pkt, ct, hook, r, st.Port, st.Addr)
+ return snatAction(pkt, hook, r, st.Port, st.Addr)
}
// MasqueradeTarget modifies the source port/IP in the outgoing packets.
@@ -247,7 +231,7 @@ type MasqueradeTarget struct {
}
// Action implements Target.Action.
-func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -272,7 +256,7 @@ func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook,
address := ep.AddressWithPrefix().Address
ep.DecRef()
- return snatAction(pkt, ct, hook, r, 0 /* port */, address)
+ return snatAction(pkt, hook, r, 0 /* port */, address)
}
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 50f73f173..b22024667 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -356,5 +356,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
+ Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 456b0cf80..888a8bd9d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -143,6 +143,8 @@ type PacketBuffer struct {
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
+
+ tuple *tuple
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -302,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NICID: pk.NICID,
RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
NetworkPacketInfo: pk.NetworkPacketInfo,
+ tuple: pk.tuple,
}
}
@@ -329,13 +332,8 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
+ tuple: pk.tuple,
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- newPk.NatDone = pk.NatDone
return newPk
}
@@ -367,12 +365,7 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu
newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- newPk.NatDone = pk.NatDone
+ newPk.tuple = pk.tuple
return newPk
}
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
index 6fa6b3a7b..dec8287f9 100644
--- a/pkg/tcpip/stack/stack_state_autogen.go
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -13,9 +13,9 @@ func (t *tuple) StateTypeName() string {
func (t *tuple) StateFields() []string {
return []string{
"tupleEntry",
- "tupleID",
"conn",
"direction",
+ "tupleID",
}
}
@@ -25,9 +25,9 @@ func (t *tuple) beforeSave() {}
func (t *tuple) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.tupleEntry)
- stateSinkObject.Save(1, &t.tupleID)
- stateSinkObject.Save(2, &t.conn)
- stateSinkObject.Save(3, &t.direction)
+ stateSinkObject.Save(1, &t.conn)
+ stateSinkObject.Save(2, &t.direction)
+ stateSinkObject.Save(3, &t.tupleID)
}
func (t *tuple) afterLoad() {}
@@ -35,9 +35,9 @@ func (t *tuple) afterLoad() {}
// +checklocksignore
func (t *tuple) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.tupleEntry)
- stateSourceObject.Load(1, &t.tupleID)
- stateSourceObject.Load(2, &t.conn)
- stateSourceObject.Load(3, &t.direction)
+ stateSourceObject.Load(1, &t.conn)
+ stateSourceObject.Load(2, &t.direction)
+ stateSourceObject.Load(3, &t.tupleID)
}
func (ti *tupleID) StateTypeName() string {
@@ -86,8 +86,10 @@ func (cn *conn) StateTypeName() string {
func (cn *conn) StateFields() []string {
return []string{
+ "ct",
"original",
"reply",
+ "finalized",
"manip",
"tcb",
"lastUsed",
@@ -101,22 +103,26 @@ func (cn *conn) StateSave(stateSinkObject state.Sink) {
cn.beforeSave()
var lastUsedValue unixTime
lastUsedValue = cn.saveLastUsed()
- stateSinkObject.SaveValue(4, lastUsedValue)
- stateSinkObject.Save(0, &cn.original)
- stateSinkObject.Save(1, &cn.reply)
- stateSinkObject.Save(2, &cn.manip)
- stateSinkObject.Save(3, &cn.tcb)
+ stateSinkObject.SaveValue(6, lastUsedValue)
+ stateSinkObject.Save(0, &cn.ct)
+ stateSinkObject.Save(1, &cn.original)
+ stateSinkObject.Save(2, &cn.reply)
+ stateSinkObject.Save(3, &cn.finalized)
+ stateSinkObject.Save(4, &cn.manip)
+ stateSinkObject.Save(5, &cn.tcb)
}
func (cn *conn) afterLoad() {}
// +checklocksignore
func (cn *conn) StateLoad(stateSourceObject state.Source) {
- stateSourceObject.Load(0, &cn.original)
- stateSourceObject.Load(1, &cn.reply)
- stateSourceObject.Load(2, &cn.manip)
- stateSourceObject.Load(3, &cn.tcb)
- stateSourceObject.LoadValue(4, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) })
+ stateSourceObject.Load(0, &cn.ct)
+ stateSourceObject.Load(1, &cn.original)
+ stateSourceObject.Load(2, &cn.reply)
+ stateSourceObject.Load(3, &cn.finalized)
+ stateSourceObject.Load(4, &cn.manip)
+ stateSourceObject.Load(5, &cn.tcb)
+ stateSourceObject.LoadValue(6, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) })
}
func (ct *ConnTrack) StateTypeName() string {
@@ -145,29 +151,29 @@ func (ct *ConnTrack) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(1, &ct.buckets)
}
-func (b *bucket) StateTypeName() string {
+func (bkt *bucket) StateTypeName() string {
return "pkg/tcpip/stack.bucket"
}
-func (b *bucket) StateFields() []string {
+func (bkt *bucket) StateFields() []string {
return []string{
"tuples",
}
}
-func (b *bucket) beforeSave() {}
+func (bkt *bucket) beforeSave() {}
// +checklocksignore
-func (b *bucket) StateSave(stateSinkObject state.Sink) {
- b.beforeSave()
- stateSinkObject.Save(0, &b.tuples)
+func (bkt *bucket) StateSave(stateSinkObject state.Sink) {
+ bkt.beforeSave()
+ stateSinkObject.Save(0, &bkt.tuples)
}
-func (b *bucket) afterLoad() {}
+func (bkt *bucket) afterLoad() {}
// +checklocksignore
-func (b *bucket) StateLoad(stateSourceObject state.Source) {
- stateSourceObject.Load(0, &b.tuples)
+func (bkt *bucket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &bkt.tuples)
}
func (u *unixTime) StateTypeName() string {