diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/stack.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 65 |
2 files changed, 69 insertions, 2 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 51abe32a7..e92ec0c24 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1408,6 +1408,12 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } +// CheckRegisterTransportEndpoint checks if an endpoint can be registered with +// the stack transport dispatcher. +func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) +} + // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 118b449d5..b902c6ca9 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -221,6 +221,18 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() + + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] + if !ok { + return nil + } + + return multiPortEp.singleCheckEndpoint(flags) +} + // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() @@ -289,6 +301,17 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } +// checkEndpoint checks if an endpoint can be registered with the dispatcher. +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + for _, n := range netProtos { + if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { + return err + } + } + + return nil +} + // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -380,7 +403,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() + bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. @@ -395,6 +418,22 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { + ep.mu.RLock() + defer ep.mu.RUnlock() + + bits := flags.Bits() & ports.MultiBindFlagMask + + if len(ep.endpoints) != 0 { + // If it was previously bound, we need to check if we can bind again. + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + return tcpip.ErrPortInUse + } + } + + return nil +} + // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() @@ -406,7 +445,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports ep.endpoints[len(ep.endpoints)-1] = nil ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] - ep.flags.DropRef(flags.Bits()) + ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) break } } @@ -439,6 +478,28 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return tcpip.ErrUnknownProtocol + } + + eps.mu.RLock() + defer eps.mu.RUnlock() + + epsByNIC, ok := eps.endpoints[id] + if !ok { + return nil + } + + return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { |