summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2019-02-27 14:30:20 -0800
committerShentubot <shentubot@google.com>2019-02-27 14:31:21 -0800
commit121db29a93c651b8b62e8701bb0f16c231b08257 (patch)
treec6b235b72340de44c8ee0e9398d337a8cc6c30f3 /pkg/tcpip/stack
parent6df212b831dcc3350b7677423ec7835ed40b3f22 (diff)
Ping support via IPv4 raw sockets.
Broadly, this change: * Enables sockets to be created via `socket(AF_INET, SOCK_RAW, IPPROTO_ICMP)`. * Passes the network-layer (IP) header up the stack to the transport endpoint, which can pass it up to the socket layer. This allows a raw socket to return the entire IP packet to users. * Adds functions to stack.TransportProtocol, stack.Stack, stack.transportDemuxer that enable incoming packets to be delivered to raw endpoints. New raw sockets of other protocols (not ICMP) just need to register with the stack. * Enables ping.endpoint to return IP headers when created via SOCK_RAW. PiperOrigin-RevId: 235993280 Change-Id: I60ed994f5ff18b2cbd79f063a7fdf15d093d845a
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go8
-rw-r--r--pkg/tcpip/stack/registration.go10
-rw-r--r--pkg/tcpip/stack/stack.go52
-rw-r--r--pkg/tcpip/stack/stack_test.go2
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go78
-rw-r--r--pkg/tcpip/stack/transport_test.go6
6 files changed, 137 insertions, 19 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 2278fbf65..79f845225 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -505,7 +505,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -525,16 +525,16 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, vv, id) {
+ if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
- if n.stack.demux.deliverPacket(r, protocol, vv, id) {
+ if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, vv) {
+ if state.defaultHandler(r, id, netHeader, vv) {
return
}
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 5accffa1b..62acd5919 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -64,7 +64,7 @@ const (
type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+ HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
@@ -80,6 +80,9 @@ type TransportProtocol interface {
// NewEndpoint creates a new endpoint of the transport protocol.
NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ // NewRawEndpoint creates a new raw endpoint of the transport protocol.
+ NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
// than this targeted at this protocol.
@@ -113,8 +116,9 @@ type TransportProtocol interface {
// the network layer.
type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
- // transport protocol endpoint.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView)
+ // transport protocol endpoint. It also returns the network layer
+ // header for the enpoint to inspect or pass up the stack.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 252c79317..797489ad9 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -48,7 +48,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool
+ defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -437,7 +437,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -499,6 +499,18 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return t.proto.NewEndpoint(s, network, waiterQueue)
}
+// NewRawEndpoint creates a new raw transport layer endpoint of the given
+// protocol. Raw endpoints receive all traffic for a given protocol regardless
+// of address.
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewRawEndpoint(s, network, waiterQueue)
+}
+
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
@@ -934,6 +946,42 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
}
}
+// RegisterRawTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided protocol will
+// be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+}
+
+// UnregisterRawTransportEndpoint removes the endpoint for the protocol from
+// the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+ if nicID == 0 {
+ s.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ }
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 163fadded..28743f3d5 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -97,7 +97,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index c18208dc0..9ab314188 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -32,8 +32,12 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
+ // mu protects all fields of the transportEndpoints.
mu sync.RWMutex
endpoints map[TransportEndpointID]TransportEndpoint
+ // rawEndpoints contains endpoints for raw sockets, which receive all
+ // traffic of a given protocol regardless of port.
+ rawEndpoints []TransportEndpoint
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
@@ -67,7 +71,9 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ }
}
}
@@ -131,22 +137,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
// If this is a broadcast datagram, deliver the datagram to all endpoints
// managed by ep.
if id.LocalAddress == header.IPv4Broadcast {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
+ endpoint.HandlePacket(r, id, netHeader, vv)
break
}
vvCopy := buffer.NewView(vv.Size())
copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView())
}
} else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv)
}
}
@@ -250,7 +256,7 @@ var loopbackSubnet = func() tcpip.Subnet {
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if it
// found one or more endpoints, false otherwise.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -276,10 +282,21 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
+
+ // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
+ // raw endpoint first. If there are multipe raw endpoints, they all
+ // receive the packet.
+ found := false
+ for _, rawEP := range eps.rawEndpoints {
+ // Each endpoint gets its own copy of the packet for the sake
+ // of save/restore.
+ rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ found = true
+ }
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
+ if len(destEps) == 0 && !found {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
@@ -289,7 +306,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.HandlePacket(r, id, netHeader, vv)
}
return true
@@ -349,3 +366,48 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
return nil
}
+
+// registerRawEndpoint registers the given endpoint with the dispatcher such
+// that packets of the appropriate protocol are delivered to it. A single
+// packet can be sent to one or more raw endpoints along with a non-raw
+// endpoint.
+func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil {
+ d.unregisterRawEndpoint(netProtos[:i], protocol, ep)
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ eps.rawEndpoints = append(eps.rawEndpoints, ep)
+
+ return nil
+}
+
+// unregisterRawEndpoint unregisters the raw endpoint for the given protocol
+// such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
+ }
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index da460db77..3347b5599 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -214,6 +214,10 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto), nil
}
+func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
func (*fakeTransportProtocol) MinimumPacketSize() int {
return fakeTransHeaderLen
}