diff options
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 63 |
1 files changed, 33 insertions, 30 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 35e5b1a2e..f183ec6e4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isInboundMulticastOrBroadcast(r) { - mpep.handlePacketAll(r, id, pkt) + if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { + mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { - queuedProtocol.QueuePacket(r, transEP, id, pkt) + queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() return } - transEP.HandlePacket(r, id, pkt) + transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } @@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T // based on endpoints IDs. It should only be instantiated via // newTransportDemuxer. type transportDemuxer struct { + stack *Stack + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints queuedProtocols map[protocolIDs]queuedTransportProtocol @@ -262,11 +264,12 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) + QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { d := &transportDemuxer{ + stack: stack, protocol: make(map[protocolIDs]*transportEndpoints), queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), } @@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) } else { - endpoint.HandlePacket(r, id, pkt.Clone()) + endpoint.HandlePacket(id, pkt.Clone()) } } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) + queuedProtocol.QueuePacket(endpoint, id, pkt) } else { - endpoint.HandlePacket(r, id, pkt) + endpoint.HandlePacket(id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() return false } // handlePacket takes ownership of pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(r, id, pkt.Clone()) + ep.handlePacket(id, pkt.Clone()) } - destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + destEPs[len(destEPs)-1].handlePacket(id, pkt) return true } @@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // destination address, then do nothing further and instruct the caller to do // the same. The network layer handles address validation for specified source // addresses. - if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { + if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { // TCP can only be used to communicate between a single source and a - // single destination; the addresses must be unicast. - r.Stats().TCP.InvalidSegmentsReceived.Increment() + // single destination; the addresses must be unicast.e + d.stack.stats.TCP.InvalidSegmentsReceived.Increment() return true } @@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() } return false } - ep.handlePacket(r, id, pkt) + ep.handlePacket(id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } @@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, pkt) + rawEP.HandlePacket(pkt.Clone()) foundRaw = true } eps.mu.RUnlock() @@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // findTransportEndpoint find a single endpoint that most closely matches the provided id. -func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil @@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[nicID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isInboundMulticastOrBroadcast(r *Route) bool { - return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) +func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { + return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) } func isSpecified(addr tcpip.Address) bool { |