From 9b4d3280e172063a6563d9e72a75b500442ed9b9 Mon Sep 17 00:00:00 2001
From: Kevin Krakauer <krakauer@google.com>
Date: Fri, 12 Jul 2019 18:08:03 -0700
Subject: Add IPPROTO_RAW, which allows raw sockets to write IP headers.

iptables also relies on IPPROTO_RAW in a way. It opens such a socket to
manipulate the kernel's tables, but it doesn't actually use any of the
functionality. Blegh.

PiperOrigin-RevId: 257903078
---
 pkg/tcpip/stack/registration.go | 20 ++++++++++++++++++++
 pkg/tcpip/stack/route.go        | 12 ++++++++++++
 pkg/tcpip/stack/stack.go        | 10 +++++++++-
 pkg/tcpip/stack/stack_test.go   |  4 ++++
 4 files changed, 45 insertions(+), 1 deletion(-)

(limited to 'pkg/tcpip/stack')

diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0ecaa0833..462265281 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -174,6 +174,10 @@ type NetworkEndpoint interface {
 	// protocol.
 	WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
 
+	// WriteHeaderIncludedPacket writes a packet that includes a network
+	// header to the given destination address.
+	WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error
+
 	// ID returns the network protocol endpoint ID.
 	ID() *NetworkEndpointID
 
@@ -357,10 +361,19 @@ type TransportProtocolFactory func() TransportProtocol
 // instantiate network protocols.
 type NetworkProtocolFactory func() NetworkProtocol
 
+// UnassociatedEndpointFactory produces endpoints for writing packets not
+// associated with a particular transport protocol. Such endpoints can be used
+// to write arbitrary packets that include the IP header.
+type UnassociatedEndpointFactory interface {
+	NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+}
+
 var (
 	transportProtocols = make(map[string]TransportProtocolFactory)
 	networkProtocols   = make(map[string]NetworkProtocolFactory)
 
+	unassociatedFactory UnassociatedEndpointFactory
+
 	linkEPMu           sync.RWMutex
 	nextLinkEndpointID tcpip.LinkEndpointID = 1
 	linkEndpoints                           = make(map[tcpip.LinkEndpointID]LinkEndpoint)
@@ -380,6 +393,13 @@ func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
 	networkProtocols[name] = p
 }
 
+// RegisterUnassociatedFactory registers a factory to produce endpoints not
+// associated with any particular transport protocol. This function is intended
+// to be called by init() functions of the protocols.
+func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
+	unassociatedFactory = f
+}
+
 // RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
 // ID that can be used to refer to it.
 func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 36d7b6ac7..391ab4344 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -163,6 +163,18 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
 	return err
 }
 
+// WriteHeaderIncludedPacket writes a packet already containing a network
+// header through the given route.
+func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+	if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
+		r.Stats().IP.OutgoingPacketErrors.Increment()
+		return err
+	}
+	r.ref.nic.stats.Tx.Packets.Increment()
+	r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size()))
+	return nil
+}
+
 // DefaultTTL returns the default TTL of the underlying network endpoint.
 func (r *Route) DefaultTTL() uint8 {
 	return r.ref.ep.DefaultTTL()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 2d7f56ca9..3e8fb2a6c 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -340,6 +340,8 @@ type Stack struct {
 	networkProtocols   map[tcpip.NetworkProtocolNumber]NetworkProtocol
 	linkAddrResolvers  map[tcpip.NetworkProtocolNumber]LinkAddressResolver
 
+	unassociatedFactory UnassociatedEndpointFactory
+
 	demux *transportDemuxer
 
 	stats tcpip.Stats
@@ -442,6 +444,8 @@ func New(network []string, transport []string, opts Options) *Stack {
 		}
 	}
 
+	s.unassociatedFactory = unassociatedFactory
+
 	// Create the global transport demuxer.
 	s.demux = newTransportDemuxer(s)
 
@@ -574,11 +578,15 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
 // NewRawEndpoint creates a new raw transport layer endpoint of the given
 // protocol. Raw endpoints receive all traffic for a given protocol regardless
 // of address.
-func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
 	if !s.raw {
 		return nil, tcpip.ErrNotPermitted
 	}
 
+	if !associated {
+		return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue)
+	}
+
 	t, ok := s.transportProtocols[transport]
 	if !ok {
 		return nil, tcpip.ErrUnknownProtocol
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 69884af03..959071dbe 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -137,6 +137,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
 	return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
 }
 
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+	return tcpip.ErrNotSupported
+}
+
 func (*fakeNetworkEndpoint) Close() {}
 
 type fakeNetGoodOption bool
-- 
cgit v1.2.3