diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 78 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 81 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 6 |
18 files changed, 247 insertions, 55 deletions
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 5c1e88e56..97a43aece 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -94,7 +94,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index f82dc098f..ea8392c98 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -55,7 +55,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) } -func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { v := vv.First() if len(v) < header.ICMPv4MinimumSize { return @@ -67,19 +67,22 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { if len(v) < header.ICMPv4EchoMinimumSize { return } - vv.TrimFront(header.ICMPv4MinimumSize) - req := echoRequest{r: r.Clone(), v: vv.ToView()} + echoPayload := vv.ToView() + echoPayload.TrimFront(header.ICMPv4MinimumSize) + req := echoRequest{r: r.Clone(), v: echoPayload} select { case e.echoRequests <- req: default: req.r.Release() } + // It's possible that a raw socket expects to receive this. + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) case header.ICMPv4EchoReply: if len(v) < header.ICMPv4EchoMinimumSize { return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) case header.ICMPv4DstUnreachable: if len(v) < header.ICMPv4DstUnreachableMinimumSize { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 0c41519df..bfc3c08fa 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -131,7 +131,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) + headerView := vv.First() + h := header.IPv4(headerView) if !h.IsValid(vv.Size()) { return } @@ -153,11 +154,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - e.handleICMP(r, vv) + headerView.CapLength(hlen) + e.handleICMP(r, headerView, vv) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, vv) + e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 14107443b..5a3c17768 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) } -func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { v := vv.First() if len(v) < header.ICMPv6MinimumSize { return @@ -148,7 +148,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { if len(v) < header.ICMPv6EchoMinimumSize { return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4d0b6ee9c..5f68ef7d5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -102,7 +102,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) + headerView := vv.First() + h := header.IPv6(headerView) if !h.IsValid(vv.Size()) { return } @@ -112,12 +113,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, vv) + e.handleICMP(r, headerView, vv) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, vv) + e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 2278fbf65..79f845225 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -505,7 +505,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -525,16 +525,16 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, vv, id) { + if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } - if n.stack.demux.deliverPacket(r, protocol, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, vv) { + if state.defaultHandler(r, id, netHeader, vv) { return } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5accffa1b..62acd5919 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,7 +64,7 @@ const ( type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. @@ -80,6 +80,9 @@ type TransportProtocol interface { // NewEndpoint creates a new endpoint of the transport protocol. NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + // NewRawEndpoint creates a new raw endpoint of the transport protocol. + NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller // than this targeted at this protocol. @@ -113,8 +116,9 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) + // transport protocol endpoint. It also returns the network layer + // header for the enpoint to inspect or pass up the stack. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 252c79317..797489ad9 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -48,7 +48,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -437,7 +437,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -499,6 +499,18 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return t.proto.NewEndpoint(s, network, waiterQueue) } +// NewRawEndpoint creates a new raw transport layer endpoint of the given +// protocol. Raw endpoints receive all traffic for a given protocol regardless +// of address. +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + t, ok := s.transportProtocols[transport] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + return t.proto.NewRawEndpoint(s, network, waiterQueue) +} + // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error { @@ -934,6 +946,42 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip } } +// RegisterRawTransportEndpoint registers the given endpoint with the stack +// transport dispatcher. Received packets that match the provided protocol will +// be delivered to the given endpoint. +func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { + if nicID == 0 { + return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) +} + +// UnregisterRawTransportEndpoint removes the endpoint for the protocol from +// the stack transport dispatcher. +func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { + if nicID == 0 { + s.demux.unregisterRawEndpoint(netProtos, protocol, ep) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic != nil { + nic.demux.unregisterRawEndpoint(netProtos, protocol, ep) + } +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 163fadded..28743f3d5 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -97,7 +97,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c18208dc0..9ab314188 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -32,8 +32,12 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { + // mu protects all fields of the transportEndpoints. mu sync.RWMutex endpoints map[TransportEndpointID]TransportEndpoint + // rawEndpoints contains endpoints for raw sockets, which receive all + // traffic of a given protocol regardless of port. + rawEndpoints []TransportEndpoint } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -67,7 +71,9 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)} + d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + endpoints: make(map[TransportEndpointID]TransportEndpoint), + } } } @@ -131,22 +137,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { // If this is a broadcast datagram, deliver the datagram to all endpoints // managed by ep. if id.LocalAddress == header.IPv4Broadcast { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, netHeader, vv) break } vvCopy := buffer.NewView(vv.Size()) copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView()) } } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv) } } @@ -250,7 +256,7 @@ var loopbackSubnet = func() tcpip.Subnet { // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if it // found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -276,10 +282,21 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } + + // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via + // raw endpoint first. If there are multipe raw endpoints, they all + // receive the packet. + found := false + for _, rawEP := range eps.rawEndpoints { + // Each endpoint gets its own copy of the packet for the sake + // of save/restore. + rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + found = true + } eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { + if len(destEps) == 0 && !found { // UDP packet could not be delivered to an unknown destination port. if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() @@ -289,7 +306,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.HandlePacket(r, id, netHeader, vv) } return true @@ -349,3 +366,48 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer return nil } + +// registerRawEndpoint registers the given endpoint with the dispatcher such +// that packets of the appropriate protocol are delivered to it. A single +// packet can be sent to one or more raw endpoints along with a non-raw +// endpoint. +func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { + for i, n := range netProtos { + if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil { + d.unregisterRawEndpoint(netProtos[:i], protocol, ep) + return err + } + } + + return nil +} + +func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return nil + } + + eps.mu.Lock() + defer eps.mu.Unlock() + eps.rawEndpoints = append(eps.rawEndpoints, ep) + + return nil +} + +// unregisterRawEndpoint unregisters the raw endpoint for the given protocol +// such that it won't receive any more packets. +func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { + for _, n := range netProtos { + if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { + eps.mu.Lock() + defer eps.mu.Unlock() + for i, rawEP := range eps.rawEndpoints { + if rawEP == ep { + eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) + return + } + } + } + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index da460db77..3347b5599 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -214,6 +214,10 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto), nil } +func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + func (*fakeTransportProtocol) MinimumPacketSize() int { return fakeTransHeaderLen } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d87bfe048..b3b7a1d0e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -46,17 +46,23 @@ const ( stateClosed ) -// endpoint represents an ICMP (ping) endpoint. This struct serves as the -// interface between users of the endpoint and the protocol implementation; it -// is legal to have concurrent goroutines make calls into the endpoint, they -// are properly synchronized. +// endpoint represents an ICMP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// +stateify savable type endpoint struct { - // The following fields are initialized at creation time and do not - // change throughout the lifetime of the endpoint. + // The following fields are initialized at creation time and are + // immutable. stack *stack.Stack `state:"manual"` netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + // raw indicates whether the endpoint is intended for use by a raw + // socket, which returns the network layer header along with the + // payload. It is immutable. + raw bool // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -80,15 +86,26 @@ type endpoint struct { route stack.Route `state:"manual"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - return &endpoint{ +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) { + e := &endpoint{ stack: stack, netProto: netProto, transProto: transProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + raw: raw, + } + + // Raw endpoints must be immediately bound because they receive all + // ICMP traffic starting from when they're created via socket(). + if raw { + if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil { + return nil, err + } } + + return e, nil } // Close puts the endpoint in a closed state and frees all resources @@ -98,7 +115,11 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + if e.raw { + e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e) + } else { + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + } } // Close the receive list and drain it. @@ -285,10 +306,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c switch e.netProto { case header.IPv4ProtocolNumber: - err = sendPing4(route, e.id.LocalPort, v) + err = e.send4(route, v) case header.IPv6ProtocolNumber: - err = sendPing6(route, e.id.LocalPort, v) + err = send6(route, e.id.LocalPort, v) } if err != nil { @@ -346,13 +367,19 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { + if e.raw { + hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength())) + return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + } + if len(data) < header.ICMPv4EchoMinimumSize { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort) hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) @@ -371,7 +398,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) } -func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -412,6 +439,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // TODO: We don't yet support connect on a raw socket. + if e.raw { + return tcpip.ErrNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -515,6 +547,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if e.raw { + err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false) + return stack.TransportEndpointID{}, err + } + if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -657,7 +694,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -675,9 +712,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + + if e.raw { + combinedVV := netHeader.ToVectorisedView() + combinedVV.Append(vv) + pkt.data = combinedVV.Clone(pkt.views[:]) + } else { + pkt.data = vv.Clone(pkt.views[:]) + } + e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + e.rcvBufSize += pkt.data.Size() pkt.timestamp = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9f0a2bf71..36b70988a 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -47,6 +47,7 @@ const ( ProtocolNumber6 = header.ICMPv6ProtocolNumber ) +// protocol implements stack.TransportProtocol. type protocol struct { number tcpip.TransportProtocolNumber } @@ -66,12 +67,22 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { panic(fmt.Sprint("unknown protocol number: ", p.number)) } -// NewEndpoint creates a new icmp endpoint. +// NewEndpoint creates a new icmp endpoint. It implements +// stack.TransportProtocol.NewEndpoint. func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue), nil + return newEndpoint(stack, netProto, p.number, waiterQueue, false) +} + +// NewRawEndpoint creates a new raw icmp endpoint. It implements +// stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return newEndpoint(stack, netProto, p.number, waiterQueue, true) } // MinimumPacketSize returns the minimum valid icmp packet size. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a8618bb4a..c48a27d8f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1441,7 +1441,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { s := newSegment(r, id, vv) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2f90839e9..ca53a076f 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -63,7 +63,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 753e1419e..639ad3fae 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -101,6 +101,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN return newEndpoint(stack, netProto, waiterQueue), nil } +// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + // MinimumPacketSize returns the minimum valid tcp packet size. func (*protocol) MinimumPacketSize() int { return header.TCPMinimumSize diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 05d35e526..44b9cdf6a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -934,7 +934,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { // Get the header then trim it from the view. hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index b3fbed6e4..616a9f388 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -48,6 +48,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN return newEndpoint(stack, netProto, waiterQueue), nil } +// NewRawEndpoint creates a new raw UDP endpoint. Raw UDP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + // MinimumPacketSize returns the minimum valid udp packet size. func (*protocol) MinimumPacketSize() int { return header.UDPMinimumSize |