summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/conntrack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r--pkg/tcpip/stack/conntrack.go777
1 files changed, 464 insertions, 313 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 7d1ede1f2..7dd344b4f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -20,376 +20,383 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
// Connection tracking is used to track and manipulate packets for NAT rules.
-// The connection is created for a packet if it does not exist. Every connection
-// contains two tuples (original and reply). The tuples are manipulated if there
-// is a matching NAT rule. The packet is modified by looking at the tuples in the
-// Prerouting and Output hooks.
+// The connection is created for a packet if it does not exist. Every
+// connection contains two tuples (original and reply). The tuples are
+// manipulated if there is a matching NAT rule. The packet is modified by
+// looking at the tuples in the Prerouting and Output hooks.
+//
+// Currently, only TCP tracking is supported.
+
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
// Direction of the tuple.
-type ctDirection int
+type direction int
const (
- dirOriginal ctDirection = iota
+ dirOriginal direction = iota
dirReply
)
-// Status of connection.
-// TODO(gvisor.dev/issue/170): Add other states of connection.
-type connStatus int
-
-const (
- connNew connStatus = iota
- connEstablished
-)
-
// Manipulation type for the connection.
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
-// connTrackMutable is the manipulatable part of the tuple.
-type connTrackMutable struct {
- // addr is source address of the tuple.
- addr tcpip.Address
-
- // port is source port of the tuple.
- port uint16
-
- // protocol is network layer protocol.
- protocol tcpip.NetworkProtocolNumber
-}
-
-// connTrackImmutable is the non-manipulatable part of the tuple.
-type connTrackImmutable struct {
- // addr is destination address of the tuple.
- addr tcpip.Address
+// tuple holds a connection's identifying and manipulating data in one
+// direction. It is immutable.
+//
+// +stateify savable
+type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
- // direction is direction (original or reply) of the tuple.
- direction ctDirection
+ tupleID
- // port is destination port of the tuple.
- port uint16
+ // conn is the connection tracking entry this tuple belongs to.
+ conn *conn
- // protocol is transport layer protocol.
- protocol tcpip.TransportProtocolNumber
+ // direction is the direction of the tuple.
+ direction direction
}
-// connTrackTuple represents the tuple which is created from the
-// packet.
-type connTrackTuple struct {
- // dst is non-manipulatable part of the tuple.
- dst connTrackImmutable
-
- // src is manipulatable part of the tuple.
- src connTrackMutable
+// tupleID uniquely identifies a connection in one direction. It currently
+// contains enough information to distinguish between any TCP or UDP
+// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
+type tupleID struct {
+ srcAddr tcpip.Address
+ srcPort uint16
+ dstAddr tcpip.Address
+ dstPort uint16
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
}
-// connTrackTupleHolder is the container of tuple and connection.
-type ConnTrackTupleHolder struct {
- // conn is pointer to the connection tracking entry.
- conn *connTrack
-
- // tuple is original or reply tuple.
- tuple connTrackTuple
+// reply creates the reply tupleID.
+func (ti tupleID) reply() tupleID {
+ return tupleID{
+ srcAddr: ti.dstAddr,
+ srcPort: ti.dstPort,
+ dstAddr: ti.srcAddr,
+ dstPort: ti.srcPort,
+ transProto: ti.transProto,
+ netProto: ti.netProto,
+ }
}
-// connTrack is the connection.
-type connTrack struct {
- // originalTupleHolder contains tuple in original direction.
- originalTupleHolder ConnTrackTupleHolder
-
- // replyTupleHolder contains tuple in reply direction.
- replyTupleHolder ConnTrackTupleHolder
-
- // status indicates connection is new or established.
- status connStatus
+// conn is a tracked connection.
+//
+// +stateify savable
+type conn struct {
+ // original is the tuple in original direction. It is immutable.
+ original tuple
- // timeout indicates the time connection should be active.
- timeout time.Duration
+ // reply is the tuple in reply direction. It is immutable.
+ reply tuple
- // manip indicates if the packet should be manipulated.
+ // manip indicates if the packet should be manipulated. It is immutable.
manip manipType
- // tcb is TCB control block. It is used to keep track of states
- // of tcp connection.
- tcb tcpconntrack.TCB
-
// tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb.
+ // update the state of tcb. It is immutable.
tcbHook Hook
-}
-// ConnTrackTable contains a map of all existing connections created for
-// NAT rules.
-type ConnTrackTable struct {
- // connMu protects connTrackTable.
- connMu sync.RWMutex
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
+ // tcb is TCB control block. It is used to keep track of states
+ // of tcp connection and is protected by mu.
+ tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
+}
- // connTrackTable maintains a map of tuples needed for connection tracking
- // for iptables NAT rules. The key for the map is an integer calculated
- // using seed, source address, destination address, source port and
- // destination port.
- CtMap map[uint32]ConnTrackTupleHolder
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
- // seed is a one-time random value initialized at stack startup
- // and is used in calculation of hash key for connection tracking
- // table.
- Seed uint32
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
}
-// parseHeaders sets headers in the packet.
-func parseHeaders(pkt *PacketBuffer) {
- newPkt := pkt.Clone()
+// ConnTrack tracks all connections created for NAT rules. Most users are
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
+type ConnTrack struct {
+ // seed is a one-time random value initialized at stack startup
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
- // Set network header.
- hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return
- }
- netHeader := header.IPv4(hdr)
- newPkt.NetworkHeader = hdr
- length := int(netHeader.HeaderLength())
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- // Set transport header.
- switch protocol := netHeader.TransportProtocol(); protocol {
- case header.UDPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.UDP(h[length:]))
- }
- case header.TCPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.TCP(h[length:]))
- }
- }
- pkt.NetworkHeader = newPkt.NetworkHeader
- pkt.TransportHeader = newPkt.TransportHeader
+ // buckets is protected by mu.
+ buckets []bucket
}
-// packetToTuple converts packet to a tuple in original direction.
-func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
- var tuple connTrackTuple
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
+}
- netHeader := header.IPv4(pkt.NetworkHeader)
+// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
+// TCP header.
+func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
// TODO(gvisor.dev/issue/170): Need to support for other
// protocols as well.
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tuple, tcpip.ErrUnknownProtocol
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
- tcpHeader := header.TCP(pkt.TransportHeader)
- if tcpHeader == nil {
- return tuple, tcpip.ErrUnknownProtocol
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
- tuple.src.addr = netHeader.SourceAddress()
- tuple.src.port = tcpHeader.SourcePort()
- tuple.src.protocol = header.IPv4ProtocolNumber
-
- tuple.dst.addr = netHeader.DestinationAddress()
- tuple.dst.port = tcpHeader.DestinationPort()
- tuple.dst.protocol = netHeader.TransportProtocol()
-
- return tuple, nil
-}
-
-// getReplyTuple creates reply tuple for the given tuple.
-func getReplyTuple(tuple connTrackTuple) connTrackTuple {
- var replyTuple connTrackTuple
- replyTuple.src.addr = tuple.dst.addr
- replyTuple.src.port = tuple.dst.port
- replyTuple.src.protocol = tuple.src.protocol
- replyTuple.dst.addr = tuple.src.addr
- replyTuple.dst.port = tuple.src.port
- replyTuple.dst.protocol = tuple.dst.protocol
- replyTuple.dst.direction = dirReply
-
- return replyTuple
+ return tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: tcpHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: tcpHeader.DestinationPort(),
+ transProto: netHeader.TransportProtocol(),
+ netProto: header.IPv4ProtocolNumber,
+ }, nil
}
-// makeNewConn creates new connection.
-func makeNewConn(tuple, replyTuple connTrackTuple) connTrack {
- var conn connTrack
- conn.status = connNew
- conn.originalTupleHolder.tuple = tuple
- conn.originalTupleHolder.conn = &conn
- conn.replyTupleHolder.tuple = replyTuple
- conn.replyTupleHolder.conn = &conn
-
- return conn
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
}
-// getTupleHash returns hash of the tuple. The fields used for
-// generating hash are seed (generated once for stack), source address,
-// destination address, source port and destination ports.
-func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
- h := jenkins.Sum32(ct.Seed)
- h.Write([]byte(tuple.src.addr))
- h.Write([]byte(tuple.dst.addr))
- portBuf := make([]byte, 2)
- binary.LittleEndian.PutUint16(portBuf, tuple.src.port)
- h.Write([]byte(portBuf))
- binary.LittleEndian.PutUint16(portBuf, tuple.dst.port)
- h.Write([]byte(portBuf))
-
- return h.Sum32()
+// connFor gets the conn for pkt if it exists, or returns nil
+// if it does not. It returns an error when pkt does not contain a valid TCP
+// header.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
+// other transport protocols.
+func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil, dirOriginal
+ }
+ return ct.connForTID(tid)
}
-// connTrackForPacket returns connTrack for packet.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
-// transport protocols.
-func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
- if hook == Prerouting {
- // Headers will not be set in Prerouting.
- // TODO(gvisor.dev/issue/170): Change this after parsing headers
- // code is added.
- parseHeaders(pkt)
+func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
}
- var dir ctDirection
- tuple, err := packetToTuple(*pkt, hook)
- if err != nil {
- return nil, dir
- }
-
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
-
- connTrackTable := ct.CtMap
- hash := ct.getTupleHash(tuple)
-
- var conn *connTrack
- switch createConn {
- case true:
- // If connection does not exist for the hash, create a new
- // connection.
- replyTuple := getReplyTuple(tuple)
- replyHash := ct.getTupleHash(replyTuple)
- newConn := makeNewConn(tuple, replyTuple)
- conn = &newConn
-
- // Add tupleHolders to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked list.
- ct.CtMap[hash] = conn.originalTupleHolder
- ct.CtMap[replyHash] = conn.replyTupleHolder
- default:
- tupleHolder, ok := connTrackTable[hash]
- if !ok {
- return nil, dir
- }
+ return nil, dirOriginal
+}
- // If this is the reply of new connection, set the connection
- // status as ESTABLISHED.
- conn = tupleHolder.conn
- if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply {
- conn.status = connEstablished
- }
- if tupleHolder.conn == nil {
- panic("tupleHolder has null connection tracking entry")
- }
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Prerouting && hook != Output {
+ return nil
+ }
- dir = tupleHolder.tuple.dst.direction
+ // Create a new connection and change the port as per the iptables
+ // rule. This tuple will be used to manipulate the packet in
+ // handlePacket.
+ replyTID := tid.reply()
+ replyTID.srcAddr = rt.MinIP
+ replyTID.srcPort = rt.MinPort
+ var manip manipType
+ switch hook {
+ case Prerouting:
+ manip = manipDstPrerouting
+ case Output:
+ manip = manipDstOutput
}
- return conn, dir
+ conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
}
-// SetNatInfo will manipulate the tuples according to iptables NAT rules.
-func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) {
- // Get the connection. Connection is always created before this
- // function is called.
- conn, _ := ct.connTrackForPacket(pkt, hook, false)
- if conn == nil {
- panic("connection should be created to manipulate tuples.")
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
}
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
- // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to
- // support changing of address for Prerouting.
-
- // Change the port as per the iptables rule. This tuple will be used
- // to manipulate the packet in HandlePacket.
- conn.replyTupleHolder.tuple.src.addr = rt.MinIP
- conn.replyTupleHolder.tuple.src.port = rt.MinPort
- newHash := ct.getTupleHash(conn.replyTupleHolder.tuple)
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
- // Add the changed tuple to the map.
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
- ct.CtMap[newHash] = conn.replyTupleHolder
- if hook == Output {
- conn.replyTupleHolder.conn.manip = manipDstOutput
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
}
- // Delete the old tuple.
- delete(ct.CtMap, replyHash)
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook..
-func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) {
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
+func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
// have their destinations modified and replies have their sources
// modified.
switch dir {
case dirOriginal:
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
case dirReply:
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) {
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
// have their destinations modified and replies have their sources
// modified. For prerouting redirection, we only reach this point
// when replying, so packet sources are modified.
if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
} else {
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- hdr := &pkt.Header
- length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -402,33 +409,32 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
-// HandlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// handlePacket will manipulate the port and address of the packet if the
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
}
- conn, dir := ct.connTrackForPacket(pkt, hook, false)
- // Connection or Rule not found for the packet.
- if conn == nil {
- return
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
- netHeader := header.IPv4(pkt.NetworkHeader)
- // TODO(gvisor.dev/issue/170): Need to support for other transport
- // protocols as well.
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return
+ conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
+ if conn == nil {
+ return true
}
- tcpHeader := header.TCP(pkt.TransportHeader)
- if tcpHeader == nil {
- return
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
+ return false
}
switch hook {
@@ -442,39 +448,184 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
// other tcp states.
- var st tcpconntrack.Result
- if conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else {
- switch hook {
- case conn.tcbHook:
- st = conn.tcb.UpdateStateOutbound(tcpHeader)
- default:
- st = conn.tcb.UpdateStateInbound(tcpHeader)
- }
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
}
- // Delete conntrack if tcp connection is closed.
- if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConnTrack(conn)
+ // We only track TCP connections.
+ if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ return
}
-}
-// deleteConnTrack deletes the connection.
-func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) {
- if conn == nil {
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ ct.insertConn(conn)
+}
- tuple := conn.originalTupleHolder.tuple
- hash := ct.getTupleHash(tuple)
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
+ }
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
+
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions: ct.mu is locked for reading and bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
+
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
+ }
+ ct.buckets[bucket].tuples.Remove(tuple)
+
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
+
+ return true
+}
+
+func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ // Lookup the connection. The reply's original destination
+ // describes the original address.
+ tid := tupleID{
+ srcAddr: epID.LocalAddress,
+ srcPort: epID.LocalPort,
+ dstAddr: epID.RemoteAddress,
+ dstPort: epID.RemotePort,
+ transProto: header.TCPProtocolNumber,
+ netProto: header.IPv4ProtocolNumber,
+ }
+ conn, _ := ct.connForTID(tid)
+ if conn == nil {
+ // Not a tracked connection.
+ return "", 0, tcpip.ErrNotConnected
+ } else if conn.manip == manipNone {
+ // Unmanipulated connection.
+ return "", 0, tcpip.ErrInvalidOptionValue
+ }
- delete(ct.CtMap, hash)
- delete(ct.CtMap, replyHash)
+ return conn.original.dstAddr, conn.original.dstPort, nil
}