summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/conntrack.go162
-rw-r--r--pkg/tcpip/stack/stack_state_autogen.go17
2 files changed, 93 insertions, 86 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 16d295271..48f290187 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -45,17 +45,6 @@ const (
dirReply
)
-// Manipulation type for the connection.
-// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
-// DNAT at the same time.
-type manipType int
-
-const (
- manipNone manipType = iota
- manipSource
- manipDestination
-)
-
// tuple holds a connection's identifying and manipulating data in one
// direction. It is immutable.
//
@@ -124,10 +113,14 @@ type conn struct {
//
// +checklocks:mu
finalized bool
- // manip indicates if the packet should be manipulated.
+ // sourceManip indicates the packet's source is manipulated.
//
// +checklocks:mu
- manip manipType
+ sourceManip bool
+ // destinationManip indicates the packet's destination is manipulated.
+ //
+ // +checklocks:mu
+ destinationManip bool
// tcb is TCB control block. It is used to keep track of states
// of tcp connection.
//
@@ -286,7 +279,6 @@ func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
ct: ct,
original: tuple{tupleID: tid, direction: dirOriginal},
reply: tuple{tupleID: tid.reply(), direction: dirReply},
- manip: manipNone,
lastUsed: now,
}
conn.original.conn = conn
@@ -393,8 +385,16 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool)
return
}
- if cn.manip != manipNone {
- return
+ if dnat {
+ if cn.destinationManip {
+ return
+ }
+ cn.destinationManip = true
+ } else {
+ if cn.sourceManip {
+ return
+ }
+ cn.sourceManip = true
}
cn.reply.mu.Lock()
@@ -403,11 +403,9 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool)
if dnat {
cn.reply.tupleID.srcAddr = address
cn.reply.tupleID.srcPort = port
- cn.manip = manipDestination
} else {
cn.reply.tupleID.dstAddr = address
cn.reply.tupleID.dstPort = port
- cn.manip = manipSource
}
}
@@ -421,68 +419,24 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
return
}
- netHeader := pkt.Network()
-
- // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
- // validated if checksum offloading is off. It may require IP defrag if the
- // packets are fragmented.
-
- var newAddr tcpip.Address
- var newPort uint16
-
- updateSRCFields := false
-
- dir := pkt.tuple.direction
-
- cn.mu.Lock()
- defer cn.mu.Unlock()
-
- switch hook {
- case Prerouting, Output:
- if cn.manip == manipDestination && dir == dirOriginal {
- id := cn.reply.id()
- newPort = id.srcPort
- newAddr = id.srcAddr
- pkt.NatDone = true
- } else if cn.manip == manipSource && dir == dirReply {
- id := cn.original.id()
- newPort = id.srcPort
- newAddr = id.srcAddr
- pkt.NatDone = true
- }
- case Input, Postrouting:
- if cn.manip == manipSource && dir == dirOriginal {
- id := cn.reply.id()
- newPort = id.dstPort
- newAddr = id.dstAddr
- updateSRCFields = true
- pkt.NatDone = true
- } else if cn.manip == manipDestination && dir == dirReply {
- id := cn.original.id()
- newPort = id.dstPort
- newAddr = id.dstAddr
- updateSRCFields = true
- pkt.NatDone = true
- }
- default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
- }
-
- if !pkt.NatDone {
- return
- }
-
fullChecksum := false
updatePseudoHeader := false
+ dnat := false
switch hook {
case Prerouting:
// Packet came from outside the stack so it must have a checksum set
// already.
fullChecksum = true
updatePseudoHeader = true
+
+ dnat = true
case Input:
- case Output, Postrouting:
- // Calculate the TCP checksum and set it.
+ case Forward:
+ panic("should not handle packet in the forwarding hook")
+ case Output:
+ dnat = true
+ fallthrough
+ case Postrouting:
if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
@@ -490,23 +444,73 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
updatePseudoHeader = true
}
default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ panic(fmt.Sprintf("unrecognized hook = %d", hook))
+ }
+
+ // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
+ // validated if checksum offloading is off. It may require IP defrag if the
+ // packets are fragmented.
+
+ dir := pkt.tuple.direction
+ tid, performManip := func() (tupleID, bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ var tuple *tuple
+ switch dir {
+ case dirOriginal:
+ if dnat {
+ if !cn.destinationManip {
+ return tupleID{}, false
+ }
+ } else if !cn.sourceManip {
+ return tupleID{}, false
+ }
+
+ tuple = &cn.reply
+ case dirReply:
+ if dnat {
+ if !cn.sourceManip {
+ return tupleID{}, false
+ }
+ } else if !cn.destinationManip {
+ return tupleID{}, false
+ }
+
+ tuple = &cn.original
+ default:
+ panic(fmt.Sprintf("unhandled dir = %d", dir))
+ }
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = time.Now()
+ // Update connection state.
+ cn.updateLocked(pkt, dir)
+
+ return tuple.id(), true
+ }()
+ if !performManip {
+ return
+ }
+
+ newPort := tid.dstPort
+ newAddr := tid.dstAddr
+ if dnat {
+ newPort = tid.srcPort
+ newAddr = tid.srcAddr
}
rewritePacket(
- netHeader,
+ pkt.Network(),
transportHeader,
- updateSRCFields,
+ !dnat,
fullChecksum,
updatePseudoHeader,
newPort,
newAddr,
)
- // Mark the connection as having been used recently so it isn't reaped.
- cn.lastUsed = time.Now()
- // Update connection state.
- cn.updateLocked(pkt, dir)
+ pkt.NatDone = true
}
// bucket gets the conntrack bucket for a tupleID.
@@ -651,7 +655,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
t.conn.mu.RLock()
defer t.conn.mu.RUnlock()
- if t.conn.manip != manipDestination {
+ if !t.conn.destinationManip {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
index 2d0966fd2..99fc2df69 100644
--- a/pkg/tcpip/stack/stack_state_autogen.go
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -90,7 +90,8 @@ func (cn *conn) StateFields() []string {
"original",
"reply",
"finalized",
- "manip",
+ "sourceManip",
+ "destinationManip",
"tcb",
"lastUsed",
}
@@ -103,13 +104,14 @@ func (cn *conn) StateSave(stateSinkObject state.Sink) {
cn.beforeSave()
var lastUsedValue unixTime
lastUsedValue = cn.saveLastUsed()
- stateSinkObject.SaveValue(6, lastUsedValue)
+ stateSinkObject.SaveValue(7, lastUsedValue)
stateSinkObject.Save(0, &cn.ct)
stateSinkObject.Save(1, &cn.original)
stateSinkObject.Save(2, &cn.reply)
stateSinkObject.Save(3, &cn.finalized)
- stateSinkObject.Save(4, &cn.manip)
- stateSinkObject.Save(5, &cn.tcb)
+ stateSinkObject.Save(4, &cn.sourceManip)
+ stateSinkObject.Save(5, &cn.destinationManip)
+ stateSinkObject.Save(6, &cn.tcb)
}
func (cn *conn) afterLoad() {}
@@ -120,9 +122,10 @@ func (cn *conn) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(1, &cn.original)
stateSourceObject.Load(2, &cn.reply)
stateSourceObject.Load(3, &cn.finalized)
- stateSourceObject.Load(4, &cn.manip)
- stateSourceObject.Load(5, &cn.tcb)
- stateSourceObject.LoadValue(6, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) })
+ stateSourceObject.Load(4, &cn.sourceManip)
+ stateSourceObject.Load(5, &cn.destinationManip)
+ stateSourceObject.Load(6, &cn.tcb)
+ stateSourceObject.LoadValue(7, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) })
}
func (ct *ConnTrack) StateTypeName() string {