summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/transport_demuxer.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go45
1 files changed, 37 insertions, 8 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cb805522b..67c21be42 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -389,8 +389,8 @@ var loopbackSubnet = func() tcpip.Subnet {
}()
// deliverPacket attempts to find one or more matching transport endpoints, and
-// then, if matches are found, delivers the packet to them. Returns true if it
-// found one or more endpoints, false otherwise.
+// then, if matches are found, delivers the packet to them. Returns true if
+// the packet no longer needs to be handled.
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
@@ -400,13 +400,38 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RLock()
// Determine which transport endpoint or endpoints to deliver this packet to.
- // If the packet is a broadcast or multicast, then find all matching
- // transport endpoints.
+ // If the packet is a UDP broadcast or multicast, then find all matching
+ // transport endpoints. If the packet is a TCP packet with a non-unicast
+ // source or destination address, then do nothing further and instruct
+ // the caller to do the same.
var destEps []*endpointsByNic
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, id)
- } else if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
+ switch protocol {
+ case header.UDPProtocolNumber:
+ if isMulticastOrBroadcast(id.LocalAddress) {
+ destEps = d.findAllEndpointsLocked(eps, id)
+ break
+ }
+
+ if ep := d.findEndpointLocked(eps, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
+
+ case header.TCPProtocolNumber:
+ if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single
+ // source and a single destination; the addresses must
+ // be unicast.
+ eps.mu.RUnlock()
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
+ }
+
+ fallthrough
+
+ default:
+ if ep := d.findEndpointLocked(eps, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
}
eps.mu.RUnlock()
@@ -587,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
func isMulticastOrBroadcast(addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
}
+
+func isUnicast(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+}