summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/conntrack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r--pkg/tcpip/stack/conntrack.go175
1 files changed, 93 insertions, 82 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index c489506bb..1c6060b70 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -119,22 +119,24 @@ type conn struct {
//
// +checklocks:mu
destinationManip bool
+
+ stateMu sync.RWMutex `state:"nosave"`
// tcb is TCB control block. It is used to keep track of states
// of tcp connection.
//
- // +checklocks:mu
+ // +checklocks:stateMu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
// is updated by each packet on the connection.
//
- // +checklocks:mu
+ // +checklocks:stateMu
lastUsed tcpip.MonotonicTime
}
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
- cn.mu.RLock()
- defer cn.mu.RUnlock()
+ cn.stateMu.RLock()
+ defer cn.stateMu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -147,7 +149,7 @@ func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
// update the connection tracking state.
//
-// +checklocks:cn.mu
+// +checklocks:cn.stateMu
func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) {
if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
return
@@ -209,17 +211,41 @@ type bucket struct {
tuples tupleList
}
-func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) {
- switch pkt.tuple.id().transProto {
+// A netAndTransHeadersFunc returns the network and transport headers found
+// in an ICMP payload. The transport layer's payload will not be returned.
+//
+// May panic if the packet does not hold the transport header.
+type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte)
+
+func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv4(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the buffer is smaller than
+ // the total length specified in the IPv4 header.
+ transHdr := icmpPayload[netHdr.HeaderLength():]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv6(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the IP payload is smaller than
+ // the payload length specified in the IPv6 header.
+ transHdr := icmpPayload[header.IPv6MinimumSize:]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) {
+ switch transProto {
case header.TCPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.TCP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize)
+ return netHeader, header.TCP(transHeaderBytes), true
}
case header.UDPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.UDP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize)
+ return netHeader, header.UDP(transHeaderBytes), true
}
}
return nil, nil, false
@@ -246,7 +272,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic("should have dropped packets with IPv4 options")
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok {
return netHdr, transHdr, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -264,7 +290,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto))
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok {
return netHdr, transHdr, true, true
}
}
@@ -283,34 +309,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro
}
}
-func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
- switch transProto {
- case header.TCPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.TCP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
- case header.UDPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.UDP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
+func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok {
+ return tupleID{
+ srcAddr: netHdr.DestinationAddress(),
+ srcPort: transHdr.DestinationPort(),
+ dstAddr: netHdr.SourceAddress(),
+ dstPort: transHdr.SourcePort(),
+ transProto: transProto,
+ netProto: netProto,
+ }, true
}
return tupleID{}, false
@@ -349,7 +357,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
return tupleID{}, false, false
}
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
return tid, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -370,7 +378,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
}
// TODO(https://gvisor.dev/issue/6789): Handle extension headers.
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
return tid, true, true
}
}
@@ -601,14 +609,17 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
// packets are fragmented.
reply := pkt.tuple.reply
- tid, performManip := func() (tupleID, bool) {
- cn.mu.Lock()
- defer cn.mu.Unlock()
- // Mark the connection as having been used recently so it isn't reaped.
- cn.lastUsed = cn.ct.clock.NowMonotonic()
- // Update connection state.
- cn.updateLocked(pkt, reply)
+ cn.stateMu.Lock()
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = cn.ct.clock.NowMonotonic()
+ // Update connection state.
+ cn.updateLocked(pkt, reply)
+ cn.stateMu.Unlock()
+
+ tid, performManip := func() (tupleID, bool) {
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
var tuple *tuple
if reply {
@@ -730,9 +741,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused deletes timed out entries from the conntrack map. The rules for
// reaping are:
-// - Most reaping occurs in connFor, which is called on each packet. connFor
-// cleans up the bucket the packet's connection maps to. Thus calls to
-// reapUnused should be fast.
// - Each call to reapUnused traverses a fraction of the conntrack table.
// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
// - After reaping, reapUnused decides when it should next run based on the
@@ -799,45 +807,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// Precondition: ct.mu is read locked and bkt.mu is write locked.
// +checklocksread:ct.mu
// +checklocks:bkt.mu
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
- if !tuple.conn.timedOut(now) {
+func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
+ if !reapingTuple.conn.timedOut(now) {
return false
}
- // To maintain lock order, we can only reap both tuples if the reply appears
- // later in the table.
- replyBktID := ct.bucket(tuple.id().reply())
- tuple.conn.mu.RLock()
- replyTupleInserted := tuple.conn.finalized
- tuple.conn.mu.RUnlock()
- if bktID > replyBktID && replyTupleInserted {
- return true
+ var otherTuple *tuple
+ if reapingTuple.reply {
+ otherTuple = &reapingTuple.conn.original
+ } else {
+ otherTuple = &reapingTuple.conn.reply
}
- // Reap the reply.
- if replyTupleInserted {
- // Don't re-lock if both tuples are in the same bucket.
- if bktID != replyBktID {
- replyBkt := &ct.buckets[replyBktID]
- replyBkt.mu.Lock()
- removeConnFromBucket(replyBkt, tuple)
- replyBkt.mu.Unlock()
- } else {
- removeConnFromBucket(bkt, tuple)
- }
+ otherTupleBktID := ct.bucket(otherTuple.id())
+ reapingTuple.conn.mu.RLock()
+ replyTupleInserted := reapingTuple.conn.finalized
+ reapingTuple.conn.mu.RUnlock()
+
+ // To maintain lock order, we can only reap both tuples if the tuple for the
+ // other direction appears later in the table.
+ if bktID > otherTupleBktID && replyTupleInserted {
+ return true
}
- bkt.tuples.Remove(tuple)
- return true
-}
+ bkt.tuples.Remove(reapingTuple)
+
+ if !replyTupleInserted {
+ // The other tuple is the reply which has not yet been inserted.
+ return true
+ }
-// +checklocks:b.mu
-func removeConnFromBucket(b *bucket, tuple *tuple) {
- if tuple.reply {
- b.tuples.Remove(&tuple.conn.original)
+ // Reap the other connection.
+ if bktID == otherTupleBktID {
+ // Don't re-lock if both tuples are in the same bucket.
+ bkt.tuples.Remove(otherTuple)
} else {
- b.tuples.Remove(&tuple.conn.reply)
+ otherTupleBkt := &ct.buckets[otherTupleBktID]
+ otherTupleBkt.mu.Lock()
+ otherTupleBkt.tuples.Remove(otherTuple)
+ otherTupleBkt.mu.Unlock()
}
+
+ return true
}
func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {