summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/conntrack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r--pkg/tcpip/stack/conntrack.go462
1 files changed, 386 insertions, 76 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index af9c325ca..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -15,9 +15,12 @@
package stack
import (
+ "encoding/binary"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
@@ -30,6 +33,10 @@ import (
//
// Currently, only TCP tracking is supported.
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
+
// Direction of the tuple.
type direction int
@@ -42,13 +49,19 @@ const (
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
// tuple holds a connection's identifying and manipulating data in one
// direction. It is immutable.
+//
+// +stateify savable
type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
+
tupleID
// conn is the connection tracking entry this tuple belongs to.
@@ -61,6 +74,8 @@ type tuple struct {
// tupleID uniquely identifies a connection in one direction. It currently
// contains enough information to distinguish between any TCP or UDP
// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
type tupleID struct {
srcAddr tcpip.Address
srcPort uint16
@@ -83,6 +98,8 @@ func (ti tupleID) reply() tupleID {
}
// conn is a tracked connection.
+//
+// +stateify savable
type conn struct {
// original is the tuple in original direction. It is immutable.
original tuple
@@ -97,36 +114,98 @@ type conn struct {
// update the state of tcb. It is immutable.
tcbHook Hook
- // mu protects tcb.
- mu sync.Mutex
-
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
// tcb is TCB control block. It is used to keep track of states
// of tcp connection and is protected by mu.
tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
+}
+
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
+
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
}
// ConnTrack tracks all connections created for NAT rules. Most users are
-// expected to only call handlePacket and createConnFor.
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
type ConnTrack struct {
- // mu protects conns.
- mu sync.RWMutex
+ // seed is a one-time random value initialized at stack startup
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
- // conns maintains a map of tuples needed for connection tracking for
- // iptables NAT rules. It is protected by mu.
- conns map[tupleID]tuple
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
+
+ // buckets is protected by mu.
+ buckets []bucket
+}
+
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
}
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader)
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
- tcpHeader := header.TCP(pkt.TransportHeader)
- if tcpHeader == nil {
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
}
@@ -136,15 +215,16 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
// newConn creates new connection.
func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
conn := conn{
- manip: manip,
- tcbHook: hook,
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
}
conn.original = tuple{conn: &conn, tupleID: orig}
conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
@@ -161,19 +241,35 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
if err != nil {
return nil, dirOriginal
}
+ return ct.connForTID(tid)
+}
- ct.mu.Lock()
- defer ct.mu.Unlock()
-
- tuple, ok := ct.conns[tid]
- if !ok {
- return nil, dirOriginal
+func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
}
- return tuple.conn, tuple.direction
+
+ return nil, dirOriginal
}
-// createConnFor creates a new conn for pkt.
-func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -186,8 +282,8 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.MinIP
- replyTID.srcPort = rt.MinPort
+ replyTID.srcAddr = rt.Addr
+ replyTID.srcPort = rt.Port
var manip manipType
switch hook {
case Prerouting:
@@ -196,23 +292,61 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
manip = manipDstOutput
}
conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
+}
- // Add the changed tuple to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked
- // list.
- ct.mu.Lock()
- defer ct.mu.Unlock()
- ct.conns[tid] = conn.original
- ct.conns[replyTID] = conn.reply
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
+ }
- return conn
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
+
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ }
+
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := pkt.Network()
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
// have their destinations modified and replies have their sources
@@ -228,14 +362,28 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
netHeader.SetSourceAddress(conn.original.dstAddr)
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := pkt.Network()
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
// have their destinations modified and replies have their sources
@@ -253,8 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- hdr := &pkt.Header
- length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -263,25 +410,39 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
if conn == nil {
- // Connection not found for the packet or the packet is invalid.
- return
+ return true
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
+ return false
}
switch hook {
@@ -297,35 +458,184 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
// other tcp states.
conn.mu.Lock()
defer conn.mu.Unlock()
- var st tcpconntrack.Result
- tcpHeader := header.TCP(pkt.TransportHeader)
- if conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else {
- switch hook {
- case conn.tcbHook:
- st = conn.tcb.UpdateStateOutbound(tcpHeader)
- default:
- st = conn.tcb.UpdateStateInbound(tcpHeader)
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
+ }
+
+ // We only track TCP connections.
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ return
+ }
+
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return
+ }
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ ct.insertConn(conn)
+}
+
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
+
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
}
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
+
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions:
+// * ct.mu is locked for reading.
+// * bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
+
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
}
+ ct.buckets[bucket].tuples.Remove(tuple)
- // Delete conn if tcp connection is closed.
- if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConn(conn)
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
}
+
+ return true
}
-// deleteConn deletes the connection.
-func (ct *ConnTrack) deleteConn(conn *conn) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
+ // Lookup the connection. The reply's original destination
+ // describes the original address.
+ tid := tupleID{
+ srcAddr: epID.LocalAddress,
+ srcPort: epID.LocalPort,
+ dstAddr: epID.RemoteAddress,
+ dstPort: epID.RemotePort,
+ transProto: header.TCPProtocolNumber,
+ netProto: netProto,
+ }
+ conn, _ := ct.connForTID(tid)
if conn == nil {
- return
+ // Not a tracked connection.
+ return "", 0, tcpip.ErrNotConnected
+ } else if conn.manip == manipNone {
+ // Unmanipulated connection.
+ return "", 0, tcpip.ErrInvalidOptionValue
}
- ct.mu.Lock()
- defer ct.mu.Unlock()
-
- delete(ct.conns, conn.original.tupleID)
- delete(ct.conns, conn.reply.tupleID)
+ return conn.original.dstAddr, conn.original.dstPort, nil
}