diff options
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 110 |
1 files changed, 63 insertions, 47 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..4fb7e9adb 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // // Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. @@ -209,27 +215,38 @@ type bucket struct { tuples tupleList } +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } + } + + return nil, false +} + // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), + srcPort: transportHeader.SourcePort(), dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, netProto: pkt.NetworkProtocolNumber, }, nil } @@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return false } @@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false - } // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the @@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true - } + if conn.manip == manipDestination && dir == dirOriginal { + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr + pkt.NatDone = true + } else if conn.manip == manipSource && dir == dirReply { + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr - } + if conn.manip == manipSource && dir == dirOriginal { + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + updateSRCFields = true + pkt.NatDone = true + } else if conn.manip == manipDestination && dir == dirReply { + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + updateSRCFields = true pkt.NatDone = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } + if !pkt.NatDone { return false } @@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { fullChecksum := false updatePseudoHeader := false switch hook { - case Prerouting, Input: + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + case Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { fullChecksum = true @@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { rewritePacket( netHeader, - tcpHeader, + transportHeader, updateSRCFields, fullChecksum, updatePseudoHeader, @@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) return false } @@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber, header.UDPProtocolNumber: + default: + // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable + // connections. return } @@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) ct.insertConn(conn) } @@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } conn, _ := ct.connForTID(tid) |