diff options
author | Kevin Krakauer <krakauer@google.com> | 2019-04-02 11:12:29 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-02 11:13:49 -0700 |
commit | 52a51a8e20b3e5c28eb1e66bd57203216cf644c5 (patch) | |
tree | 4d98f67dc4abb15e6de568b2af12a3e12f41aa99 /pkg/tcpip/stack | |
parent | 1df3fa69977477092efa65a8de407bd6f0f88db4 (diff) |
Add a raw socket transport endpoint and use it for raw ICMP sockets.
Having raw socket code together will make it easier to add support for other raw
network protocols. Currently, only ICMP uses the raw endpoint. However, adding
support for other protocols such as UDP shouldn't be much more difficult than
adding a few switch cases.
PiperOrigin-RevId: 241564875
Change-Id: I77e03adafe4ce0fd29ba2d5dfdc547d2ae8f25bf
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/registration.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 |
4 files changed, 53 insertions, 49 deletions
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ff356ea22..f3cc849ec 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,13 +64,24 @@ const ( type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) + HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) } +// RawTransportEndpoint is the interface that needs to be implemented by raw +// transport protocol endpoints. RawTransportEndpoints receive the entire +// packet - including the link, network, and transport headers - as delivered +// to netstack. +type RawTransportEndpoint interface { + // HandlePacket is called by the stack when new packets arrive to + // this transport endpoint. The packet contains all data from the link + // layer up. + HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) +} + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 15a268b10..a74c0a7a0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -955,11 +955,11 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip } // RegisterRawTransportEndpoint registers the given endpoint with the stack -// transport dispatcher. Received packets that match the provided protocol will -// be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { +// transport dispatcher. Received packets that match the provided transport +// protocol will be delivered to the given endpoint. +func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { if nicID == 0 { - return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } s.mu.RLock() @@ -970,14 +970,14 @@ func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpi return tcpip.ErrUnknownNICID } - return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + return nic.demux.registerRawEndpoint(netProto, transProto, ep) } -// UnregisterRawTransportEndpoint removes the endpoint for the protocol from -// the stack transport dispatcher. -func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { +// UnregisterRawTransportEndpoint removes the endpoint for the transport +// protocol from the stack transport dispatcher. +func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { if nicID == 0 { - s.demux.unregisterRawEndpoint(netProtos, protocol, ep) + s.demux.unregisterRawEndpoint(netProto, transProto, ep) return } @@ -986,7 +986,7 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tc nic := s.nics[nicID] if nic != nil { - nic.demux.unregisterRawEndpoint(netProtos, protocol, ep) + nic.demux.unregisterRawEndpoint(netProto, transProto, ep) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9ab314188..a8ac18e72 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "math/rand" "sync" @@ -37,7 +38,7 @@ type transportEndpoints struct { endpoints map[TransportEndpointID]TransportEndpoint // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. - rawEndpoints []TransportEndpoint + rawEndpoints []RawTransportEndpoint } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -60,8 +61,10 @@ func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep Tra // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then -// based on endpoints IDs. +// based on endpoints IDs. It should only be instantiated via +// newTransportDemuxer. type transportDemuxer struct { + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints } @@ -137,22 +140,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { // If this is a broadcast datagram, deliver the datagram to all endpoints // managed by ep. if id.LocalAddress == header.IPv4Broadcast { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, netHeader, vv) + endpoint.HandlePacket(r, id, vv) break } vvCopy := buffer.NewView(vv.Size()) copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } } else { - ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv) + ep.selectEndpoint(id).HandlePacket(r, id, vv) } } @@ -286,17 +289,17 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via // raw endpoint first. If there are multipe raw endpoints, they all // receive the packet. - found := false + foundRaw := false for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) - found = true + rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + foundRaw = true } eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 && !found { + if len(destEps) == 0 && !foundRaw { // UDP packet could not be delivered to an unknown destination port. if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() @@ -306,7 +309,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, netHeader, vv) + ep.HandlePacket(r, id, vv) } return true @@ -371,19 +374,8 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw // endpoint. -func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { - for i, n := range netProtos { - if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil { - d.unregisterRawEndpoint(netProtos[:i], protocol, ep) - return err - } - } - - return nil -} - -func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error { - eps, ok := d.protocol[protocolIDs{netProto, protocol}] +func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil } @@ -395,19 +387,20 @@ func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProto return nil } -// unregisterRawEndpoint unregisters the raw endpoint for the given protocol -// such that it won't receive any more packets. -func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { - for _, n := range netProtos { - if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.mu.Lock() - defer eps.mu.Unlock() - for i, rawEP := range eps.rawEndpoints { - if rawEP == ep { - eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) - return - } - } +// unregisterRawEndpoint unregisters the raw endpoint for the given transport +// protocol such that it won't receive any more packets. +func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto)) + } + + eps.mu.Lock() + defer eps.mu.Unlock() + for i, rawEP := range eps.rawEndpoints { + if rawEP == ep { + eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) + return } } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index dfd31557a..0c2589083 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { |