diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 78 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 6 |
6 files changed, 137 insertions, 19 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 2278fbf65..79f845225 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -505,7 +505,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -525,16 +525,16 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, vv, id) { + if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } - if n.stack.demux.deliverPacket(r, protocol, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, vv) { + if state.defaultHandler(r, id, netHeader, vv) { return } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5accffa1b..62acd5919 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,7 +64,7 @@ const ( type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. @@ -80,6 +80,9 @@ type TransportProtocol interface { // NewEndpoint creates a new endpoint of the transport protocol. NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + // NewRawEndpoint creates a new raw endpoint of the transport protocol. + NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller // than this targeted at this protocol. @@ -113,8 +116,9 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) + // transport protocol endpoint. It also returns the network layer + // header for the enpoint to inspect or pass up the stack. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 252c79317..797489ad9 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -48,7 +48,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -437,7 +437,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -499,6 +499,18 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return t.proto.NewEndpoint(s, network, waiterQueue) } +// NewRawEndpoint creates a new raw transport layer endpoint of the given +// protocol. Raw endpoints receive all traffic for a given protocol regardless +// of address. +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + t, ok := s.transportProtocols[transport] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + return t.proto.NewRawEndpoint(s, network, waiterQueue) +} + // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error { @@ -934,6 +946,42 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip } } +// RegisterRawTransportEndpoint registers the given endpoint with the stack +// transport dispatcher. Received packets that match the provided protocol will +// be delivered to the given endpoint. +func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { + if nicID == 0 { + return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) +} + +// UnregisterRawTransportEndpoint removes the endpoint for the protocol from +// the stack transport dispatcher. +func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { + if nicID == 0 { + s.demux.unregisterRawEndpoint(netProtos, protocol, ep) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic != nil { + nic.demux.unregisterRawEndpoint(netProtos, protocol, ep) + } +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 163fadded..28743f3d5 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -97,7 +97,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c18208dc0..9ab314188 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -32,8 +32,12 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { + // mu protects all fields of the transportEndpoints. mu sync.RWMutex endpoints map[TransportEndpointID]TransportEndpoint + // rawEndpoints contains endpoints for raw sockets, which receive all + // traffic of a given protocol regardless of port. + rawEndpoints []TransportEndpoint } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -67,7 +71,9 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)} + d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + endpoints: make(map[TransportEndpointID]TransportEndpoint), + } } } @@ -131,22 +137,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { // If this is a broadcast datagram, deliver the datagram to all endpoints // managed by ep. if id.LocalAddress == header.IPv4Broadcast { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, netHeader, vv) break } vvCopy := buffer.NewView(vv.Size()) copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView()) } } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv) } } @@ -250,7 +256,7 @@ var loopbackSubnet = func() tcpip.Subnet { // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if it // found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -276,10 +282,21 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } + + // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via + // raw endpoint first. If there are multipe raw endpoints, they all + // receive the packet. + found := false + for _, rawEP := range eps.rawEndpoints { + // Each endpoint gets its own copy of the packet for the sake + // of save/restore. + rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + found = true + } eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { + if len(destEps) == 0 && !found { // UDP packet could not be delivered to an unknown destination port. if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() @@ -289,7 +306,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.HandlePacket(r, id, netHeader, vv) } return true @@ -349,3 +366,48 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer return nil } + +// registerRawEndpoint registers the given endpoint with the dispatcher such +// that packets of the appropriate protocol are delivered to it. A single +// packet can be sent to one or more raw endpoints along with a non-raw +// endpoint. +func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { + for i, n := range netProtos { + if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil { + d.unregisterRawEndpoint(netProtos[:i], protocol, ep) + return err + } + } + + return nil +} + +func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return nil + } + + eps.mu.Lock() + defer eps.mu.Unlock() + eps.rawEndpoints = append(eps.rawEndpoints, ep) + + return nil +} + +// unregisterRawEndpoint unregisters the raw endpoint for the given protocol +// such that it won't receive any more packets. +func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { + for _, n := range netProtos { + if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { + eps.mu.Lock() + defer eps.mu.Unlock() + for i, rawEP := range eps.rawEndpoints { + if rawEP == ep { + eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) + return + } + } + } + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index da460db77..3347b5599 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -214,6 +214,10 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto), nil } +func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + func (*fakeTransportProtocol) MinimumPacketSize() int { return fakeTransHeaderLen } |