summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/conntrack.go110
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go46
-rw-r--r--pkg/tcpip/stack/packet_buffer.go23
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
5 files changed, 100 insertions, 85 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..4fb7e9adb 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
@@ -209,27 +215,38 @@ type bucket struct {
tuples tupleList
}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
+ }
+
+ return nil, false
+}
+
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
//
// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
return tupleID{
srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
+ srcPort: transportHeader.SourcePort(),
dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return false
}
@@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
}
netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
- }
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
@@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
switch hook {
case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
- }
+ if conn.manip == manipDestination && dir == dirOriginal {
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
+ pkt.NatDone = true
+ } else if conn.manip == manipSource && dir == dirReply {
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- }
+ if conn.manip == manipSource && dir == dirOriginal {
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+ updateSRCFields = true
+ pkt.NatDone = true
+ } else if conn.manip == manipDestination && dir == dirReply {
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+ updateSRCFields = true
pkt.NatDone = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
+
if !pkt.NatDone {
return false
}
@@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
fullChecksum := false
updatePseudoHeader := false
switch hook {
- case Prerouting, Input:
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+ case Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
fullChecksum = true
@@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
rewritePacket(
netHeader,
- tcpHeader,
+ transportHeader,
updateSRCFields,
fullChecksum,
updatePseudoHeader,
@@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
return false
}
@@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber, header.UDPProtocolNumber:
+ default:
+ // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable
+ // connections.
return
}
@@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
ct.insertConn(conn)
}
@@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
conn, _ := ct.connForTID(tid)
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..3617b6dd0 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -482,11 +482,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..de5997e9e 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -206,34 +206,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
+ port := st.Port
+
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ if port == 0 {
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ }
+ case header.TCPProtocolNumber:
+ if port == 0 {
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ }
}
+ }
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
+ // Set up conection for matching NAT rule. Only the first packet of the
+ // connection comes here. Other packets will be manipulated in connection
+ // tracking.
+ //
+ // Does nothing if the protocol does not support connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, r)
}
return RuleAccept, 0
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index b9280c2de..bf248ef20 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -335,9 +335,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// tell if a noop connection should be inserted at Input hook. Once conntrack
// redefines the manipulation field as mutable, we won't need the special noop
// connection.
- if pk.NatDone {
- newPk.NatDone = true
- }
+ newPk.NatDone = pk.NatDone
return newPk
}
@@ -347,7 +345,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// The returned packet buffer will have the network and transport headers
// set if the original packet buffer did.
func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
- newPkt := NewPacketBuffer(PacketBufferOptions{
+ newPk := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: reservedHeaderBytes,
Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
IsForwardedPacket: true,
@@ -355,21 +353,28 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu
{
consumeBytes := pk.NetworkHeader().View().Size()
- if _, consumed := newPkt.NetworkHeader().Consume(consumeBytes); !consumed {
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
}
- newPkt.NetworkProtocolNumber = pk.NetworkProtocolNumber
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
{
consumeBytes := pk.TransportHeader().View().Size()
- if _, consumed := newPkt.TransportHeader().Consume(consumeBytes); !consumed {
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
}
- newPkt.TransportProtocolNumber = pk.TransportProtocolNumber
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
- return newPkt
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ newPk.NatDone = pk.NatDone
+
+ return newPk
}
// headerInfo stores metadata about a header in a packet.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a3002abf3..407ab2664 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2066,7 +2066,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err