diff options
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 234 |
1 files changed, 119 insertions, 115 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 3f083928f..41e964cf3 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -16,6 +16,7 @@ package stack import ( "encoding/binary" + "fmt" "sync" "time" @@ -29,7 +30,7 @@ import ( // The connection is created for a packet if it does not exist. Every // connection contains two tuples (original and reply). The tuples are // manipulated if there is a matching NAT rule. The packet is modified by -// looking at the tuples in the Prerouting and Output hooks. +// looking at the tuples in each hook. // // Currently, only TCP tracking is supported. @@ -46,12 +47,14 @@ const ( ) // Manipulation type for the connection. +// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and +// DNAT at the same time. type manipType int const ( manipNone manipType = iota - manipDstPrerouting - manipDstOutput + manipSource + manipDestination ) // tuple holds a connection's identifying and manipulating data in one @@ -108,6 +111,7 @@ type conn struct { reply tuple // manip indicates if the packet should be manipulated. It is immutable. + // TODO(gvisor.dev/issue/5696): Support updating manipulation type. manip manipType // tcbHook indicates if the packet is inbound or outbound to @@ -124,6 +128,18 @@ type conn struct { lastUsed time.Time `state:".(unixTime)"` } +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn +} + // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour @@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { }, nil } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - func (ct *ConnTrack) init() { ct.mu.Lock() defer ct.mu.Unlock() @@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1 return nil } - // Create a new connection and change the port as per the iptables - // rule. This tuple will be used to manipulate the packet in - // handlePacket. replyTID := tid.reply() replyTID.srcAddr = address replyTID.srcPort = port - var manip manipType - switch hook { - case Prerouting: - manip = manipDstPrerouting - case Output: - manip = manipDstOutput + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil } - conn := newConn(tid, replyTID, manip, hook) + conn = newConn(tid, replyTID, manipDestination, hook) + ct.insertConn(conn) + return conn +} + +func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Input && hook != Postrouting { + return nil + } + + replyTID := tid.reply() + replyTID.dstAddr = address + replyTID.dstPort = port + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil + } + conn = newConn(tid, replyTID, manipSource, hook) ct.insertConn(conn) return conn } @@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Now that we hold the locks, ensure the tuple hasn't been inserted by // another thread. + // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? alreadyInserted := false for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { if other.tupleID == conn.original.tupleID { @@ -343,86 +369,6 @@ func (ct *ConnTrack) insertConn(conn *conn) { } } -// handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. -func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For prerouting redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. - switch dir { - case dirOriginal: - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - case dirReply: - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated - // on inbound packets, so we don't recalculate them. However, we should - // support cases when they are validated, e.g. when we can't offload - // receive checksumming. - - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - -// handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For output redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. For prerouting redirection, we only reach this point - // when replying, so packet sources are modified. - if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - } else { - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - if gso != nil && gso.NeedsCsum { - tcpHeader.SetChecksum(xsum) - } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) - } - - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - // handlePacket will manipulate the port and address of the packet if the // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. @@ -431,7 +377,9 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou return false } - if hook != Prerouting && hook != Output { + switch hook { + case Prerouting, Input, Output, Postrouting: + default: return false } @@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } conn, dir := ct.connFor(pkt) - // Connection or Rule not found for the packet. + // Connection not found for the packet. if conn == nil { - return true + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting } + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return false } + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. + + switch hook { + case Prerouting, Output: + if conn.manip == manipDestination { + switch dir { + case dirOriginal: + tcpHeader.SetDestinationPort(conn.reply.srcPort) + netHeader.SetDestinationAddress(conn.reply.srcAddr) + case dirReply: + tcpHeader.SetSourcePort(conn.original.dstPort) + netHeader.SetSourceAddress(conn.original.dstAddr) + } + pkt.NatDone = true + } + case Input, Postrouting: + if conn.manip == manipSource { + switch dir { + case dirOriginal: + tcpHeader.SetSourcePort(conn.reply.dstPort) + netHeader.SetSourceAddress(conn.reply.dstAddr) + case dirReply: + tcpHeader.SetDestinationPort(conn.original.srcPort) + netHeader.SetDestinationAddress(conn.original.srcAddr) + } + pkt.NatDone = true + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + if !pkt.NatDone { + return false + } + switch hook { - case Prerouting: - handlePacketPrerouting(pkt, conn, dir) - case Output: - handlePacketOutput(pkt, conn, gso, r, dir) + case Prerouting, Input: + case Output, Postrouting: + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + length := uint16(len(tcpHeader) + pkt.Data().Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + if gso != nil && gso.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) } - pkt.NatDone = true // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle @@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ if conn == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip == manipNone { - // Unmanipulated connection. + } else if conn.manip != manipDestination { + // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } |