diff options
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 112 |
1 files changed, 59 insertions, 53 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index ff1845bfb..d4c0359e8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -409,61 +409,45 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - eps.mu.RLock() - - // Determine which transport endpoint or endpoints to deliver this packet to. // If the packet is a UDP broadcast or multicast, then find all matching - // transport endpoints. If the packet is a TCP packet with a non-unicast - // source or destination address, then do nothing further and instruct - // the caller to do the same. - var destEps []*endpointsByNic - switch protocol { - case header.UDPProtocolNumber: - if isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, id) - break - } - - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) + // transport endpoints. + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + eps.mu.RLock() + destEPs := d.findAllEndpointsLocked(eps, id) + eps.mu.RUnlock() + // Fail if we didn't find at least one matching transport endpoint. + if len(destEPs) == 0 { + r.Stats().UDP.UnknownPortErrors.Increment() + return false } - - case header.TCPProtocolNumber: - if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { - // TCP can only be used to communicate between a single - // source and a single destination; the addresses must - // be unicast. - eps.mu.RUnlock() - r.Stats().TCP.InvalidSegmentsReceived.Increment() - return true + // handlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEPs[:len(destEPs)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + return true + } - fallthrough - - default: - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) - } + // If the packet is a TCP packet with a non-unicast source or destination + // address, then do nothing further and instruct the caller to do the same. + if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single source and a + // single destination; the addresses must be unicast. + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true } + eps.mu.RLock() + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() - - // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { - // UDP packet could not be delivered to an unknown destination port. + if ep == nil { if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() } return false } - - // HandlePacket takes ownership of pkt, so each endpoint needs its own - // copy except for the final one. - for _, ep := range destEps[:len(destEps)-1] { - ep.handlePacket(r, id, pkt.Clone()) - } - destEps[len(destEps)-1].handlePacket(r, id, pkt) - + ep.handlePacket(r, id, pkt) return true } @@ -515,11 +499,17 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic +// iterEndpointsLocked yields all endpointsByNic in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with the id minus the local address. @@ -527,7 +517,9 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with the id minus the remote part. @@ -535,14 +527,26 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } +} + +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic + d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) return matchedEPs } @@ -580,10 +584,12 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN // findEndpointLocked returns the endpoint that most closely matches the given // id. func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { - return matchedEPs[0] - } - return nil + var matchedEP *endpointsByNic + d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { + matchedEP = ep + return false + }) + return matchedEP } // registerRawEndpoint registers the given endpoint with the dispatcher such |