summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/conntrack.go589
-rw-r--r--pkg/tcpip/stack/iptables.go132
-rw-r--r--pkg/tcpip/stack/iptables_state.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go192
-rw-r--r--pkg/tcpip/stack/iptables_types.go28
-rw-r--r--pkg/tcpip/stack/packet_buffer.go30
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go38
-rw-r--r--pkg/tcpip/stack/stack_test.go11
-rw-r--r--pkg/tcpip/stack/tcp.go6
9 files changed, 538 insertions, 492 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 4fb7e9adb..48f290187 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -45,17 +45,6 @@ const (
dirReply
)
-// Manipulation type for the connection.
-// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
-// DNAT at the same time.
-type manipType int
-
-const (
- manipNone manipType = iota
- manipSource
- manipDestination
-)
-
// tuple holds a connection's identifying and manipulating data in one
// direction. It is immutable.
//
@@ -64,13 +53,21 @@ type tuple struct {
// tupleEntry is used to build an intrusive list of tuples.
tupleEntry
- tupleID
-
// conn is the connection tracking entry this tuple belongs to.
conn *conn
// direction is the direction of the tuple.
direction direction
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ tupleID tupleID
+}
+
+func (t *tuple) id() tupleID {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.tupleID
}
// tupleID uniquely identifies a connection in one direction. It currently
@@ -103,50 +100,47 @@ func (ti tupleID) reply() tupleID {
//
// +stateify savable
type conn struct {
+ ct *ConnTrack
+
// original is the tuple in original direction. It is immutable.
original tuple
- // reply is the tuple in reply direction. It is immutable.
+ // reply is the tuple in reply direction.
reply tuple
- // manip indicates if the packet should be manipulated. It is immutable.
- // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
- manip manipType
-
- // tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb. It is immutable.
- tcbHook Hook
-
- // mu protects all mutable state.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // Indicates that the connection has been finalized and may handle replies.
+ //
+ // +checklocks:mu
+ finalized bool
+ // sourceManip indicates the packet's source is manipulated.
+ //
+ // +checklocks:mu
+ sourceManip bool
+ // destinationManip indicates the packet's destination is manipulated.
+ //
+ // +checklocks:mu
+ destinationManip bool
// tcb is TCB control block. It is used to keep track of states
- // of tcp connection and is protected by mu.
+ // of tcp connection.
+ //
+ // +checklocks:mu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
- // is updated by each packet on the connection. It is protected by mu.
+ // is updated by each packet on the connection.
//
// TODO(gvisor.dev/issue/5939): do not use the ambient clock.
+ //
+ // +checklocks:mu
lastUsed time.Time `state:".(unixTime)"`
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
const defaultTimeout = 120 * time.Second
- cn.mu.Lock()
- defer cn.mu.Unlock()
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -159,8 +153,9 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
-// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:cn.mu
+func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) {
if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
return
}
@@ -172,10 +167,16 @@ func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
// established or not, so the client/server distinction isn't important.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
- } else if hook == cn.tcbHook {
+ return
+ }
+
+ switch dir {
+ case dirOriginal:
cn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
+ case dirReply:
cn.tcb.UpdateStateInbound(tcpHeader)
+ default:
+ panic(fmt.Sprintf("unhandled dir = %d", dir))
}
}
@@ -200,18 +201,18 @@ type ConnTrack struct {
// It is immutable.
seed uint32
+ mu sync.RWMutex `state:"nosave"`
// mu protects the buckets slice, but not buckets' contents. Only take
// the write lock if you are modifying the slice or saving for S/R.
- mu sync.RWMutex `state:"nosave"`
-
- // buckets is protected by mu.
+ //
+ // +checklocks:mu
buckets []bucket
}
// +stateify savable
type bucket struct {
- // mu protects tuples.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
tuples tupleList
}
@@ -230,241 +231,212 @@ func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool)
return nil, false
}
-// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
-// TCP header.
-//
-// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
+func (ct *ConnTrack) init() {
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+ ct.buckets = make([]bucket, numBuckets)
+}
+
+func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
netHeader := pkt.Network()
transportHeader, ok := getTransportHeader(pkt)
if !ok {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
+ return nil
}
- return tupleID{
+ tid := tupleID{
srcAddr: netHeader.SourceAddress(),
srcPort: transportHeader.SourcePort(),
dstAddr: netHeader.DestinationAddress(),
dstPort: transportHeader.DestinationPort(),
transProto: pkt.TransportProtocolNumber,
netProto: pkt.NetworkProtocolNumber,
- }, nil
-}
-
-func (ct *ConnTrack) init() {
- ct.mu.Lock()
- defer ct.mu.Unlock()
- ct.buckets = make([]bucket, numBuckets)
-}
-
-// connFor gets the conn for pkt if it exists, or returns nil
-// if it does not. It returns an error when pkt does not contain a valid TCP
-// header.
-// TODO(gvisor.dev/issue/6168): Support UDP.
-func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil, dirOriginal
}
- return ct.connForTID(tid)
-}
-func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
- bucket := ct.bucket(tid)
- now := time.Now()
+ bktID := ct.bucket(tid)
ct.mu.RLock()
- defer ct.mu.RUnlock()
- ct.buckets[bucket].mu.Lock()
- defer ct.buckets[bucket].mu.Unlock()
-
- // Iterate over the tuples in a bucket, cleaning up any unused
- // connections we find.
- for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
- // Clean up any timed-out connections we happen to find.
- if ct.reapTupleLocked(other, bucket, now) {
- // The tuple expired.
- continue
- }
- if tid == other.tupleID {
- return other.conn, other.direction
- }
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ now := time.Now()
+ if t := bkt.connForTID(tid, now); t != nil {
+ return t
}
- return nil, dirOriginal
-}
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
+ // Make sure a connection wasn't added between when we last checked the
+ // bucket and acquired the bucket's write lock.
+ if t := bkt.connForTIDRLocked(tid, now); t != nil {
+ return t
}
- if hook != Prerouting && hook != Output {
- return nil
+
+ // This is the first packet we're seeing for the connection. Create an entry
+ // for this new connection.
+ conn := &conn{
+ ct: ct,
+ original: tuple{tupleID: tid, direction: dirOriginal},
+ reply: tuple{tupleID: tid.reply(), direction: dirReply},
+ lastUsed: now,
}
+ conn.original.conn = conn
+ conn.reply.conn = conn
- replyTID := tid.reply()
- replyTID.srcAddr = address
- replyTID.srcPort = port
+ // For now, we only map an entry for the packet's original tuple as NAT may be
+ // performed on this connection. Until the packet goes through all the hooks
+ // and its final address/port is known, we cannot know what the response
+ // packet's addresses/ports will look like.
+ //
+ // This is okay because the destination cannot send its response until it
+ // receives the packet; the packet will only be received once all the hooks
+ // have been performed.
+ //
+ // See (*conn).finalize.
+ bkt.tuples.PushFront(&conn.original)
+ return &conn.original
+}
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
- }
- conn = newConn(tid, replyTID, manipDestination, hook)
- ct.insertConn(conn)
- return conn
+func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
+ bktID := ct.bucket(tid)
+
+ ct.mu.RLock()
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ return bkt.connForTID(tid, time.Now())
}
-func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Input && hook != Postrouting {
- return nil
+func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple {
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ return bkt.connForTIDRLocked(tid, now)
+}
+
+// +checklocks:bkt.mu
+func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple {
+ for other := bkt.tuples.Front(); other != nil; other = other.Next() {
+ if tid == other.id() && !other.conn.timedOut(now) {
+ return other
+ }
}
+ return nil
+}
- replyTID := tid.reply()
- replyTID.dstAddr = address
- replyTID.dstPort = port
+func (ct *ConnTrack) finalize(cn *conn) {
+ tid := cn.reply.id()
+ id := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ ct.mu.RLock()
+ bkt := &ct.buckets[id]
+ ct.mu.RUnlock()
+
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil {
+ // Another connection for the reply already exists. We can't do much about
+ // this so we leave the connection cn represents in a state where it can
+ // send packets but its responses will be mapped to some other connection.
+ // This may be okay if the connection only expects to send packets without
+ // any responses.
+ return
}
- conn = newConn(tid, replyTID, manipSource, hook)
- ct.insertConn(conn)
- return conn
+
+ bkt.tuples.PushFront(&cn.reply)
}
-// insertConn inserts conn into the appropriate table bucket.
-func (ct *ConnTrack) insertConn(conn *conn) {
- // Lock the buckets in the correct order.
- tupleBucket := ct.bucket(conn.original.tupleID)
- replyBucket := ct.bucket(conn.reply.tupleID)
- ct.mu.RLock()
- defer ct.mu.RUnlock()
- if tupleBucket < replyBucket {
- ct.buckets[tupleBucket].mu.Lock()
- ct.buckets[replyBucket].mu.Lock()
- } else if tupleBucket > replyBucket {
- ct.buckets[replyBucket].mu.Lock()
- ct.buckets[tupleBucket].mu.Lock()
- } else {
- // Both tuples are in the same bucket.
- ct.buckets[tupleBucket].mu.Lock()
- }
-
- // Now that we hold the locks, ensure the tuple hasn't been inserted by
- // another thread.
- // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
- alreadyInserted := false
- for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
- if other.tupleID == conn.original.tupleID {
- alreadyInserted = true
- break
+func (cn *conn) finalize() {
+ {
+ cn.mu.RLock()
+ finalized := cn.finalized
+ cn.mu.RUnlock()
+ if finalized {
+ return
}
}
- if !alreadyInserted {
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ cn.mu.Lock()
+ finalized := cn.finalized
+ cn.finalized = true
+ cn.mu.Unlock()
+ if finalized {
+ return
}
- // Unlocking can happen in any order.
- ct.buckets[tupleBucket].mu.Unlock()
- if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
- }
+ cn.ct.finalize(cn)
}
-// handlePacket will manipulate the port and address of the packet if the
-// connection exists. Returns whether, after the packet traverses the tables,
-// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
- if pkt.NatDone {
- return false
- }
+// performNAT setups up the connection for the specified NAT.
+//
+// Generally, only the first packet of a connection reaches this method; other
+// other packets will be manipulated without needing to modify the connection.
+func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) {
+ cn.performNATIfNoop(port, address, dnat)
+ cn.handlePacket(pkt, hook, r)
+}
- switch hook {
- case Prerouting, Input, Output, Postrouting:
- default:
- return false
- }
+func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
- transportHeader, ok := getTransportHeader(pkt)
- if !ok {
- return false
+ if cn.finalized {
+ return
}
- conn, dir := ct.connFor(pkt)
- // Connection not found for the packet.
- if conn == nil {
- // If this is the last hook in the data path for this packet (Input if
- // incoming, Postrouting if outgoing), indicate that a connection should be
- // inserted by the end of this hook.
- return hook == Input || hook == Postrouting
+ if dnat {
+ if cn.destinationManip {
+ return
+ }
+ cn.destinationManip = true
+ } else {
+ if cn.sourceManip {
+ return
+ }
+ cn.sourceManip = true
}
- netHeader := pkt.Network()
-
- // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
- // validated if checksum offloading is off. It may require IP defrag if the
- // packets are fragmented.
-
- var newAddr tcpip.Address
- var newPort uint16
+ cn.reply.mu.Lock()
+ defer cn.reply.mu.Unlock()
- updateSRCFields := false
+ if dnat {
+ cn.reply.tupleID.srcAddr = address
+ cn.reply.tupleID.srcPort = port
+ } else {
+ cn.reply.tupleID.dstAddr = address
+ cn.reply.tupleID.dstPort = port
+ }
+}
- switch hook {
- case Prerouting, Output:
- if conn.manip == manipDestination && dir == dirOriginal {
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- pkt.NatDone = true
- } else if conn.manip == manipSource && dir == dirReply {
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- pkt.NatDone = true
- }
- case Input, Postrouting:
- if conn.manip == manipSource && dir == dirOriginal {
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
- updateSRCFields = true
- pkt.NatDone = true
- } else if conn.manip == manipDestination && dir == dirReply {
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
- updateSRCFields = true
- pkt.NatDone = true
- }
- default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
+ if pkt.NatDone {
+ return
}
- if !pkt.NatDone {
- return false
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return
}
fullChecksum := false
updatePseudoHeader := false
+ dnat := false
switch hook {
case Prerouting:
// Packet came from outside the stack so it must have a checksum set
// already.
fullChecksum = true
updatePseudoHeader = true
+
+ dnat = true
case Input:
- case Output, Postrouting:
- // Calculate the TCP checksum and set it.
+ case Forward:
+ panic("should not handle packet in the forwarding hook")
+ case Output:
+ dnat = true
+ fallthrough
+ case Postrouting:
if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
@@ -472,62 +444,73 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
updatePseudoHeader = true
}
default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ panic(fmt.Sprintf("unrecognized hook = %d", hook))
}
- rewritePacket(
- netHeader,
- transportHeader,
- updateSRCFields,
- fullChecksum,
- updatePseudoHeader,
- newPort,
- newAddr,
- )
+ // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
+ // validated if checksum offloading is off. It may require IP defrag if the
+ // packets are fragmented.
- // Update the state of tcb.
- conn.mu.Lock()
- defer conn.mu.Unlock()
+ dir := pkt.tuple.direction
+ tid, performManip := func() (tupleID, bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ var tuple *tuple
+ switch dir {
+ case dirOriginal:
+ if dnat {
+ if !cn.destinationManip {
+ return tupleID{}, false
+ }
+ } else if !cn.sourceManip {
+ return tupleID{}, false
+ }
- // Mark the connection as having been used recently so it isn't reaped.
- conn.lastUsed = time.Now()
- // Update connection state.
- conn.updateLocked(pkt, hook)
+ tuple = &cn.reply
+ case dirReply:
+ if dnat {
+ if !cn.sourceManip {
+ return tupleID{}, false
+ }
+ } else if !cn.destinationManip {
+ return tupleID{}, false
+ }
- return false
-}
+ tuple = &cn.original
+ default:
+ panic(fmt.Sprintf("unhandled dir = %d", dir))
+ }
-// maybeInsertNoop tries to insert a no-op connection entry to keep connections
-// from getting clobbered when replies arrive. It only inserts if there isn't
-// already a connection for pkt.
-//
-// This should be called after traversing iptables rules only, to ensure that
-// pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
- // If there were a rule applying to this packet, it would be marked
- // with NatDone.
- if pkt.NatDone {
- return
- }
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = time.Now()
+ // Update connection state.
+ cn.updateLocked(pkt, dir)
- switch pkt.TransportProtocolNumber {
- case header.TCPProtocolNumber, header.UDPProtocolNumber:
- default:
- // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable
- // connections.
+ return tuple.id(), true
+ }()
+ if !performManip {
return
}
- // This is the first packet we're seeing for the TCP connection. Insert
- // the noop entry (an identity mapping) so that the response doesn't
- // get NATed, breaking the connection.
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return
+ newPort := tid.dstPort
+ newAddr := tid.dstAddr
+ if dnat {
+ newPort = tid.srcPort
+ newAddr = tid.srcAddr
}
- conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(pkt, hook)
- ct.insertConn(conn)
+
+ rewritePacket(
+ pkt.Network(),
+ transportHeader,
+ !dnat,
+ fullChecksum,
+ updatePseudoHeader,
+ newPort,
+ newAddr,
+ )
+
+ pkt.NatDone = true
}
// bucket gets the conntrack bucket for a tupleID.
@@ -579,14 +562,15 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
defer ct.mu.RUnlock()
for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
idx = (i + start) % len(ct.buckets)
- ct.buckets[idx].mu.Lock()
- for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ bkt := &ct.buckets[idx]
+ bkt.mu.Lock()
+ for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
checked++
- if ct.reapTupleLocked(tuple, idx, now) {
+ if ct.reapTupleLocked(tuple, idx, bkt, now) {
expired++
}
}
- ct.buckets[idx].mu.Unlock()
+ bkt.mu.Unlock()
}
// We already checked buckets[idx].
idx++
@@ -611,41 +595,45 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// reapTupleLocked tries to remove tuple and its reply from the table. It
// returns whether the tuple's connection has timed out.
//
-// Preconditions:
-// * ct.mu is locked for reading.
-// * bucket is locked.
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+// Precondition: ct.mu is read locked and bkt.mu is write locked.
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:ct.mu
+// +checklocks:bkt.mu
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool {
if !tuple.conn.timedOut(now) {
return false
}
// To maintain lock order, we can only reap these tuples if the reply
// appears later in the table.
- replyBucket := ct.bucket(tuple.reply())
- if bucket > replyBucket {
+ replyBktID := ct.bucket(tuple.id().reply())
+ if bktID > replyBktID {
return true
}
// Don't re-lock if both tuples are in the same bucket.
- differentBuckets := bucket != replyBucket
- if differentBuckets {
- ct.buckets[replyBucket].mu.Lock()
+ if bktID != replyBktID {
+ replyBkt := &ct.buckets[replyBktID]
+ replyBkt.mu.Lock()
+ removeConnFromBucket(replyBkt, tuple)
+ replyBkt.mu.Unlock()
+ } else {
+ removeConnFromBucket(bkt, tuple)
}
// We have the buckets locked and can remove both tuples.
+ bkt.tuples.Remove(tuple)
+ return true
+}
+
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:b.mu
+func removeConnFromBucket(b *bucket, tuple *tuple) {
if tuple.direction == dirOriginal {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ b.tuples.Remove(&tuple.conn.reply)
} else {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
- }
- ct.buckets[bucket].tuples.Remove(tuple)
-
- // Don't re-unlock if both tuples are in the same bucket.
- if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
+ b.tuples.Remove(&tuple.conn.original)
}
-
- return true
}
func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
@@ -659,14 +647,19 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
transProto: transProto,
netProto: netProto,
}
- conn, _ := ct.connForTID(tid)
- if conn == nil {
+ t := ct.connForTID(tid)
+ if t == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip != manipDestination {
+ }
+
+ t.conn.mu.RLock()
+ defer t.conn.mu.RUnlock()
+ if !t.conn.destinationManip {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
- return conn.original.dstAddr, conn.original.dstPort, nil
+ id := t.conn.original.id()
+ return id.dstAddr, id.dstPort, nil
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 74c9075b4..5808be685 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -271,7 +271,18 @@ const (
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
- return it.check(Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
+ const hook = Prerouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
}
// CheckInput performs the input hook on the packet.
@@ -281,7 +292,22 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
- return it.check(Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ const hook = Input
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
}
// CheckForward performs the forward hook on the packet.
@@ -291,6 +317,9 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
}
@@ -301,7 +330,18 @@ func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
- return it.check(Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+ const hook = Output
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
}
// CheckPostrouting performs the postrouting hook on the packet.
@@ -310,8 +350,38 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string)
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
-func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool {
- return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
+ const hook = Postrouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool {
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ // IPTables only supports IPv4/IPv6.
+ return true
+ }
+
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ return !it.modified
}
// check runs pkt through the rules for hook. It returns true when the packet
@@ -320,20 +390,8 @@ func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName str
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
- if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
- return true
- }
- // Many users never configure iptables. Spare them the cost of rule
- // traversal if rules have never been set.
it.mu.RLock()
defer it.mu.RUnlock()
- if !it.modified {
- return true
- }
-
- // Packets are manipulated only if connection and matching
- // NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
@@ -361,7 +419,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v {
+ switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -377,21 +435,6 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
}
}
- // If this connection should be tracked, try to add an entry for it. If
- // traversing the nat table didn't end in adding an entry,
- // maybeInsertNoop will add a no-op entry for the connection. This is
- // needeed when establishing connections so that the SYN/ACK reply to an
- // outgoing SYN is delivered to the correct endpoint rather than being
- // redirected by a prerouting rule.
- //
- // From the iptables documentation: "If there is no rule, a `null'
- // binding is created: this usually does not map the packet, but exists
- // to ensure we don't map another stream over an existing one."
- if shouldTrack {
- it.connections.maybeInsertNoop(pkt, hook)
- }
-
- // Every table returned Accept.
return true
}
@@ -431,7 +474,9 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// Precondition: The packets' network and transport header must be set.
func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
- return it.checkPackets(Output, pkts, r, outNicName)
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckOutput(pkt, r, outNicName)
+ })
}
// CheckPostroutingPackets performs the postrouting hook on the packets.
@@ -439,21 +484,16 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa
// Returns a map of packets that must be dropped.
//
// Precondition: The packets' network and transport header must be set.
-func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
- return it.checkPackets(Postrouting, pkts, r, outNicName)
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckPostrouting(pkt, r, addressEP, outNicName)
+ })
}
-// checkPackets runs pkts through the rules for hook and returns a map of
-// packets that should not go forward.
-//
-// NOTE: unlike the Check API the returned map contains packets that should be
-// dropped.
-//
-// Precondition: The packets' network and transport header must be set.
-func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok {
+ if ok := f(pkt); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -543,7 +583,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, addressEP)
+ return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
index 529e02a07..3d3c39c20 100644
--- a/pkg/tcpip/stack/iptables_state.go
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -26,11 +26,15 @@ type unixTime struct {
// saveLastUsed is invoked by stateify.
func (cn *conn) saveLastUsed() unixTime {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
}
// loadLastUsed is invoked by stateify.
func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
cn.lastUsed = time.Unix(unix.second, unix.nano)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index e8806ebdb..85490e2d4 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,10 +79,49 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
+// DNATTarget modifies the destination port/IP of packets.
+type DNATTarget struct {
+ // The new destination address for packets.
+ //
+ // Immutable.
+ Addr tcpip.Address
+
+ // The new destination port for packets.
+ //
+ // Immutable.
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with.
+ //
+ // Immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Prerouting, Output:
+ case Input, Forward, Postrouting:
+ panic(fmt.Sprintf("%s not supported for DNAT", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ return natAction(pkt, hook, r, rt.Port, rt.Addr, true /* dnat */)
+
+}
+
// RedirectTarget redirects the packet to this machine by modifying the
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
@@ -97,7 +136,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -105,16 +144,6 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Packet is already manipulated.
- if pkt.NatDone {
- return RuleAccept, 0
- }
-
- // Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
- return RuleDrop, 0
- }
-
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
var address tcpip.Address
@@ -132,43 +161,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
panic("redirect target is supported only on output and prerouting hooks")
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
-
- if hook == Output {
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- udpHeader,
- false, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- rt.Port,
- address,
- )
- } else {
- udpHeader.SetDestinationPort(rt.Port)
- }
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
-
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
- }
-
- return RuleAccept, 0
+ return natAction(pkt, hook, r, rt.Port, address, true /* dnat */)
}
// SNATTarget modifies the source port/IP in the outgoing packets.
@@ -181,15 +174,7 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
- // Sanity check.
- if st.NetworkProtocol != pkt.NetworkProtocolNumber {
- panic(fmt.Sprintf(
- "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
- st.NetworkProtocol, pkt.NetworkProtocolNumber))
- }
-
+func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -200,6 +185,37 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleDrop, 0
}
+ t := pkt.tuple
+ if t == nil {
+ return RuleDrop, 0
+ }
+
+ // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a
+ // different port.
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ case header.TCPProtocolNumber:
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ default:
+ panic(fmt.Sprintf("unsupported transport protocol = %d", pkt.TransportProtocolNumber))
+ }
+ }
+
+ t.conn.performNAT(pkt, hook, r, port, address, dnat)
+ return RuleAccept, 0
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
@@ -208,31 +224,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- port := st.Port
+ return natAction(pkt, hook, r, st.Port, st.Addr, false /* dnat */)
+}
- if port == 0 {
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- if port == 0 {
- port = header.UDP(pkt.TransportHeader().View()).SourcePort()
- }
- case header.TCPProtocolNumber:
- if port == 0 {
- port = header.TCP(pkt.TransportHeader().View()).SourcePort()
- }
- }
+// MasqueradeTarget modifies the source port/IP in the outgoing packets.
+type MasqueradeTarget struct {
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ mt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Set up conection for matching NAT rule. Only the first packet of the
- // connection comes here. Other packets will be manipulated in connection
- // tracking.
- //
- // Does nothing if the protocol does not support connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
+ switch hook {
+ case Postrouting:
+ case Prerouting, Input, Forward, Output:
+ panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
}
- return RuleAccept, 0
+ // addressEP is expected to be set for the postrouting hook.
+ ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */)
+ if ep == nil {
+ // No address exists that we can use as a source address.
+ return RuleDrop, 0
+ }
+
+ address := ep.AddressWithPrefix().Address
+ ep.DecRef()
+ return natAction(pkt, hook, r, 0 /* port */, address, false /* dnat */)
}
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 976194124..b22024667 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -81,17 +81,6 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects v4Tables, v6Tables, and modified.
- mu sync.RWMutex
- // v4Tables and v6tables map tableIDs to tables. They hold builtin
- // tables only, not user tables. mu must be locked for accessing.
- v4Tables [NumTables]Table
- v6Tables [NumTables]Table
- // modified is whether tables have been modified at least once. It is
- // used to elide the iptables performance overhead for workloads that
- // don't utilize iptables.
- modified bool
-
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
@@ -101,6 +90,21 @@ type IPTables struct {
// reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
+
+ mu sync.RWMutex
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables.
+ //
+ // +checklocks:mu
+ v4Tables [NumTables]Table
+ // +checklocks:mu
+ v6Tables [NumTables]Table
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ //
+ // +checklocks:mu
+ modified bool
}
// VisitTargets traverses all the targets of all tables and replaces each with
@@ -352,5 +356,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
+ Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index bf248ef20..888a8bd9d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -143,6 +143,8 @@ type PacketBuffer struct {
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
+
+ tuple *tuple
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -302,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NICID: pk.NICID,
RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
NetworkPacketInfo: pk.NetworkPacketInfo,
+ tuple: pk.tuple,
}
}
@@ -329,13 +332,8 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
+ tuple: pk.tuple,
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- newPk.NatDone = pk.NatDone
return newPk
}
@@ -367,12 +365,7 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu
newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- newPk.NatDone = pk.NatDone
+ newPk.tuple = pk.tuple
return newPk
}
@@ -425,13 +418,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// DeleteFront removes count from the beginning of d. It panics if count >
-// d.Size(). All backing storage references after the front of the d are
-// invalidated.
-func (d PacketData) DeleteFront(count int) {
- if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
- panic("count > d.Size()")
+// Consume is the same as PullUp except that is additionally consumes the
+// returned bytes. Subsequent PullUp or Consume will not return these bytes.
+func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) {
+ v, ok := d.PullUp(size)
+ if ok {
+ d.pk.consumed += size
}
+ return v, ok
}
// CapLength reduces d to at most length bytes.
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 87b023445..c376ed1a1 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) {
}
}
-func TestPacketBufferClone(t *testing.T) {
- data := concatViews(makeView(20), makeView(30), makeView(40))
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
- })
-
- bytesToDelete := 30
- originalSize := data.Size()
-
- clonedPks := []*PacketBuffer{
- pk.Clone(),
- pk.CloneToInbound(),
- }
- pk.Data().DeleteFront(bytesToDelete)
- if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
- t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
- }
- for _, clonedPk := range clonedPks {
- if got := clonedPk.Data().Size(); got != originalSize {
- t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
- }
- }
-}
-
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
@@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // DeleteFront
+ // Consume.
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
+ v, ok := pkt.Data().Consume(n)
+ if !ok {
+ t.Fatalf("Consume failed")
+ }
+ if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) {
+ t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want)
+ }
checkData(t, pkt, []byte(tc.data)[n:])
})
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index cd4137794..c23e91702 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().Consume(fakeNetHeaderLen)
if !ok {
return
}
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]),
// Nothing checks the error.
nil, /* transport error */
pkt,
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index dc7289441..a941091b0 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -289,6 +289,12 @@ type TCPSenderState struct {
// RACKState holds the state related to RACK loss detection algorithm.
RACKState TCPRACKState
+
+ // RetransmitTS records the timestamp used to detect spurious recovery.
+ RetransmitTS uint32
+
+ // SpuriousRecovery indicates if the sender entered recovery spuriously.
+ SpuriousRecovery bool
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.