diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/registration.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 |
4 files changed, 53 insertions, 49 deletions
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ff356ea22..f3cc849ec 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,13 +64,24 @@ const ( type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) + HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) } +// RawTransportEndpoint is the interface that needs to be implemented by raw +// transport protocol endpoints. RawTransportEndpoints receive the entire +// packet - including the link, network, and transport headers - as delivered +// to netstack. +type RawTransportEndpoint interface { + // HandlePacket is called by the stack when new packets arrive to + // this transport endpoint. The packet contains all data from the link + // layer up. + HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) +} + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 15a268b10..a74c0a7a0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -955,11 +955,11 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip } // RegisterRawTransportEndpoint registers the given endpoint with the stack -// transport dispatcher. Received packets that match the provided protocol will -// be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { +// transport dispatcher. Received packets that match the provided transport +// protocol will be delivered to the given endpoint. +func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { if nicID == 0 { - return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } s.mu.RLock() @@ -970,14 +970,14 @@ func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpi return tcpip.ErrUnknownNICID } - return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + return nic.demux.registerRawEndpoint(netProto, transProto, ep) } -// UnregisterRawTransportEndpoint removes the endpoint for the protocol from -// the stack transport dispatcher. -func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { +// UnregisterRawTransportEndpoint removes the endpoint for the transport +// protocol from the stack transport dispatcher. +func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { if nicID == 0 { - s.demux.unregisterRawEndpoint(netProtos, protocol, ep) + s.demux.unregisterRawEndpoint(netProto, transProto, ep) return } @@ -986,7 +986,7 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tc nic := s.nics[nicID] if nic != nil { - nic.demux.unregisterRawEndpoint(netProtos, protocol, ep) + nic.demux.unregisterRawEndpoint(netProto, transProto, ep) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9ab314188..a8ac18e72 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "math/rand" "sync" @@ -37,7 +38,7 @@ type transportEndpoints struct { endpoints map[TransportEndpointID]TransportEndpoint // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. - rawEndpoints []TransportEndpoint + rawEndpoints []RawTransportEndpoint } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -60,8 +61,10 @@ func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep Tra // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then -// based on endpoints IDs. +// based on endpoints IDs. It should only be instantiated via +// newTransportDemuxer. type transportDemuxer struct { + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints } @@ -137,22 +140,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { // If this is a broadcast datagram, deliver the datagram to all endpoints // managed by ep. if id.LocalAddress == header.IPv4Broadcast { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, netHeader, vv) + endpoint.HandlePacket(r, id, vv) break } vvCopy := buffer.NewView(vv.Size()) copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } } else { - ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv) + ep.selectEndpoint(id).HandlePacket(r, id, vv) } } @@ -286,17 +289,17 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via // raw endpoint first. If there are multipe raw endpoints, they all // receive the packet. - found := false + foundRaw := false for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) - found = true + rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + foundRaw = true } eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 && !found { + if len(destEps) == 0 && !foundRaw { // UDP packet could not be delivered to an unknown destination port. if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() @@ -306,7 +309,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, netHeader, vv) + ep.HandlePacket(r, id, vv) } return true @@ -371,19 +374,8 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw // endpoint. -func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { - for i, n := range netProtos { - if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil { - d.unregisterRawEndpoint(netProtos[:i], protocol, ep) - return err - } - } - - return nil -} - -func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error { - eps, ok := d.protocol[protocolIDs{netProto, protocol}] +func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil } @@ -395,19 +387,20 @@ func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProto return nil } -// unregisterRawEndpoint unregisters the raw endpoint for the given protocol -// such that it won't receive any more packets. -func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { - for _, n := range netProtos { - if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.mu.Lock() - defer eps.mu.Unlock() - for i, rawEP := range eps.rawEndpoints { - if rawEP == ep { - eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) - return - } - } +// unregisterRawEndpoint unregisters the raw endpoint for the given transport +// protocol such that it won't receive any more packets. +func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto)) + } + + eps.mu.Lock() + defer eps.mu.Unlock() + for i, rawEP := range eps.rawEndpoints { + if rawEP == ep { + eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) + return } } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index dfd31557a..0c2589083 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { |