diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 162 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_state_autogen.go | 17 |
2 files changed, 93 insertions, 86 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 16d295271..48f290187 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -45,17 +45,6 @@ const ( dirReply ) -// Manipulation type for the connection. -// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and -// DNAT at the same time. -type manipType int - -const ( - manipNone manipType = iota - manipSource - manipDestination -) - // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. // @@ -124,10 +113,14 @@ type conn struct { // // +checklocks:mu finalized bool - // manip indicates if the packet should be manipulated. + // sourceManip indicates the packet's source is manipulated. // // +checklocks:mu - manip manipType + sourceManip bool + // destinationManip indicates the packet's destination is manipulated. + // + // +checklocks:mu + destinationManip bool // tcb is TCB control block. It is used to keep track of states // of tcp connection. // @@ -286,7 +279,6 @@ func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { ct: ct, original: tuple{tupleID: tid, direction: dirOriginal}, reply: tuple{tupleID: tid.reply(), direction: dirReply}, - manip: manipNone, lastUsed: now, } conn.original.conn = conn @@ -393,8 +385,16 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) return } - if cn.manip != manipNone { - return + if dnat { + if cn.destinationManip { + return + } + cn.destinationManip = true + } else { + if cn.sourceManip { + return + } + cn.sourceManip = true } cn.reply.mu.Lock() @@ -403,11 +403,9 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) if dnat { cn.reply.tupleID.srcAddr = address cn.reply.tupleID.srcPort = port - cn.manip = manipDestination } else { cn.reply.tupleID.dstAddr = address cn.reply.tupleID.dstPort = port - cn.manip = manipSource } } @@ -421,68 +419,24 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) { return } - netHeader := pkt.Network() - - // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be - // validated if checksum offloading is off. It may require IP defrag if the - // packets are fragmented. - - var newAddr tcpip.Address - var newPort uint16 - - updateSRCFields := false - - dir := pkt.tuple.direction - - cn.mu.Lock() - defer cn.mu.Unlock() - - switch hook { - case Prerouting, Output: - if cn.manip == manipDestination && dir == dirOriginal { - id := cn.reply.id() - newPort = id.srcPort - newAddr = id.srcAddr - pkt.NatDone = true - } else if cn.manip == manipSource && dir == dirReply { - id := cn.original.id() - newPort = id.srcPort - newAddr = id.srcAddr - pkt.NatDone = true - } - case Input, Postrouting: - if cn.manip == manipSource && dir == dirOriginal { - id := cn.reply.id() - newPort = id.dstPort - newAddr = id.dstAddr - updateSRCFields = true - pkt.NatDone = true - } else if cn.manip == manipDestination && dir == dirReply { - id := cn.original.id() - newPort = id.dstPort - newAddr = id.dstAddr - updateSRCFields = true - pkt.NatDone = true - } - default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) - } - - if !pkt.NatDone { - return - } - fullChecksum := false updatePseudoHeader := false + dnat := false switch hook { case Prerouting: // Packet came from outside the stack so it must have a checksum set // already. fullChecksum = true updatePseudoHeader = true + + dnat = true case Input: - case Output, Postrouting: - // Calculate the TCP checksum and set it. + case Forward: + panic("should not handle packet in the forwarding hook") + case Output: + dnat = true + fallthrough + case Postrouting: if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { @@ -490,23 +444,73 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) { updatePseudoHeader = true } default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) + panic(fmt.Sprintf("unrecognized hook = %d", hook)) + } + + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. + + dir := pkt.tuple.direction + tid, performManip := func() (tupleID, bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + var tuple *tuple + switch dir { + case dirOriginal: + if dnat { + if !cn.destinationManip { + return tupleID{}, false + } + } else if !cn.sourceManip { + return tupleID{}, false + } + + tuple = &cn.reply + case dirReply: + if dnat { + if !cn.sourceManip { + return tupleID{}, false + } + } else if !cn.destinationManip { + return tupleID{}, false + } + + tuple = &cn.original + default: + panic(fmt.Sprintf("unhandled dir = %d", dir)) + } + + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = time.Now() + // Update connection state. + cn.updateLocked(pkt, dir) + + return tuple.id(), true + }() + if !performManip { + return + } + + newPort := tid.dstPort + newAddr := tid.dstAddr + if dnat { + newPort = tid.srcPort + newAddr = tid.srcAddr } rewritePacket( - netHeader, + pkt.Network(), transportHeader, - updateSRCFields, + !dnat, fullChecksum, updatePseudoHeader, newPort, newAddr, ) - // Mark the connection as having been used recently so it isn't reaped. - cn.lastUsed = time.Now() - // Update connection state. - cn.updateLocked(pkt, dir) + pkt.NatDone = true } // bucket gets the conntrack bucket for a tupleID. @@ -651,7 +655,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ t.conn.mu.RLock() defer t.conn.mu.RUnlock() - if t.conn.manip != manipDestination { + if !t.conn.destinationManip { // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go index 2d0966fd2..99fc2df69 100644 --- a/pkg/tcpip/stack/stack_state_autogen.go +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -90,7 +90,8 @@ func (cn *conn) StateFields() []string { "original", "reply", "finalized", - "manip", + "sourceManip", + "destinationManip", "tcb", "lastUsed", } @@ -103,13 +104,14 @@ func (cn *conn) StateSave(stateSinkObject state.Sink) { cn.beforeSave() var lastUsedValue unixTime lastUsedValue = cn.saveLastUsed() - stateSinkObject.SaveValue(6, lastUsedValue) + stateSinkObject.SaveValue(7, lastUsedValue) stateSinkObject.Save(0, &cn.ct) stateSinkObject.Save(1, &cn.original) stateSinkObject.Save(2, &cn.reply) stateSinkObject.Save(3, &cn.finalized) - stateSinkObject.Save(4, &cn.manip) - stateSinkObject.Save(5, &cn.tcb) + stateSinkObject.Save(4, &cn.sourceManip) + stateSinkObject.Save(5, &cn.destinationManip) + stateSinkObject.Save(6, &cn.tcb) } func (cn *conn) afterLoad() {} @@ -120,9 +122,10 @@ func (cn *conn) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(1, &cn.original) stateSourceObject.Load(2, &cn.reply) stateSourceObject.Load(3, &cn.finalized) - stateSourceObject.Load(4, &cn.manip) - stateSourceObject.Load(5, &cn.tcb) - stateSourceObject.LoadValue(6, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) }) + stateSourceObject.Load(4, &cn.sourceManip) + stateSourceObject.Load(5, &cn.destinationManip) + stateSourceObject.Load(6, &cn.tcb) + stateSourceObject.LoadValue(7, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) }) } func (ct *ConnTrack) StateTypeName() string { |