diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 354 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 50 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_state_autogen.go | 58 |
6 files changed, 249 insertions, 299 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 30545f634..16d295271 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -64,13 +64,21 @@ type tuple struct { // tupleEntry is used to build an intrusive list of tuples. tupleEntry - tupleID - // conn is the connection tracking entry this tuple belongs to. conn *conn // direction is the direction of the tuple. direction direction + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + tupleID tupleID +} + +func (t *tuple) id() tupleID { + t.mu.RLock() + defer t.mu.RUnlock() + return t.tupleID } // tupleID uniquely identifies a connection in one direction. It currently @@ -103,17 +111,23 @@ func (ti tupleID) reply() tupleID { // // +stateify savable type conn struct { + ct *ConnTrack + // original is the tuple in original direction. It is immutable. original tuple - // reply is the tuple in reply direction. It is immutable. + // reply is the tuple in reply direction. reply tuple - // manip indicates if the packet should be manipulated. It is immutable. - // TODO(gvisor.dev/issue/5696): Support updating manipulation type. - manip manipType - mu sync.RWMutex `state:"nosave"` + // Indicates that the connection has been finalized and may handle replies. + // + // +checklocks:mu + finalized bool + // manip indicates if the packet should be manipulated. + // + // +checklocks:mu + manip manipType // tcb is TCB control block. It is used to keep track of states // of tcp connection. // @@ -128,17 +142,6 @@ type conn struct { lastUsed time.Time `state:".(unixTime)"` } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType) *conn { - conn := conn{ - manip: manip, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour @@ -235,168 +238,180 @@ func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) return nil, false } -// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid -// TCP header. -// -// Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { +func (ct *ConnTrack) init() { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.buckets = make([]bucket, numBuckets) +} + +func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { netHeader := pkt.Network() transportHeader, ok := getTransportHeader(pkt) if !ok { - return tupleID{}, &tcpip.ErrUnknownProtocol{} + return nil } - return tupleID{ + tid := tupleID{ srcAddr: netHeader.SourceAddress(), srcPort: transportHeader.SourcePort(), dstAddr: netHeader.DestinationAddress(), dstPort: transportHeader.DestinationPort(), transProto: pkt.TransportProtocolNumber, netProto: pkt.NetworkProtocolNumber, - }, nil -} + } -func (ct *ConnTrack) init() { - ct.mu.Lock() - defer ct.mu.Unlock() - ct.buckets = make([]bucket, numBuckets) -} + bktID := ct.bucket(tid) + + ct.mu.RLock() + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + now := time.Now() + if t := bkt.connForTID(tid, now); t != nil { + return t + } -// connFor gets the conn for pkt if it exists, or returns nil -// if it does not. It returns an error when pkt does not contain a valid TCP -// header. -// TODO(gvisor.dev/issue/6168): Support UDP. -func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil, dirOriginal + bkt.mu.Lock() + defer bkt.mu.Unlock() + + // Make sure a connection wasn't added between when we last checked the + // bucket and acquired the bucket's write lock. + if t := bkt.connForTIDRLocked(tid, now); t != nil { + return t + } + + // This is the first packet we're seeing for the connection. Create an entry + // for this new connection. + conn := &conn{ + ct: ct, + original: tuple{tupleID: tid, direction: dirOriginal}, + reply: tuple{tupleID: tid.reply(), direction: dirReply}, + manip: manipNone, + lastUsed: now, } - return ct.connForTID(tid) + conn.original.conn = conn + conn.reply.conn = conn + + // For now, we only map an entry for the packet's original tuple as NAT may be + // performed on this connection. Until the packet goes through all the hooks + // and its final address/port is known, we cannot know what the response + // packet's addresses/ports will look like. + // + // This is okay because the destination cannot send its response until it + // receives the packet; the packet will only be received once all the hooks + // have been performed. + // + // See (*conn).finalize. + bkt.tuples.PushFront(&conn.original) + return &conn.original } -func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { +func (ct *ConnTrack) connForTID(tid tupleID) *tuple { bktID := ct.bucket(tid) - now := time.Now() ct.mu.RLock() bkt := &ct.buckets[bktID] ct.mu.RUnlock() + return bkt.connForTID(tid, time.Now()) +} + +func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple { bkt.mu.RLock() defer bkt.mu.RUnlock() + return bkt.connForTIDRLocked(tid, now) +} + +// +checklocks:bkt.mu +func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple { for other := bkt.tuples.Front(); other != nil; other = other.Next() { - if tid == other.tupleID && !other.conn.timedOut(now) { - return other.conn, other.direction + if tid == other.id() && !other.conn.timedOut(now) { + return other } } - - return nil, dirOriginal + return nil } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil - } - if hook != Prerouting && hook != Output { - return nil - } +func (ct *ConnTrack) finalize(cn *conn) { + tid := cn.reply.id() + id := ct.bucket(tid) - replyTID := tid.reply() - replyTID.srcAddr = address - replyTID.srcPort = port + ct.mu.RLock() + bkt := &ct.buckets[id] + ct.mu.RUnlock() - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil + bkt.mu.Lock() + defer bkt.mu.Unlock() + + if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil { + // Another connection for the reply already exists. We can't do much about + // this so we leave the connection cn represents in a state where it can + // send packets but its responses will be mapped to some other connection. + // This may be okay if the connection only expects to send packets without + // any responses. + return } - conn = newConn(tid, replyTID, manipDestination) - ct.insertConn(conn) - return conn + + bkt.tuples.PushFront(&cn.reply) } -func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil - } - if hook != Input && hook != Postrouting { - return nil +func (cn *conn) finalize() { + { + cn.mu.RLock() + finalized := cn.finalized + cn.mu.RUnlock() + if finalized { + return + } } - replyTID := tid.reply() - replyTID.dstAddr = address - replyTID.dstPort = port - - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil + cn.mu.Lock() + finalized := cn.finalized + cn.finalized = true + cn.mu.Unlock() + if finalized { + return } - conn = newConn(tid, replyTID, manipSource) - ct.insertConn(conn) - return conn + + cn.ct.finalize(cn) } -// insertConn inserts conn into the appropriate table bucket. -func (ct *ConnTrack) insertConn(conn *conn) { - tupleBktID := ct.bucket(conn.original.tupleID) - replyBktID := ct.bucket(conn.reply.tupleID) +// performNAT setups up the connection for the specified NAT. +// +// Generally, only the first packet of a connection reaches this method; other +// other packets will be manipulated without needing to modify the connection. +func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) { + cn.performNATIfNoop(port, address, dnat) + cn.handlePacket(pkt, hook, r) +} - ct.mu.RLock() - defer ct.mu.RUnlock() +func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) { + cn.mu.Lock() + defer cn.mu.Unlock() - tupleBkt := &ct.buckets[tupleBktID] - if tupleBktID == replyBktID { - // Both tuples are in the same bucket. - tupleBkt.mu.Lock() - defer tupleBkt.mu.Unlock() - insertConn(tupleBkt, tupleBkt, conn) + if cn.finalized { return } - // Lock the buckets in the correct order. - replyBkt := &ct.buckets[replyBktID] - if tupleBktID < replyBktID { - tupleBkt.mu.Lock() - defer tupleBkt.mu.Unlock() - replyBkt.mu.Lock() - defer replyBkt.mu.Unlock() - } else { - replyBkt.mu.Lock() - defer replyBkt.mu.Unlock() - tupleBkt.mu.Lock() - defer tupleBkt.mu.Unlock() + if cn.manip != manipNone { + return } - insertConn(tupleBkt, replyBkt, conn) -} -// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. -// +checklocks:tupleBkt.mu -// +checklocks:replyBkt.mu -func insertConn(tupleBkt *bucket, replyBkt *bucket, conn *conn) { - // Now that we hold the locks, ensure the tuple hasn't been inserted by - // another thread. - // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? - alreadyInserted := false - for other := tupleBkt.tuples.Front(); other != nil; other = other.Next() { - if other.tupleID == conn.original.tupleID { - alreadyInserted = true - break - } - } + cn.reply.mu.Lock() + defer cn.reply.mu.Unlock() - if !alreadyInserted { - // Add the tuple to the map. - tupleBkt.tuples.PushFront(&conn.original) - replyBkt.tuples.PushFront(&conn.reply) + if dnat { + cn.reply.tupleID.srcAddr = address + cn.reply.tupleID.srcPort = port + cn.manip = manipDestination + } else { + cn.reply.tupleID.dstAddr = address + cn.reply.tupleID.dstPort = port + cn.manip = manipSource } } -func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Route) { +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) { if pkt.NatDone { return } @@ -417,26 +432,35 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Rou updateSRCFields := false + dir := pkt.tuple.direction + + cn.mu.Lock() + defer cn.mu.Unlock() + switch hook { case Prerouting, Output: if cn.manip == manipDestination && dir == dirOriginal { - newPort = cn.reply.srcPort - newAddr = cn.reply.srcAddr + id := cn.reply.id() + newPort = id.srcPort + newAddr = id.srcAddr pkt.NatDone = true } else if cn.manip == manipSource && dir == dirReply { - newPort = cn.original.srcPort - newAddr = cn.original.srcAddr + id := cn.original.id() + newPort = id.srcPort + newAddr = id.srcAddr pkt.NatDone = true } case Input, Postrouting: if cn.manip == manipSource && dir == dirOriginal { - newPort = cn.reply.dstPort - newAddr = cn.reply.dstAddr + id := cn.reply.id() + newPort = id.dstPort + newAddr = id.dstAddr updateSRCFields = true pkt.NatDone = true } else if cn.manip == manipDestination && dir == dirReply { - newPort = cn.original.dstPort - newAddr = cn.original.dstAddr + id := cn.original.id() + newPort = id.dstPort + newAddr = id.dstAddr updateSRCFields = true pkt.NatDone = true } @@ -479,51 +503,12 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Rou newAddr, ) - // Update the state of tcb. - cn.mu.Lock() - defer cn.mu.Unlock() - // Mark the connection as having been used recently so it isn't reaped. cn.lastUsed = time.Now() // Update connection state. cn.updateLocked(pkt, dir) } -// maybeInsertNoop tries to insert a no-op connection entry to keep connections -// from getting clobbered when replies arrive. It only inserts if there isn't -// already a connection for pkt. -// -// This should be called after traversing iptables rules only, to ensure that -// pkt.NatDone is set correctly. -func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer) { - // If there were a rule applying to this packet, it would be marked - // with NatDone. - if pkt.NatDone { - return - } - - switch pkt.TransportProtocolNumber { - case header.TCPProtocolNumber, header.UDPProtocolNumber: - default: - // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable - // connections. - return - } - - // This is the first packet we're seeing for the TCP connection. Insert - // the noop entry (an identity mapping) so that the response doesn't - // get NATed, breaking the connection. - tid, err := packetToTupleID(pkt) - if err != nil { - return - } - conn := newConn(tid, tid.reply(), manipNone) - ct.insertConn(conn) - conn.mu.Lock() - defer conn.mu.Unlock() - conn.updateLocked(pkt, dirOriginal) -} - // bucket gets the conntrack bucket for a tupleID. func (ct *ConnTrack) bucket(id tupleID) int { h := jenkins.Sum32(ct.seed) @@ -617,7 +602,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now t // To maintain lock order, we can only reap these tuples if the reply // appears later in the table. - replyBktID := ct.bucket(tuple.reply()) + replyBktID := ct.bucket(tuple.id().reply()) if bktID > replyBktID { return true } @@ -658,14 +643,19 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ transProto: transProto, netProto: netProto, } - conn, _ := ct.connForTID(tid) - if conn == nil { + t := ct.connForTID(tid) + if t == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip != manipDestination { + } + + t.conn.mu.RLock() + defer t.conn.mu.RUnlock() + if t.conn.manip != manipDestination { // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } - return conn.original.dstAddr, conn.original.dstPort, nil + id := t.conn.original.id() + return id.dstAddr, id.dstPort, nil } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 1021e484a..5808be685 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -277,8 +277,9 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp return true } - if conn, dir := it.connections.connFor(pkt); conn != nil { - conn.handlePacket(pkt, hook, dir, nil /* route */) + if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { + pkt.tuple = t + t.conn.handlePacket(pkt, hook, nil /* route */) } return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) @@ -297,20 +298,16 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { return true } - shouldTrack := true - if conn, dir := it.connections.connFor(pkt); conn != nil { - conn.handlePacket(pkt, hook, dir, nil /* route */) - shouldTrack = false + if t := pkt.tuple; t != nil { + t.conn.handlePacket(pkt, hook, nil /* route */) } - if !it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) { - return false + ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) + if t := pkt.tuple; t != nil { + t.conn.finalize() } - - // This is the last hook a packet will perform so if the packet's - // connection is not tracked, we may need to add a no-op entry. - it.maybeinsertNoopConn(pkt, hook, shouldTrack) - return true + pkt.tuple = nil + return ret } // CheckForward performs the forward hook on the packet. @@ -323,7 +320,6 @@ func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string if it.shouldSkip(pkt.NetworkProtocolNumber) { return true } - return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) } @@ -340,8 +336,9 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) return true } - if conn, dir := it.connections.connFor(pkt); conn != nil { - conn.handlePacket(pkt, hook, dir, r) + if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { + pkt.tuple = t + t.conn.handlePacket(pkt, hook, r) } return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) @@ -360,20 +357,16 @@ func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP Addr return true } - shouldTrack := true - if conn, dir := it.connections.connFor(pkt); conn != nil { - conn.handlePacket(pkt, hook, dir, r) - shouldTrack = false + if t := pkt.tuple; t != nil { + t.conn.handlePacket(pkt, hook, r) } - if !it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) { - return false + ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) + if t := pkt.tuple; t != nil { + t.conn.finalize() } - - // This is the last hook a packet will perform so if the packet's - // connection is not tracked, we may need to add a no-op entry. - it.maybeinsertNoopConn(pkt, hook, shouldTrack) - return true + pkt.tuple = nil + return ret } func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool { @@ -426,7 +419,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v { + switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -445,22 +438,6 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr return true } -func (it *IPTables) maybeinsertNoopConn(pkt *PacketBuffer, hook Hook, shouldTrack bool) { - // If this connection should be tracked, try to add an entry for it. If - // traversing the nat table didn't end in adding an entry, - // maybeInsertNoop will add a no-op entry for the connection. This is - // needeed when establishing connections so that the SYN/ACK reply to an - // outgoing SYN is delivered to the correct endpoint rather than being - // redirected by a prerouting rule. - // - // From the iptables documentation: "If there is no rule, a `null' - // binding is created: this usually does not map the packet, but exists - // to ensure we don't map another stream over an existing one." - if shouldTrack { - it.connections.maybeInsertNoop(pkt) - } -} - // beforeSave is invoked by stateify. func (it *IPTables) beforeSave() { // Ensure the reaper exits cleanly. @@ -606,7 +583,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, addressEP) + return rule.Target.Action(pkt, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 949c44c9b..8b74677d0 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } @@ -97,7 +97,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -154,15 +154,8 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r pkt.NatDone = true case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } - - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - conn.handlePacket(pkt, hook, dirOriginal, r) + if t := pkt.tuple; t != nil { + t.conn.performNAT(pkt, hook, r, rt.Port, address, true /* dnat */) } default: return RuleDrop, 0 @@ -181,7 +174,7 @@ type SNATTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) { +func snatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -197,30 +190,21 @@ func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint if port == 0 { switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - if port == 0 { - port = header.UDP(pkt.TransportHeader().View()).SourcePort() - } + port = header.UDP(pkt.TransportHeader().View()).SourcePort() case header.TCPProtocolNumber: - if port == 0 { - port = header.TCP(pkt.TransportHeader().View()).SourcePort() - } + port = header.TCP(pkt.TransportHeader().View()).SourcePort() } } - // Set up conection for matching NAT rule. Only the first packet of the - // connection comes here. Other packets will be manipulated in connection - // tracking. - // - // Does nothing if the protocol does not support connection tracking. - if conn := ct.insertSNATConn(pkt, hook, port, address); conn != nil { - conn.handlePacket(pkt, hook, dirOriginal, r) + if t := pkt.tuple; t != nil { + t.conn.performNAT(pkt, hook, r, port, address, false /* dnat */) } return RuleAccept, 0 } // Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { +func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if st.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -236,7 +220,7 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - return snatAction(pkt, ct, hook, r, st.Port, st.Addr) + return snatAction(pkt, hook, r, st.Port, st.Addr) } // MasqueradeTarget modifies the source port/IP in the outgoing packets. @@ -247,7 +231,7 @@ type MasqueradeTarget struct { } // Action implements Target.Action. -func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { +func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if mt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -272,7 +256,7 @@ func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, address := ep.AddressWithPrefix().Address ep.DecRef() - return snatAction(pkt, ct, hook, r, 0 /* port */, address) + return snatAction(pkt, hook, r, 0 /* port */, address) } func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 50f73f173..b22024667 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -356,5 +356,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) + Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 456b0cf80..888a8bd9d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -143,6 +143,8 @@ type PacketBuffer struct { // NetworkPacketInfo holds an incoming packet's network-layer information. NetworkPacketInfo NetworkPacketInfo + + tuple *tuple } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -302,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NICID: pk.NICID, RXTransportChecksumValidated: pk.RXTransportChecksumValidated, NetworkPacketInfo: pk.NetworkPacketInfo, + tuple: pk.tuple, } } @@ -329,13 +332,8 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { buf: pk.buf.Clone(), // Treat unfilled header portion as reserved. reserved: pk.AvailableHeaderBytes(), + tuple: pk.tuple, } - // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to - // maintain this flag in the packet. Currently conntrack needs this flag to - // tell if a noop connection should be inserted at Input hook. Once conntrack - // redefines the manipulation field as mutable, we won't need the special noop - // connection. - newPk.NatDone = pk.NatDone return newPk } @@ -367,12 +365,7 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu newPk.TransportProtocolNumber = pk.TransportProtocolNumber } - // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to - // maintain this flag in the packet. Currently conntrack needs this flag to - // tell if a noop connection should be inserted at Input hook. Once conntrack - // redefines the manipulation field as mutable, we won't need the special noop - // connection. - newPk.NatDone = pk.NatDone + newPk.tuple = pk.tuple return newPk } diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go index 6fa6b3a7b..dec8287f9 100644 --- a/pkg/tcpip/stack/stack_state_autogen.go +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -13,9 +13,9 @@ func (t *tuple) StateTypeName() string { func (t *tuple) StateFields() []string { return []string{ "tupleEntry", - "tupleID", "conn", "direction", + "tupleID", } } @@ -25,9 +25,9 @@ func (t *tuple) beforeSave() {} func (t *tuple) StateSave(stateSinkObject state.Sink) { t.beforeSave() stateSinkObject.Save(0, &t.tupleEntry) - stateSinkObject.Save(1, &t.tupleID) - stateSinkObject.Save(2, &t.conn) - stateSinkObject.Save(3, &t.direction) + stateSinkObject.Save(1, &t.conn) + stateSinkObject.Save(2, &t.direction) + stateSinkObject.Save(3, &t.tupleID) } func (t *tuple) afterLoad() {} @@ -35,9 +35,9 @@ func (t *tuple) afterLoad() {} // +checklocksignore func (t *tuple) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &t.tupleEntry) - stateSourceObject.Load(1, &t.tupleID) - stateSourceObject.Load(2, &t.conn) - stateSourceObject.Load(3, &t.direction) + stateSourceObject.Load(1, &t.conn) + stateSourceObject.Load(2, &t.direction) + stateSourceObject.Load(3, &t.tupleID) } func (ti *tupleID) StateTypeName() string { @@ -86,8 +86,10 @@ func (cn *conn) StateTypeName() string { func (cn *conn) StateFields() []string { return []string{ + "ct", "original", "reply", + "finalized", "manip", "tcb", "lastUsed", @@ -101,22 +103,26 @@ func (cn *conn) StateSave(stateSinkObject state.Sink) { cn.beforeSave() var lastUsedValue unixTime lastUsedValue = cn.saveLastUsed() - stateSinkObject.SaveValue(4, lastUsedValue) - stateSinkObject.Save(0, &cn.original) - stateSinkObject.Save(1, &cn.reply) - stateSinkObject.Save(2, &cn.manip) - stateSinkObject.Save(3, &cn.tcb) + stateSinkObject.SaveValue(6, lastUsedValue) + stateSinkObject.Save(0, &cn.ct) + stateSinkObject.Save(1, &cn.original) + stateSinkObject.Save(2, &cn.reply) + stateSinkObject.Save(3, &cn.finalized) + stateSinkObject.Save(4, &cn.manip) + stateSinkObject.Save(5, &cn.tcb) } func (cn *conn) afterLoad() {} // +checklocksignore func (cn *conn) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(0, &cn.original) - stateSourceObject.Load(1, &cn.reply) - stateSourceObject.Load(2, &cn.manip) - stateSourceObject.Load(3, &cn.tcb) - stateSourceObject.LoadValue(4, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) }) + stateSourceObject.Load(0, &cn.ct) + stateSourceObject.Load(1, &cn.original) + stateSourceObject.Load(2, &cn.reply) + stateSourceObject.Load(3, &cn.finalized) + stateSourceObject.Load(4, &cn.manip) + stateSourceObject.Load(5, &cn.tcb) + stateSourceObject.LoadValue(6, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) }) } func (ct *ConnTrack) StateTypeName() string { @@ -145,29 +151,29 @@ func (ct *ConnTrack) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(1, &ct.buckets) } -func (b *bucket) StateTypeName() string { +func (bkt *bucket) StateTypeName() string { return "pkg/tcpip/stack.bucket" } -func (b *bucket) StateFields() []string { +func (bkt *bucket) StateFields() []string { return []string{ "tuples", } } -func (b *bucket) beforeSave() {} +func (bkt *bucket) beforeSave() {} // +checklocksignore -func (b *bucket) StateSave(stateSinkObject state.Sink) { - b.beforeSave() - stateSinkObject.Save(0, &b.tuples) +func (bkt *bucket) StateSave(stateSinkObject state.Sink) { + bkt.beforeSave() + stateSinkObject.Save(0, &bkt.tuples) } -func (b *bucket) afterLoad() {} +func (bkt *bucket) afterLoad() {} // +checklocksignore -func (b *bucket) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(0, &b.tuples) +func (bkt *bucket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &bkt.tuples) } func (u *unixTime) StateTypeName() string { |