summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/conntrack.go
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2020-06-25 21:20:29 -0700
committergVisor bot <gvisor-bot@google.com>2020-06-25 21:21:57 -0700
commit7fb6cc286fffab2f408a5dbc228e0db706104682 (patch)
treeb34e3fbc648f2b317f74295ef9c4ec3837d65f54 /pkg/tcpip/stack/conntrack.go
parent4069461877d843654d18db74a5962b332f1226aa (diff)
conntrack refactor, no behavior changes
- Split connTrackForPacket into 2 functions instead of switching on flag - Replace hash with struct keys. - Remove prefixes where possible - Remove unused connStatus, timeout - Flatten ConnTrack struct a bit - some intermediate structs had no meaning outside of the context of their parent. - Protect conn.tcb with a mutex - Remove redundant error checking (e.g. when is pkt.NetworkHeader valid) - Clarify that HandlePacket and CreateConnFor are the expected entrypoints for ConnTrack PiperOrigin-RevId: 318407168
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r--pkg/tcpip/stack/conntrack.go405
1 files changed, 151 insertions, 254 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 05bf62788..af9c325ca 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -15,39 +15,29 @@
package stack
import (
- "encoding/binary"
"sync"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
// Connection tracking is used to track and manipulate packets for NAT rules.
-// The connection is created for a packet if it does not exist. Every connection
-// contains two tuples (original and reply). The tuples are manipulated if there
-// is a matching NAT rule. The packet is modified by looking at the tuples in the
-// Prerouting and Output hooks.
+// The connection is created for a packet if it does not exist. Every
+// connection contains two tuples (original and reply). The tuples are
+// manipulated if there is a matching NAT rule. The packet is modified by
+// looking at the tuples in the Prerouting and Output hooks.
+//
+// Currently, only TCP tracking is supported.
// Direction of the tuple.
-type ctDirection int
+type direction int
const (
- dirOriginal ctDirection = iota
+ dirOriginal direction = iota
dirReply
)
-// Status of connection.
-// TODO(gvisor.dev/issue/170): Add other states of connection.
-type connStatus int
-
-const (
- connNew connStatus = iota
- connEstablished
-)
-
// Manipulation type for the connection.
type manipType int
@@ -56,250 +46,171 @@ const (
manipDstOutput
)
-// connTrackMutable is the manipulatable part of the tuple.
-type connTrackMutable struct {
- // addr is source address of the tuple.
- addr tcpip.Address
+// tuple holds a connection's identifying and manipulating data in one
+// direction. It is immutable.
+type tuple struct {
+ tupleID
- // port is source port of the tuple.
- port uint16
+ // conn is the connection tracking entry this tuple belongs to.
+ conn *conn
- // protocol is network layer protocol.
- protocol tcpip.NetworkProtocolNumber
+ // direction is the direction of the tuple.
+ direction direction
}
-// connTrackImmutable is the non-manipulatable part of the tuple.
-type connTrackImmutable struct {
- // addr is destination address of the tuple.
- addr tcpip.Address
-
- // direction is direction (original or reply) of the tuple.
- direction ctDirection
-
- // port is destination port of the tuple.
- port uint16
-
- // protocol is transport layer protocol.
- protocol tcpip.TransportProtocolNumber
+// tupleID uniquely identifies a connection in one direction. It currently
+// contains enough information to distinguish between any TCP or UDP
+// connection, and will need to be extended to support other protocols.
+type tupleID struct {
+ srcAddr tcpip.Address
+ srcPort uint16
+ dstAddr tcpip.Address
+ dstPort uint16
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
}
-// connTrackTuple represents the tuple which is created from the
-// packet.
-type connTrackTuple struct {
- // dst is non-manipulatable part of the tuple.
- dst connTrackImmutable
-
- // src is manipulatable part of the tuple.
- src connTrackMutable
+// reply creates the reply tupleID.
+func (ti tupleID) reply() tupleID {
+ return tupleID{
+ srcAddr: ti.dstAddr,
+ srcPort: ti.dstPort,
+ dstAddr: ti.srcAddr,
+ dstPort: ti.srcPort,
+ transProto: ti.transProto,
+ netProto: ti.netProto,
+ }
}
-// connTrackTupleHolder is the container of tuple and connection.
-type ConnTrackTupleHolder struct {
- // conn is pointer to the connection tracking entry.
- conn *connTrack
-
- // tuple is original or reply tuple.
- tuple connTrackTuple
-}
+// conn is a tracked connection.
+type conn struct {
+ // original is the tuple in original direction. It is immutable.
+ original tuple
-// connTrack is the connection.
-type connTrack struct {
- // originalTupleHolder contains tuple in original direction.
- originalTupleHolder ConnTrackTupleHolder
+ // reply is the tuple in reply direction. It is immutable.
+ reply tuple
- // replyTupleHolder contains tuple in reply direction.
- replyTupleHolder ConnTrackTupleHolder
+ // manip indicates if the packet should be manipulated. It is immutable.
+ manip manipType
- // status indicates connection is new or established.
- status connStatus
+ // tcbHook indicates if the packet is inbound or outbound to
+ // update the state of tcb. It is immutable.
+ tcbHook Hook
- // timeout indicates the time connection should be active.
- timeout time.Duration
-
- // manip indicates if the packet should be manipulated.
- manip manipType
+ // mu protects tcb.
+ mu sync.Mutex
// tcb is TCB control block. It is used to keep track of states
- // of tcp connection.
+ // of tcp connection and is protected by mu.
tcb tcpconntrack.TCB
-
- // tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb.
- tcbHook Hook
}
-// ConnTrackTable contains a map of all existing connections created for
-// NAT rules.
-type ConnTrackTable struct {
- // connMu protects connTrackTable.
- connMu sync.RWMutex
-
- // connTrackTable maintains a map of tuples needed for connection tracking
- // for iptables NAT rules. The key for the map is an integer calculated
- // using seed, source address, destination address, source port and
- // destination port.
- CtMap map[uint32]ConnTrackTupleHolder
-
- // seed is a one-time random value initialized at stack startup
- // and is used in calculation of hash key for connection tracking
- // table.
- Seed uint32
-}
+// ConnTrack tracks all connections created for NAT rules. Most users are
+// expected to only call handlePacket and createConnFor.
+type ConnTrack struct {
+ // mu protects conns.
+ mu sync.RWMutex
-// packetToTuple converts packet to a tuple in original direction.
-func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
- var tuple connTrackTuple
+ // conns maintains a map of tuples needed for connection tracking for
+ // iptables NAT rules. It is protected by mu.
+ conns map[tupleID]tuple
+}
- netHeader := header.IPv4(pkt.NetworkHeader)
+// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
+// TCP header.
+func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
// TODO(gvisor.dev/issue/170): Need to support for other
// protocols as well.
+ netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
tcpHeader := header.TCP(pkt.TransportHeader)
if tcpHeader == nil {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
- tuple.src.addr = netHeader.SourceAddress()
- tuple.src.port = tcpHeader.SourcePort()
- tuple.src.protocol = header.IPv4ProtocolNumber
-
- tuple.dst.addr = netHeader.DestinationAddress()
- tuple.dst.port = tcpHeader.DestinationPort()
- tuple.dst.protocol = netHeader.TransportProtocol()
-
- return tuple, nil
-}
-
-// getReplyTuple creates reply tuple for the given tuple.
-func getReplyTuple(tuple connTrackTuple) connTrackTuple {
- var replyTuple connTrackTuple
- replyTuple.src.addr = tuple.dst.addr
- replyTuple.src.port = tuple.dst.port
- replyTuple.src.protocol = tuple.src.protocol
- replyTuple.dst.addr = tuple.src.addr
- replyTuple.dst.port = tuple.src.port
- replyTuple.dst.protocol = tuple.dst.protocol
- replyTuple.dst.direction = dirReply
-
- return replyTuple
+ return tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: tcpHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: tcpHeader.DestinationPort(),
+ transProto: netHeader.TransportProtocol(),
+ netProto: header.IPv4ProtocolNumber,
+ }, nil
}
-// makeNewConn creates new connection.
-func makeNewConn(tuple, replyTuple connTrackTuple) connTrack {
- var conn connTrack
- conn.status = connNew
- conn.originalTupleHolder.tuple = tuple
- conn.originalTupleHolder.conn = &conn
- conn.replyTupleHolder.tuple = replyTuple
- conn.replyTupleHolder.conn = &conn
-
- return conn
-}
-
-// getTupleHash returns hash of the tuple. The fields used for
-// generating hash are seed (generated once for stack), source address,
-// destination address, source port and destination ports.
-func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
- h := jenkins.Sum32(ct.Seed)
- h.Write([]byte(tuple.src.addr))
- h.Write([]byte(tuple.dst.addr))
- portBuf := make([]byte, 2)
- binary.LittleEndian.PutUint16(portBuf, tuple.src.port)
- h.Write([]byte(portBuf))
- binary.LittleEndian.PutUint16(portBuf, tuple.dst.port)
- h.Write([]byte(portBuf))
-
- return h.Sum32()
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
}
-// connTrackForPacket returns connTrack for packet.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
-// transport protocols.
-func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
- var dir ctDirection
- tuple, err := packetToTuple(pkt, hook)
+// connFor gets the conn for pkt if it exists, or returns nil
+// if it does not. It returns an error when pkt does not contain a valid TCP
+// header.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
+// other transport protocols.
+func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
+ tid, err := packetToTupleID(pkt)
if err != nil {
- return nil, dir
+ return nil, dirOriginal
}
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
-
- connTrackTable := ct.CtMap
- hash := ct.getTupleHash(tuple)
-
- var conn *connTrack
- switch createConn {
- case true:
- // If connection does not exist for the hash, create a new
- // connection.
- replyTuple := getReplyTuple(tuple)
- replyHash := ct.getTupleHash(replyTuple)
- newConn := makeNewConn(tuple, replyTuple)
- conn = &newConn
-
- // Add tupleHolders to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked list.
- ct.CtMap[hash] = conn.originalTupleHolder
- ct.CtMap[replyHash] = conn.replyTupleHolder
- default:
- tupleHolder, ok := connTrackTable[hash]
- if !ok {
- return nil, dir
- }
-
- // If this is the reply of new connection, set the connection
- // status as ESTABLISHED.
- conn = tupleHolder.conn
- if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply {
- conn.status = connEstablished
- }
- if tupleHolder.conn == nil {
- panic("tupleHolder has null connection tracking entry")
- }
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
- dir = tupleHolder.tuple.dst.direction
+ tuple, ok := ct.conns[tid]
+ if !ok {
+ return nil, dirOriginal
}
- return conn, dir
+ return tuple.conn, tuple.direction
}
-// SetNatInfo will manipulate the tuples according to iptables NAT rules.
-func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) {
- // Get the connection. Connection is always created before this
- // function is called.
- conn, _ := ct.connTrackForPacket(pkt, hook, false)
- if conn == nil {
- panic("connection should be created to manipulate tuples.")
+// createConnFor creates a new conn for pkt.
+func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Prerouting && hook != Output {
+ return nil
}
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
-
- // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to
- // support changing of address for Prerouting.
- // Change the port as per the iptables rule. This tuple will be used
- // to manipulate the packet in HandlePacket.
- conn.replyTupleHolder.tuple.src.addr = rt.MinIP
- conn.replyTupleHolder.tuple.src.port = rt.MinPort
- newHash := ct.getTupleHash(conn.replyTupleHolder.tuple)
+ // Create a new connection and change the port as per the iptables
+ // rule. This tuple will be used to manipulate the packet in
+ // handlePacket.
+ replyTID := tid.reply()
+ replyTID.srcAddr = rt.MinIP
+ replyTID.srcPort = rt.MinPort
+ var manip manipType
+ switch hook {
+ case Prerouting:
+ manip = manipDstPrerouting
+ case Output:
+ manip = manipDstOutput
+ }
+ conn := newConn(tid, replyTID, manip, hook)
// Add the changed tuple to the map.
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
- ct.CtMap[newHash] = conn.replyTupleHolder
- if hook == Output {
- conn.replyTupleHolder.conn.manip = manipDstOutput
- }
+ // TODO(gvisor.dev/issue/170): Need to support collisions using linked
+ // list.
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+ ct.conns[tid] = conn.original
+ ct.conns[replyTID] = conn.reply
- // Delete the old tuple.
- delete(ct.CtMap, replyHash)
+ return conn
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook..
-func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) {
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
+func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -308,13 +219,13 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection)
// modified.
switch dir {
case dirOriginal:
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
case dirReply:
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
netHeader.SetChecksum(0)
@@ -322,7 +233,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection)
}
// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) {
+func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -331,13 +242,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
// modified. For prerouting redirection, we only reach this point
// when replying, so packet sources are modified.
if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
} else {
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
// Calculate the TCP checksum and set it.
@@ -356,9 +267,9 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
-// HandlePacket will manipulate the port and address of the packet if the
+// handlePacket will manipulate the port and address of the packet if the
// connection exists.
-func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
if pkt.NatDone {
return
}
@@ -367,21 +278,9 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r
return
}
- conn, dir := ct.connTrackForPacket(pkt, hook, false)
- // Connection or Rule not found for the packet.
+ conn, dir := ct.connFor(pkt)
if conn == nil {
- return
- }
-
- netHeader := header.IPv4(pkt.NetworkHeader)
- // TODO(gvisor.dev/issue/170): Need to support for other transport
- // protocols as well.
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader)
- if tcpHeader == nil {
+ // Connection not found for the packet or the packet is invalid.
return
}
@@ -396,7 +295,10 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
// other tcp states.
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
var st tcpconntrack.Result
+ tcpHeader := header.TCP(pkt.TransportHeader)
if conn.tcb.IsEmpty() {
conn.tcb.Init(tcpHeader)
conn.tcbHook = hook
@@ -409,26 +311,21 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r
}
}
- // Delete conntrack if tcp connection is closed.
+ // Delete conn if tcp connection is closed.
if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConnTrack(conn)
+ ct.deleteConn(conn)
}
}
-// deleteConnTrack deletes the connection.
-func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) {
+// deleteConn deletes the connection.
+func (ct *ConnTrack) deleteConn(conn *conn) {
if conn == nil {
return
}
- tuple := conn.originalTupleHolder.tuple
- hash := ct.getTupleHash(tuple)
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
-
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
- delete(ct.CtMap, hash)
- delete(ct.CtMap, replyHash)
+ delete(ct.conns, conn.original.tupleID)
+ delete(ct.conns, conn.reply.tupleID)
}