diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-09-22 17:52:43 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-09-22 17:55:49 -0700 |
commit | d8772545113ff941d34a4eae5f43df3f39d3547f (patch) | |
tree | c0e692655feac0fdf33542031a0b829c1893c341 /pkg/tcpip/stack | |
parent | 440fc07f70203caf517c5cbc3dcc3e83b7de3f05 (diff) |
Track UDP connections
This will enable NAT to be performed on UDP packets that are sent
in response to packets sent by the stack.
This will also enable ICMP errors to be properly NAT-ed in response
to UDP packets (#5916).
Updates #5915.
PiperOrigin-RevId: 398373251
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 110 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 46 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 23 |
4 files changed, 99 insertions, 84 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..4fb7e9adb 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // // Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. @@ -209,27 +215,38 @@ type bucket struct { tuples tupleList } +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } + } + + return nil, false +} + // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), + srcPort: transportHeader.SourcePort(), dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, netProto: pkt.NetworkProtocolNumber, }, nil } @@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return false } @@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false - } // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the @@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true - } + if conn.manip == manipDestination && dir == dirOriginal { + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr + pkt.NatDone = true + } else if conn.manip == manipSource && dir == dirReply { + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr - } + if conn.manip == manipSource && dir == dirOriginal { + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + updateSRCFields = true + pkt.NatDone = true + } else if conn.manip == manipDestination && dir == dirReply { + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + updateSRCFields = true pkt.NatDone = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } + if !pkt.NatDone { return false } @@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { fullChecksum := false updatePseudoHeader := false switch hook { - case Prerouting, Input: + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + case Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { fullChecksum = true @@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { rewritePacket( netHeader, - tcpHeader, + transportHeader, updateSRCFields, fullChecksum, updatePseudoHeader, @@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) return false } @@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber, header.UDPProtocolNumber: + default: + // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable + // connections. return } @@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) ct.insertConn(conn) } @@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } conn, _ := ct.connForTID(tid) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f152c0d83..3617b6dd0 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -482,11 +482,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, &tcpip.ErrNotConnected{} } - return it.connections.originalDst(epID, netProto) + return it.connections.originalDst(epID, netProto, transProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 96cc899bb..de5997e9e 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -206,34 +206,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - header.UDP(pkt.TransportHeader().View()), - true, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - st.Port, - st.Addr, - ) - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 + port := st.Port + + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + if port == 0 { + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + } + case header.TCPProtocolNumber: + if port == 0 { + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + } } + } - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 + // Set up conection for matching NAT rule. Only the first packet of the + // connection comes here. Other packets will be manipulated in connection + // tracking. + // + // Does nothing if the protocol does not support connection tracking. + if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, r) } return RuleAccept, 0 diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index b9280c2de..bf248ef20 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -335,9 +335,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // tell if a noop connection should be inserted at Input hook. Once conntrack // redefines the manipulation field as mutable, we won't need the special noop // connection. - if pk.NatDone { - newPk.NatDone = true - } + newPk.NatDone = pk.NatDone return newPk } @@ -347,7 +345,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // The returned packet buffer will have the network and transport headers // set if the original packet buffer did. func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer { - newPkt := NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: reservedHeaderBytes, Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(), IsForwardedPacket: true, @@ -355,21 +353,28 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu { consumeBytes := pk.NetworkHeader().View().Size() - if _, consumed := newPkt.NetworkHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes)) } - newPkt.NetworkProtocolNumber = pk.NetworkProtocolNumber + newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber } { consumeBytes := pk.TransportHeader().View().Size() - if _, consumed := newPkt.TransportHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes)) } - newPkt.TransportProtocolNumber = pk.TransportProtocolNumber + newPk.TransportProtocolNumber = pk.TransportProtocolNumber } - return newPkt + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + newPk.NatDone = pk.NatDone + + return newPk } // headerInfo stores metadata about a header in a packet. |