diff options
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 175 |
1 files changed, 93 insertions, 82 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index c489506bb..1c6060b70 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -119,22 +119,24 @@ type conn struct { // // +checklocks:mu destinationManip bool + + stateMu sync.RWMutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection. // - // +checklocks:mu + // +checklocks:stateMu tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. // - // +checklocks:mu + // +checklocks:stateMu lastUsed tcpip.MonotonicTime } // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { - cn.mu.RLock() - defer cn.mu.RUnlock() + cn.stateMu.RLock() + defer cn.stateMu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { // Use the same default as Linux, which doesn't delete // established connections for 5(!) days. @@ -147,7 +149,7 @@ func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { // update the connection tracking state. // -// +checklocks:cn.mu +// +checklocks:cn.stateMu func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) { if pkt.TransportProtocolNumber != header.TCPProtocolNumber { return @@ -209,17 +211,41 @@ type bucket struct { tuples tupleList } -func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) { - switch pkt.tuple.id().transProto { +// A netAndTransHeadersFunc returns the network and transport headers found +// in an ICMP payload. The transport layer's payload will not be returned. +// +// May panic if the packet does not hold the transport header. +type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte) + +func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv4(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the buffer is smaller than + // the total length specified in the IPv4 header. + transHdr := icmpPayload[netHdr.HeaderLength():] + return netHdr, transHdr[:minTransHdrLen] +} + +func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv6(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the IP payload is smaller than + // the payload length specified in the IPv6 header. + transHdr := icmpPayload[header.IPv6MinimumSize:] + return netHdr, transHdr[:minTransHdrLen] +} + +func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) { + switch transProto { case header.TCPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.TCP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize) + return netHeader, header.TCP(transHeaderBytes), true } case header.UDPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.UDP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize) + return netHeader, header.UDP(transHeaderBytes), true } } return nil, nil, false @@ -246,7 +272,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic("should have dropped packets with IPv4 options") } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok { return netHdr, transHdr, true, true } case header.ICMPv6ProtocolNumber: @@ -264,7 +290,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto)) } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok { return netHdr, transHdr, true, true } } @@ -283,34 +309,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro } } -func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { - switch transProto { - case header.TCPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.TCP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } - case header.UDPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.UDP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } +func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok { + return tupleID{ + srcAddr: netHdr.DestinationAddress(), + srcPort: transHdr.DestinationPort(), + dstAddr: netHdr.SourceAddress(), + dstPort: transHdr.SourcePort(), + transProto: transProto, + netProto: netProto, + }, true } return tupleID{}, false @@ -349,7 +357,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { return tupleID{}, false, false } - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { return tid, true, true } case header.ICMPv6ProtocolNumber: @@ -370,7 +378,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { } // TODO(https://gvisor.dev/issue/6789): Handle extension headers. - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { return tid, true, true } } @@ -601,14 +609,17 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { // packets are fragmented. reply := pkt.tuple.reply - tid, performManip := func() (tupleID, bool) { - cn.mu.Lock() - defer cn.mu.Unlock() - // Mark the connection as having been used recently so it isn't reaped. - cn.lastUsed = cn.ct.clock.NowMonotonic() - // Update connection state. - cn.updateLocked(pkt, reply) + cn.stateMu.Lock() + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = cn.ct.clock.NowMonotonic() + // Update connection state. + cn.updateLocked(pkt, reply) + cn.stateMu.Unlock() + + tid, performManip := func() (tupleID, bool) { + cn.mu.RLock() + defer cn.mu.RUnlock() var tuple *tuple if reply { @@ -730,9 +741,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused deletes timed out entries from the conntrack map. The rules for // reaping are: -// - Most reaping occurs in connFor, which is called on each packet. connFor -// cleans up the bucket the packet's connection maps to. Thus calls to -// reapUnused should be fast. // - Each call to reapUnused traverses a fraction of the conntrack table. // Specifically, it traverses len(ct.buckets)/fractionPerReaping. // - After reaping, reapUnused decides when it should next run based on the @@ -799,45 +807,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // Precondition: ct.mu is read locked and bkt.mu is write locked. // +checklocksread:ct.mu // +checklocks:bkt.mu -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { - if !tuple.conn.timedOut(now) { +func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { + if !reapingTuple.conn.timedOut(now) { return false } - // To maintain lock order, we can only reap both tuples if the reply appears - // later in the table. - replyBktID := ct.bucket(tuple.id().reply()) - tuple.conn.mu.RLock() - replyTupleInserted := tuple.conn.finalized - tuple.conn.mu.RUnlock() - if bktID > replyBktID && replyTupleInserted { - return true + var otherTuple *tuple + if reapingTuple.reply { + otherTuple = &reapingTuple.conn.original + } else { + otherTuple = &reapingTuple.conn.reply } - // Reap the reply. - if replyTupleInserted { - // Don't re-lock if both tuples are in the same bucket. - if bktID != replyBktID { - replyBkt := &ct.buckets[replyBktID] - replyBkt.mu.Lock() - removeConnFromBucket(replyBkt, tuple) - replyBkt.mu.Unlock() - } else { - removeConnFromBucket(bkt, tuple) - } + otherTupleBktID := ct.bucket(otherTuple.id()) + reapingTuple.conn.mu.RLock() + replyTupleInserted := reapingTuple.conn.finalized + reapingTuple.conn.mu.RUnlock() + + // To maintain lock order, we can only reap both tuples if the tuple for the + // other direction appears later in the table. + if bktID > otherTupleBktID && replyTupleInserted { + return true } - bkt.tuples.Remove(tuple) - return true -} + bkt.tuples.Remove(reapingTuple) + + if !replyTupleInserted { + // The other tuple is the reply which has not yet been inserted. + return true + } -// +checklocks:b.mu -func removeConnFromBucket(b *bucket, tuple *tuple) { - if tuple.reply { - b.tuples.Remove(&tuple.conn.original) + // Reap the other connection. + if bktID == otherTupleBktID { + // Don't re-lock if both tuples are in the same bucket. + bkt.tuples.Remove(otherTuple) } else { - b.tuples.Remove(&tuple.conn.reply) + otherTupleBkt := &ct.buckets[otherTupleBktID] + otherTupleBkt.mu.Lock() + otherTupleBkt.tuples.Remove(otherTuple) + otherTupleBkt.mu.Unlock() } + + return true } func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { |