diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 589 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 132 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_state.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 192 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/tcp.go | 6 |
9 files changed, 538 insertions, 492 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 4fb7e9adb..48f290187 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -45,17 +45,6 @@ const ( dirReply ) -// Manipulation type for the connection. -// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and -// DNAT at the same time. -type manipType int - -const ( - manipNone manipType = iota - manipSource - manipDestination -) - // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. // @@ -64,13 +53,21 @@ type tuple struct { // tupleEntry is used to build an intrusive list of tuples. tupleEntry - tupleID - // conn is the connection tracking entry this tuple belongs to. conn *conn // direction is the direction of the tuple. direction direction + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + tupleID tupleID +} + +func (t *tuple) id() tupleID { + t.mu.RLock() + defer t.mu.RUnlock() + return t.tupleID } // tupleID uniquely identifies a connection in one direction. It currently @@ -103,50 +100,47 @@ func (ti tupleID) reply() tupleID { // // +stateify savable type conn struct { + ct *ConnTrack + // original is the tuple in original direction. It is immutable. original tuple - // reply is the tuple in reply direction. It is immutable. + // reply is the tuple in reply direction. reply tuple - // manip indicates if the packet should be manipulated. It is immutable. - // TODO(gvisor.dev/issue/5696): Support updating manipulation type. - manip manipType - - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. It is immutable. - tcbHook Hook - - // mu protects all mutable state. - mu sync.Mutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // Indicates that the connection has been finalized and may handle replies. + // + // +checklocks:mu + finalized bool + // sourceManip indicates the packet's source is manipulated. + // + // +checklocks:mu + sourceManip bool + // destinationManip indicates the packet's destination is manipulated. + // + // +checklocks:mu + destinationManip bool // tcb is TCB control block. It is used to keep track of states - // of tcp connection and is protected by mu. + // of tcp connection. + // + // +checklocks:mu tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and - // is updated by each packet on the connection. It is protected by mu. + // is updated by each packet on the connection. // // TODO(gvisor.dev/issue/5939): do not use the ambient clock. + // + // +checklocks:mu lastUsed time.Time `state:".(unixTime)"` } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour const defaultTimeout = 120 * time.Second - cn.mu.Lock() - defer cn.mu.Unlock() + cn.mu.RLock() + defer cn.mu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { // Use the same default as Linux, which doesn't delete // established connections for 5(!) days. @@ -159,8 +153,9 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // -// Precondition: cn.mu must be held. -func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) { +// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. +// +checklocks:cn.mu +func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) { if pkt.TransportProtocolNumber != header.TCPProtocolNumber { return } @@ -172,10 +167,16 @@ func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) { // established or not, so the client/server distinction isn't important. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) - } else if hook == cn.tcbHook { + return + } + + switch dir { + case dirOriginal: cn.tcb.UpdateStateOutbound(tcpHeader) - } else { + case dirReply: cn.tcb.UpdateStateInbound(tcpHeader) + default: + panic(fmt.Sprintf("unhandled dir = %d", dir)) } } @@ -200,18 +201,18 @@ type ConnTrack struct { // It is immutable. seed uint32 + mu sync.RWMutex `state:"nosave"` // mu protects the buckets slice, but not buckets' contents. Only take // the write lock if you are modifying the slice or saving for S/R. - mu sync.RWMutex `state:"nosave"` - - // buckets is protected by mu. + // + // +checklocks:mu buckets []bucket } // +stateify savable type bucket struct { - // mu protects tuples. - mu sync.Mutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu tuples tupleList } @@ -230,241 +231,212 @@ func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) return nil, false } -// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid -// TCP header. -// -// Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { +func (ct *ConnTrack) init() { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.buckets = make([]bucket, numBuckets) +} + +func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { netHeader := pkt.Network() transportHeader, ok := getTransportHeader(pkt) if !ok { - return tupleID{}, &tcpip.ErrUnknownProtocol{} + return nil } - return tupleID{ + tid := tupleID{ srcAddr: netHeader.SourceAddress(), srcPort: transportHeader.SourcePort(), dstAddr: netHeader.DestinationAddress(), dstPort: transportHeader.DestinationPort(), transProto: pkt.TransportProtocolNumber, netProto: pkt.NetworkProtocolNumber, - }, nil -} - -func (ct *ConnTrack) init() { - ct.mu.Lock() - defer ct.mu.Unlock() - ct.buckets = make([]bucket, numBuckets) -} - -// connFor gets the conn for pkt if it exists, or returns nil -// if it does not. It returns an error when pkt does not contain a valid TCP -// header. -// TODO(gvisor.dev/issue/6168): Support UDP. -func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil, dirOriginal } - return ct.connForTID(tid) -} -func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { - bucket := ct.bucket(tid) - now := time.Now() + bktID := ct.bucket(tid) ct.mu.RLock() - defer ct.mu.RUnlock() - ct.buckets[bucket].mu.Lock() - defer ct.buckets[bucket].mu.Unlock() - - // Iterate over the tuples in a bucket, cleaning up any unused - // connections we find. - for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { - // Clean up any timed-out connections we happen to find. - if ct.reapTupleLocked(other, bucket, now) { - // The tuple expired. - continue - } - if tid == other.tupleID { - return other.conn, other.direction - } + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + now := time.Now() + if t := bkt.connForTID(tid, now); t != nil { + return t } - return nil, dirOriginal -} + bkt.mu.Lock() + defer bkt.mu.Unlock() -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil + // Make sure a connection wasn't added between when we last checked the + // bucket and acquired the bucket's write lock. + if t := bkt.connForTIDRLocked(tid, now); t != nil { + return t } - if hook != Prerouting && hook != Output { - return nil + + // This is the first packet we're seeing for the connection. Create an entry + // for this new connection. + conn := &conn{ + ct: ct, + original: tuple{tupleID: tid, direction: dirOriginal}, + reply: tuple{tupleID: tid.reply(), direction: dirReply}, + lastUsed: now, } + conn.original.conn = conn + conn.reply.conn = conn - replyTID := tid.reply() - replyTID.srcAddr = address - replyTID.srcPort = port + // For now, we only map an entry for the packet's original tuple as NAT may be + // performed on this connection. Until the packet goes through all the hooks + // and its final address/port is known, we cannot know what the response + // packet's addresses/ports will look like. + // + // This is okay because the destination cannot send its response until it + // receives the packet; the packet will only be received once all the hooks + // have been performed. + // + // See (*conn).finalize. + bkt.tuples.PushFront(&conn.original) + return &conn.original +} - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil - } - conn = newConn(tid, replyTID, manipDestination, hook) - ct.insertConn(conn) - return conn +func (ct *ConnTrack) connForTID(tid tupleID) *tuple { + bktID := ct.bucket(tid) + + ct.mu.RLock() + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + return bkt.connForTID(tid, time.Now()) } -func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil - } - if hook != Input && hook != Postrouting { - return nil +func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple { + bkt.mu.RLock() + defer bkt.mu.RUnlock() + return bkt.connForTIDRLocked(tid, now) +} + +// +checklocks:bkt.mu +func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple { + for other := bkt.tuples.Front(); other != nil; other = other.Next() { + if tid == other.id() && !other.conn.timedOut(now) { + return other + } } + return nil +} - replyTID := tid.reply() - replyTID.dstAddr = address - replyTID.dstPort = port +func (ct *ConnTrack) finalize(cn *conn) { + tid := cn.reply.id() + id := ct.bucket(tid) - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil + ct.mu.RLock() + bkt := &ct.buckets[id] + ct.mu.RUnlock() + + bkt.mu.Lock() + defer bkt.mu.Unlock() + + if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil { + // Another connection for the reply already exists. We can't do much about + // this so we leave the connection cn represents in a state where it can + // send packets but its responses will be mapped to some other connection. + // This may be okay if the connection only expects to send packets without + // any responses. + return } - conn = newConn(tid, replyTID, manipSource, hook) - ct.insertConn(conn) - return conn + + bkt.tuples.PushFront(&cn.reply) } -// insertConn inserts conn into the appropriate table bucket. -func (ct *ConnTrack) insertConn(conn *conn) { - // Lock the buckets in the correct order. - tupleBucket := ct.bucket(conn.original.tupleID) - replyBucket := ct.bucket(conn.reply.tupleID) - ct.mu.RLock() - defer ct.mu.RUnlock() - if tupleBucket < replyBucket { - ct.buckets[tupleBucket].mu.Lock() - ct.buckets[replyBucket].mu.Lock() - } else if tupleBucket > replyBucket { - ct.buckets[replyBucket].mu.Lock() - ct.buckets[tupleBucket].mu.Lock() - } else { - // Both tuples are in the same bucket. - ct.buckets[tupleBucket].mu.Lock() - } - - // Now that we hold the locks, ensure the tuple hasn't been inserted by - // another thread. - // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? - alreadyInserted := false - for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { - if other.tupleID == conn.original.tupleID { - alreadyInserted = true - break +func (cn *conn) finalize() { + { + cn.mu.RLock() + finalized := cn.finalized + cn.mu.RUnlock() + if finalized { + return } } - if !alreadyInserted { - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + cn.mu.Lock() + finalized := cn.finalized + cn.finalized = true + cn.mu.Unlock() + if finalized { + return } - // Unlocking can happen in any order. - ct.buckets[tupleBucket].mu.Unlock() - if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce - } + cn.ct.finalize(cn) } -// handlePacket will manipulate the port and address of the packet if the -// connection exists. Returns whether, after the packet traverses the tables, -// it should create a new entry in the table. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { - if pkt.NatDone { - return false - } +// performNAT setups up the connection for the specified NAT. +// +// Generally, only the first packet of a connection reaches this method; other +// other packets will be manipulated without needing to modify the connection. +func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) { + cn.performNATIfNoop(port, address, dnat) + cn.handlePacket(pkt, hook, r) +} - switch hook { - case Prerouting, Input, Output, Postrouting: - default: - return false - } +func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) { + cn.mu.Lock() + defer cn.mu.Unlock() - transportHeader, ok := getTransportHeader(pkt) - if !ok { - return false + if cn.finalized { + return } - conn, dir := ct.connFor(pkt) - // Connection not found for the packet. - if conn == nil { - // If this is the last hook in the data path for this packet (Input if - // incoming, Postrouting if outgoing), indicate that a connection should be - // inserted by the end of this hook. - return hook == Input || hook == Postrouting + if dnat { + if cn.destinationManip { + return + } + cn.destinationManip = true + } else { + if cn.sourceManip { + return + } + cn.sourceManip = true } - netHeader := pkt.Network() - - // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be - // validated if checksum offloading is off. It may require IP defrag if the - // packets are fragmented. - - var newAddr tcpip.Address - var newPort uint16 + cn.reply.mu.Lock() + defer cn.reply.mu.Unlock() - updateSRCFields := false + if dnat { + cn.reply.tupleID.srcAddr = address + cn.reply.tupleID.srcPort = port + } else { + cn.reply.tupleID.dstAddr = address + cn.reply.tupleID.dstPort = port + } +} - switch hook { - case Prerouting, Output: - if conn.manip == manipDestination && dir == dirOriginal { - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - pkt.NatDone = true - } else if conn.manip == manipSource && dir == dirReply { - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr - pkt.NatDone = true - } - case Input, Postrouting: - if conn.manip == manipSource && dir == dirOriginal { - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - updateSRCFields = true - pkt.NatDone = true - } else if conn.manip == manipDestination && dir == dirReply { - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - updateSRCFields = true - pkt.NatDone = true - } - default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) { + if pkt.NatDone { + return } - if !pkt.NatDone { - return false + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return } fullChecksum := false updatePseudoHeader := false + dnat := false switch hook { case Prerouting: // Packet came from outside the stack so it must have a checksum set // already. fullChecksum = true updatePseudoHeader = true + + dnat = true case Input: - case Output, Postrouting: - // Calculate the TCP checksum and set it. + case Forward: + panic("should not handle packet in the forwarding hook") + case Output: + dnat = true + fallthrough + case Postrouting: if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { @@ -472,62 +444,73 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { updatePseudoHeader = true } default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) + panic(fmt.Sprintf("unrecognized hook = %d", hook)) } - rewritePacket( - netHeader, - transportHeader, - updateSRCFields, - fullChecksum, - updatePseudoHeader, - newPort, - newAddr, - ) + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. - // Update the state of tcb. - conn.mu.Lock() - defer conn.mu.Unlock() + dir := pkt.tuple.direction + tid, performManip := func() (tupleID, bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + var tuple *tuple + switch dir { + case dirOriginal: + if dnat { + if !cn.destinationManip { + return tupleID{}, false + } + } else if !cn.sourceManip { + return tupleID{}, false + } - // Mark the connection as having been used recently so it isn't reaped. - conn.lastUsed = time.Now() - // Update connection state. - conn.updateLocked(pkt, hook) + tuple = &cn.reply + case dirReply: + if dnat { + if !cn.sourceManip { + return tupleID{}, false + } + } else if !cn.destinationManip { + return tupleID{}, false + } - return false -} + tuple = &cn.original + default: + panic(fmt.Sprintf("unhandled dir = %d", dir)) + } -// maybeInsertNoop tries to insert a no-op connection entry to keep connections -// from getting clobbered when replies arrive. It only inserts if there isn't -// already a connection for pkt. -// -// This should be called after traversing iptables rules only, to ensure that -// pkt.NatDone is set correctly. -func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { - // If there were a rule applying to this packet, it would be marked - // with NatDone. - if pkt.NatDone { - return - } + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = time.Now() + // Update connection state. + cn.updateLocked(pkt, dir) - switch pkt.TransportProtocolNumber { - case header.TCPProtocolNumber, header.UDPProtocolNumber: - default: - // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable - // connections. + return tuple.id(), true + }() + if !performManip { return } - // This is the first packet we're seeing for the TCP connection. Insert - // the noop entry (an identity mapping) so that the response doesn't - // get NATed, breaking the connection. - tid, err := packetToTupleID(pkt) - if err != nil { - return + newPort := tid.dstPort + newAddr := tid.dstAddr + if dnat { + newPort = tid.srcPort + newAddr = tid.srcAddr } - conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(pkt, hook) - ct.insertConn(conn) + + rewritePacket( + pkt.Network(), + transportHeader, + !dnat, + fullChecksum, + updatePseudoHeader, + newPort, + newAddr, + ) + + pkt.NatDone = true } // bucket gets the conntrack bucket for a tupleID. @@ -579,14 +562,15 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim defer ct.mu.RUnlock() for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { idx = (i + start) % len(ct.buckets) - ct.buckets[idx].mu.Lock() - for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + bkt := &ct.buckets[idx] + bkt.mu.Lock() + for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() { checked++ - if ct.reapTupleLocked(tuple, idx, now) { + if ct.reapTupleLocked(tuple, idx, bkt, now) { expired++ } } - ct.buckets[idx].mu.Unlock() + bkt.mu.Unlock() } // We already checked buckets[idx]. idx++ @@ -611,41 +595,45 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // reapTupleLocked tries to remove tuple and its reply from the table. It // returns whether the tuple's connection has timed out. // -// Preconditions: -// * ct.mu is locked for reading. -// * bucket is locked. -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { +// Precondition: ct.mu is read locked and bkt.mu is write locked. +// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. +// +checklocks:ct.mu +// +checklocks:bkt.mu +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool { if !tuple.conn.timedOut(now) { return false } // To maintain lock order, we can only reap these tuples if the reply // appears later in the table. - replyBucket := ct.bucket(tuple.reply()) - if bucket > replyBucket { + replyBktID := ct.bucket(tuple.id().reply()) + if bktID > replyBktID { return true } // Don't re-lock if both tuples are in the same bucket. - differentBuckets := bucket != replyBucket - if differentBuckets { - ct.buckets[replyBucket].mu.Lock() + if bktID != replyBktID { + replyBkt := &ct.buckets[replyBktID] + replyBkt.mu.Lock() + removeConnFromBucket(replyBkt, tuple) + replyBkt.mu.Unlock() + } else { + removeConnFromBucket(bkt, tuple) } // We have the buckets locked and can remove both tuples. + bkt.tuples.Remove(tuple) + return true +} + +// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. +// +checklocks:b.mu +func removeConnFromBucket(b *bucket, tuple *tuple) { if tuple.direction == dirOriginal { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + b.tuples.Remove(&tuple.conn.reply) } else { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) - } - ct.buckets[bucket].tuples.Remove(tuple) - - // Don't re-unlock if both tuples are in the same bucket. - if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce + b.tuples.Remove(&tuple.conn.original) } - - return true } func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { @@ -659,14 +647,19 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ transProto: transProto, netProto: netProto, } - conn, _ := ct.connForTID(tid) - if conn == nil { + t := ct.connForTID(tid) + if t == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip != manipDestination { + } + + t.conn.mu.RLock() + defer t.conn.mu.RUnlock() + if !t.conn.destinationManip { // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } - return conn.original.dstAddr, conn.original.dstPort, nil + id := t.conn.original.id() + return id.dstAddr, id.dstPort, nil } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 74c9075b4..5808be685 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -271,7 +271,18 @@ const ( // // Precondition: The packet's network and transport header must be set. func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool { - return it.check(Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) + const hook = Prerouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { + pkt.tuple = t + t.conn.handlePacket(pkt, hook, nil /* route */) + } + + return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) } // CheckInput performs the input hook on the packet. @@ -281,7 +292,22 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp // // Precondition: The packet's network and transport header must be set. func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { - return it.check(Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) + const hook = Input + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + if t := pkt.tuple; t != nil { + t.conn.handlePacket(pkt, hook, nil /* route */) + } + + ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret } // CheckForward performs the forward hook on the packet. @@ -291,6 +317,9 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { // // Precondition: The packet's network and transport header must be set. func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool { + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) } @@ -301,7 +330,18 @@ func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string // // Precondition: The packet's network and transport header must be set. func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool { - return it.check(Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) + const hook = Output + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { + pkt.tuple = t + t.conn.handlePacket(pkt, hook, r) + } + + return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) } // CheckPostrouting performs the postrouting hook on the packet. @@ -310,8 +350,38 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) // must be dropped if false is returned. // // Precondition: The packet's network and transport header must be set. -func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool { - return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool { + const hook = Postrouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + if t := pkt.tuple; t != nil { + t.conn.handlePacket(pkt, hook, r) + } + + ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret +} + +func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool { + switch netProto { + case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber: + default: + // IPTables only supports IPv4/IPv6. + return true + } + + it.mu.RLock() + defer it.mu.RUnlock() + // Many users never configure iptables. Spare them the cost of rule + // traversal if rules have never been set. + return !it.modified } // check runs pkt through the rules for hook. It returns true when the packet @@ -320,20 +390,8 @@ func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName str // // Precondition: The packet's network and transport header must be set. func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool { - if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { - return true - } - // Many users never configure iptables. Spare them the cost of rule - // traversal if rules have never been set. it.mu.RLock() defer it.mu.RUnlock() - if !it.modified { - return true - } - - // Packets are manipulated only if connection and matching - // NAT rule exists. - shouldTrack := it.connections.handlePacket(pkt, hook, r) // Go through each table containing the hook. priorities := it.priorities[hook] @@ -361,7 +419,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v { + switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -377,21 +435,6 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr } } - // If this connection should be tracked, try to add an entry for it. If - // traversing the nat table didn't end in adding an entry, - // maybeInsertNoop will add a no-op entry for the connection. This is - // needeed when establishing connections so that the SYN/ACK reply to an - // outgoing SYN is delivered to the correct endpoint rather than being - // redirected by a prerouting rule. - // - // From the iptables documentation: "If there is no rule, a `null' - // binding is created: this usually does not map the packet, but exists - // to ensure we don't map another stream over an existing one." - if shouldTrack { - it.connections.maybeInsertNoop(pkt, hook) - } - - // Every table returned Accept. return true } @@ -431,7 +474,9 @@ func (it *IPTables) startReaper(interval time.Duration) { // // Precondition: The packets' network and transport header must be set. func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { - return it.checkPackets(Output, pkts, r, outNicName) + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckOutput(pkt, r, outNicName) + }) } // CheckPostroutingPackets performs the postrouting hook on the packets. @@ -439,21 +484,16 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa // Returns a map of packets that must be dropped. // // Precondition: The packets' network and transport header must be set. -func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { - return it.checkPackets(Postrouting, pkts, r, outNicName) +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckPostrouting(pkt, r, addressEP, outNicName) + }) } -// checkPackets runs pkts through the rules for hook and returns a map of -// packets that should not go forward. -// -// NOTE: unlike the Check API the returned map contains packets that should be -// dropped. -// -// Precondition: The packets' network and transport header must be set. -func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok { + if ok := f(pkt); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -543,7 +583,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, addressEP) + return rule.Target.Action(pkt, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go index 529e02a07..3d3c39c20 100644 --- a/pkg/tcpip/stack/iptables_state.go +++ b/pkg/tcpip/stack/iptables_state.go @@ -26,11 +26,15 @@ type unixTime struct { // saveLastUsed is invoked by stateify. func (cn *conn) saveLastUsed() unixTime { + cn.mu.Lock() + defer cn.mu.Unlock() return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} } // loadLastUsed is invoked by stateify. func (cn *conn) loadLastUsed(unix unixTime) { + cn.mu.Lock() + defer cn.mu.Unlock() cn.lastUsed = time.Unix(unix.second, unix.nano) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index e8806ebdb..85490e2d4 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,10 +79,49 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } +// DNATTarget modifies the destination port/IP of packets. +type DNATTarget struct { + // The new destination address for packets. + // + // Immutable. + Addr tcpip.Address + + // The new destination port for packets. + // + // Immutable. + Port uint16 + + // NetworkProtocol is the network protocol the target is used with. + // + // Immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Prerouting, Output: + case Input, Forward, Postrouting: + panic(fmt.Sprintf("%s not supported for DNAT", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + return natAction(pkt, hook, r, rt.Port, rt.Addr, true /* dnat */) + +} + // RedirectTarget redirects the packet to this machine by modifying the // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than @@ -97,7 +136,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -105,16 +144,6 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r rt.NetworkProtocol, pkt.NetworkProtocolNumber)) } - // Packet is already manipulated. - if pkt.NatDone { - return RuleAccept, 0 - } - - // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { - return RuleDrop, 0 - } - // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. var address tcpip.Address @@ -132,43 +161,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r panic("redirect target is supported only on output and prerouting hooks") } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - - if hook == Output { - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - udpHeader, - false, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - rt.Port, - address, - ) - } else { - udpHeader.SetDestinationPort(rt.Port) - } - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } - - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 - } - - return RuleAccept, 0 + return natAction(pkt, hook, r, rt.Port, address, true /* dnat */) } // SNATTarget modifies the source port/IP in the outgoing packets. @@ -181,15 +174,7 @@ type SNATTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { - // Sanity check. - if st.NetworkProtocol != pkt.NetworkProtocolNumber { - panic(fmt.Sprintf( - "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", - st.NetworkProtocol, pkt.NetworkProtocolNumber)) - } - +func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -200,6 +185,37 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleDrop, 0 } + t := pkt.tuple + if t == nil { + return RuleDrop, 0 + } + + // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a + // different port. + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + case header.TCPProtocolNumber: + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + default: + panic(fmt.Sprintf("unsupported transport protocol = %d", pkt.TransportProtocolNumber)) + } + } + + t.conn.performNAT(pkt, hook, r, port, address, dnat) + return RuleAccept, 0 +} + +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + switch hook { case Postrouting, Input: case Prerouting, Output, Forward: @@ -208,31 +224,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - port := st.Port + return natAction(pkt, hook, r, st.Port, st.Addr, false /* dnat */) +} - if port == 0 { - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - if port == 0 { - port = header.UDP(pkt.TransportHeader().View()).SourcePort() - } - case header.TCPProtocolNumber: - if port == 0 { - port = header.TCP(pkt.TransportHeader().View()).SourcePort() - } - } +// MasqueradeTarget modifies the source port/IP in the outgoing packets. +type MasqueradeTarget struct { + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if mt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + mt.NetworkProtocol, pkt.NetworkProtocolNumber)) } - // Set up conection for matching NAT rule. Only the first packet of the - // connection comes here. Other packets will be manipulated in connection - // tracking. - // - // Does nothing if the protocol does not support connection tracking. - if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) + switch hook { + case Postrouting: + case Prerouting, Input, Forward, Output: + panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) } - return RuleAccept, 0 + // addressEP is expected to be set for the postrouting hook. + ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */) + if ep == nil { + // No address exists that we can use as a source address. + return RuleDrop, 0 + } + + address := ep.AddressWithPrefix().Address + ep.DecRef() + return natAction(pkt, hook, r, 0 /* port */, address, false /* dnat */) } func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 976194124..b22024667 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -81,17 +81,6 @@ const ( // // +stateify savable type IPTables struct { - // mu protects v4Tables, v6Tables, and modified. - mu sync.RWMutex - // v4Tables and v6tables map tableIDs to tables. They hold builtin - // tables only, not user tables. mu must be locked for accessing. - v4Tables [NumTables]Table - v6Tables [NumTables]Table - // modified is whether tables have been modified at least once. It is - // used to elide the iptables performance overhead for workloads that - // don't utilize iptables. - modified bool - // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. @@ -101,6 +90,21 @@ type IPTables struct { // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} + + mu sync.RWMutex + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. + // + // +checklocks:mu + v4Tables [NumTables]Table + // +checklocks:mu + v6Tables [NumTables]Table + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + // + // +checklocks:mu + modified bool } // VisitTargets traverses all the targets of all tables and replaces each with @@ -352,5 +356,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) + Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index bf248ef20..888a8bd9d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -143,6 +143,8 @@ type PacketBuffer struct { // NetworkPacketInfo holds an incoming packet's network-layer information. NetworkPacketInfo NetworkPacketInfo + + tuple *tuple } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -302,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NICID: pk.NICID, RXTransportChecksumValidated: pk.RXTransportChecksumValidated, NetworkPacketInfo: pk.NetworkPacketInfo, + tuple: pk.tuple, } } @@ -329,13 +332,8 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { buf: pk.buf.Clone(), // Treat unfilled header portion as reserved. reserved: pk.AvailableHeaderBytes(), + tuple: pk.tuple, } - // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to - // maintain this flag in the packet. Currently conntrack needs this flag to - // tell if a noop connection should be inserted at Input hook. Once conntrack - // redefines the manipulation field as mutable, we won't need the special noop - // connection. - newPk.NatDone = pk.NatDone return newPk } @@ -367,12 +365,7 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu newPk.TransportProtocolNumber = pk.TransportProtocolNumber } - // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to - // maintain this flag in the packet. Currently conntrack needs this flag to - // tell if a noop connection should be inserted at Input hook. Once conntrack - // redefines the manipulation field as mutable, we won't need the special noop - // connection. - newPk.NatDone = pk.NatDone + newPk.tuple = pk.tuple return newPk } @@ -425,13 +418,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// DeleteFront removes count from the beginning of d. It panics if count > -// d.Size(). All backing storage references after the front of the d are -// invalidated. -func (d PacketData) DeleteFront(count int) { - if !d.pk.buf.Remove(d.pk.dataOffset(), count) { - panic("count > d.Size()") +// Consume is the same as PullUp except that is additionally consumes the +// returned bytes. Subsequent PullUp or Consume will not return these bytes. +func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) { + v, ok := d.PullUp(size) + if ok { + d.pk.consumed += size } + return v, ok } // CapLength reduces d to at most length bytes. diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 87b023445..c376ed1a1 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) { } } -func TestPacketBufferClone(t *testing.T) { - data := concatViews(makeView(20), makeView(30), makeView(40)) - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - bytesToDelete := 30 - originalSize := data.Size() - - clonedPks := []*PacketBuffer{ - pk.Clone(), - pk.CloneToInbound(), - } - pk.Data().DeleteFront(bytesToDelete) - if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { - t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) - } - for _, clonedPk := range clonedPks { - if got := clonedPk.Data().Size(); got != originalSize { - t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) - } - } -} - func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string @@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) { } }) - // DeleteFront + // Consume. for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) + v, ok := pkt.Data().Consume(n) + if !ok { + t.Fatalf("Consume failed") + } + if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) { + t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want) + } checkData(t, pkt, []byte(tc.data)[n:]) }) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index cd4137794..c23e91702 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().Consume(fakeNetHeaderLen) if !ok { return } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), // Nothing checks the error. nil, /* transport error */ pkt, diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index dc7289441..a941091b0 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -289,6 +289,12 @@ type TCPSenderState struct { // RACKState holds the state related to RACK loss detection algorithm. RACKState TCPRACKState + + // RetransmitTS records the timestamp used to detect spurious recovery. + RetransmitTS uint32 + + // SpuriousRecovery indicates if the sender entered recovery spuriously. + SpuriousRecovery bool } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. |