summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/transport_demuxer.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go63
1 files changed, 33 insertions, 30 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 35e5b1a2e..f183ec6e4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isInboundMulticastOrBroadcast(r) {
- mpep.handlePacketAll(r, id, pkt)
+ if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
+ mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
- queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return
}
- transEP.HandlePacket(r, id, pkt)
+ transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
type transportDemuxer struct {
+ stack *Stack
+
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
queuedProtocols map[protocolIDs]queuedTransportProtocol
@@ -262,11 +264,12 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
+ QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{
+ stack: stack,
protocol: make(map[protocolIDs]*transportEndpoints),
queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
}
@@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ queuedProtocol.QueuePacket(endpoint, id, pkt.Clone())
} else {
- endpoint.HandlePacket(r, id, pkt.Clone())
+ endpoint.HandlePacket(id, pkt.Clone())
}
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ queuedProtocol.QueuePacket(endpoint, id, pkt)
} else {
- endpoint.HandlePacket(r, id, pkt)
+ endpoint.HandlePacket(id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
// handlePacket takes ownership of pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
- ep.handlePacket(r, id, pkt.Clone())
+ ep.handlePacket(id, pkt.Clone())
}
- destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
}
@@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// destination address, then do nothing further and instruct the caller to do
// the same. The network layer handles address validation for specified source
// addresses.
- if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
+ if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
// TCP can only be used to communicate between a single source and a
- // single destination; the addresses must be unicast.
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ // single destination; the addresses must be unicast.e
+ d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
return true
}
@@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
}
return false
}
- ep.handlePacket(r, id, pkt)
+ ep.handlePacket(id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
@@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, pkt)
+ rawEP.HandlePacket(pkt.Clone())
foundRaw = true
}
eps.mu.RUnlock()
@@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
}
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
-func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
@@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[nicID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isInboundMulticastOrBroadcast(r *Route) bool {
- return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
+func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
+ return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
}
func isSpecified(addr tcpip.Address) bool {