summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/transport_demuxer.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go33
1 files changed, 31 insertions, 2 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 594570216..cb805522b 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -103,7 +103,6 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
return
}
-
// multiPortEndpoints are guaranteed to have at least one element.
selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
@@ -507,10 +506,40 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
if ep, ok := eps.endpoints[nid]; ok {
matchedEPs = append(matchedEPs, ep)
}
-
return matchedEPs
}
+// findTransportEndpoint find a single endpoint that most closely matches the provided id.
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
+ }
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ epsByNic := d.findEndpointLocked(eps, id)
+ // Fail if we didn't find one.
+ if epsByNic == nil {
+ eps.mu.RUnlock()
+ return nil
+ }
+
+ epsByNic.mu.RLock()
+ eps.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return nil
+ }
+ }
+
+ ep := selectEndpoint(id, mpep, epsByNic.seed)
+ epsByNic.mu.RUnlock()
+ return ep
+}
+
// findEndpointLocked returns the endpoint that most closely matches the given
// id.
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {