summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/epsocket/BUILD1
-rw-r--r--pkg/sentry/socket/epsocket/provider.go28
-rw-r--r--pkg/tcpip/network/ip_test.go2
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go11
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go8
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go7
-rw-r--r--pkg/tcpip/stack/nic.go8
-rw-r--r--pkg/tcpip/stack/registration.go10
-rw-r--r--pkg/tcpip/stack/stack.go52
-rw-r--r--pkg/tcpip/stack/stack_test.go2
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go78
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go81
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go15
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go6
20 files changed, 272 insertions, 59 deletions
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD
index 45e418db3..44bb97b5b 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/epsocket/BUILD
@@ -27,6 +27,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/kdefs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 0184d8e3e..0d9c2df24 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -18,8 +18,10 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/syserr"
@@ -38,9 +40,9 @@ type provider struct {
netProto tcpip.NetworkProtocolNumber
}
-// GetTransportProtocol figures out transport protocol. Currently only TCP,
+// getTransportProtocol figures out transport protocol. Currently only TCP,
// UDP, and ICMP are supported.
-func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
switch stype {
case linux.SOCK_STREAM:
if protocol != 0 && protocol != syscall.IPPROTO_TCP {
@@ -57,6 +59,18 @@ func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.Transpo
case syscall.IPPROTO_ICMPV6:
return header.ICMPv6ProtocolNumber, nil
}
+
+ case linux.SOCK_RAW:
+ // Raw sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ switch protocol {
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, nil
+ }
}
return 0, syserr.ErrInvalidArgument
}
@@ -76,14 +90,20 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int
}
// Figure out the transport protocol.
- transProto, err := GetTransportProtocol(stype, protocol)
+ transProto, err := getTransportProtocol(t, stype, protocol)
if err != nil {
return nil, err
}
// Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
wq := &waiter.Queue{}
- ep, e := eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ }
if e != nil {
return nil, syserr.TranslateNetstackError(e)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 5c1e88e56..97a43aece 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -94,7 +94,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index f82dc098f..ea8392c98 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -55,7 +55,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
}
-func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv4MinimumSize {
return
@@ -67,19 +67,22 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
if len(v) < header.ICMPv4EchoMinimumSize {
return
}
- vv.TrimFront(header.ICMPv4MinimumSize)
- req := echoRequest{r: r.Clone(), v: vv.ToView()}
+ echoPayload := vv.ToView()
+ echoPayload.TrimFront(header.ICMPv4MinimumSize)
+ req := echoRequest{r: r.Clone(), v: echoPayload}
select {
case e.echoRequests <- req:
default:
req.r.Release()
}
+ // It's possible that a raw socket expects to receive this.
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
case header.ICMPv4EchoReply:
if len(v) < header.ICMPv4EchoMinimumSize {
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
case header.ICMPv4DstUnreachable:
if len(v) < header.ICMPv4DstUnreachableMinimumSize {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 0c41519df..bfc3c08fa 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -131,7 +131,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- h := header.IPv4(vv.First())
+ headerView := vv.First()
+ h := header.IPv4(headerView)
if !h.IsValid(vv.Size()) {
return
}
@@ -153,11 +154,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
}
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
- e.handleICMP(r, vv)
+ headerView.CapLength(hlen)
+ e.handleICMP(r, headerView, vv)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
}
// Close cleans up resources associated with the endpoint.
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 14107443b..5a3c17768 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
}
-func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv6MinimumSize {
return
@@ -148,7 +148,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
if len(v) < header.ICMPv6EchoMinimumSize {
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv)
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 4d0b6ee9c..5f68ef7d5 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -102,7 +102,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- h := header.IPv6(vv.First())
+ headerView := vv.First()
+ h := header.IPv6(headerView)
if !h.IsValid(vv.Size()) {
return
}
@@ -112,12 +113,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
p := h.TransportProtocol()
if p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, vv)
+ e.handleICMP(r, headerView, vv)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
}
// Close cleans up resources associated with the endpoint.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 2278fbf65..79f845225 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -505,7 +505,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -525,16 +525,16 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, vv, id) {
+ if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
- if n.stack.demux.deliverPacket(r, protocol, vv, id) {
+ if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, vv) {
+ if state.defaultHandler(r, id, netHeader, vv) {
return
}
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 5accffa1b..62acd5919 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -64,7 +64,7 @@ const (
type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+ HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
@@ -80,6 +80,9 @@ type TransportProtocol interface {
// NewEndpoint creates a new endpoint of the transport protocol.
NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ // NewRawEndpoint creates a new raw endpoint of the transport protocol.
+ NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
// than this targeted at this protocol.
@@ -113,8 +116,9 @@ type TransportProtocol interface {
// the network layer.
type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
- // transport protocol endpoint.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView)
+ // transport protocol endpoint. It also returns the network layer
+ // header for the enpoint to inspect or pass up the stack.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 252c79317..797489ad9 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -48,7 +48,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool
+ defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -437,7 +437,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -499,6 +499,18 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return t.proto.NewEndpoint(s, network, waiterQueue)
}
+// NewRawEndpoint creates a new raw transport layer endpoint of the given
+// protocol. Raw endpoints receive all traffic for a given protocol regardless
+// of address.
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewRawEndpoint(s, network, waiterQueue)
+}
+
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
@@ -934,6 +946,42 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
}
}
+// RegisterRawTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided protocol will
+// be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+}
+
+// UnregisterRawTransportEndpoint removes the endpoint for the protocol from
+// the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+ if nicID == 0 {
+ s.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ }
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 163fadded..28743f3d5 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -97,7 +97,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index c18208dc0..9ab314188 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -32,8 +32,12 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
+ // mu protects all fields of the transportEndpoints.
mu sync.RWMutex
endpoints map[TransportEndpointID]TransportEndpoint
+ // rawEndpoints contains endpoints for raw sockets, which receive all
+ // traffic of a given protocol regardless of port.
+ rawEndpoints []TransportEndpoint
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
@@ -67,7 +71,9 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ }
}
}
@@ -131,22 +137,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
// If this is a broadcast datagram, deliver the datagram to all endpoints
// managed by ep.
if id.LocalAddress == header.IPv4Broadcast {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
+ endpoint.HandlePacket(r, id, netHeader, vv)
break
}
vvCopy := buffer.NewView(vv.Size())
copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView())
}
} else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv)
}
}
@@ -250,7 +256,7 @@ var loopbackSubnet = func() tcpip.Subnet {
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if it
// found one or more endpoints, false otherwise.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -276,10 +282,21 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
+
+ // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
+ // raw endpoint first. If there are multipe raw endpoints, they all
+ // receive the packet.
+ found := false
+ for _, rawEP := range eps.rawEndpoints {
+ // Each endpoint gets its own copy of the packet for the sake
+ // of save/restore.
+ rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ found = true
+ }
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
+ if len(destEps) == 0 && !found {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
@@ -289,7 +306,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.HandlePacket(r, id, netHeader, vv)
}
return true
@@ -349,3 +366,48 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
return nil
}
+
+// registerRawEndpoint registers the given endpoint with the dispatcher such
+// that packets of the appropriate protocol are delivered to it. A single
+// packet can be sent to one or more raw endpoints along with a non-raw
+// endpoint.
+func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil {
+ d.unregisterRawEndpoint(netProtos[:i], protocol, ep)
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ eps.rawEndpoints = append(eps.rawEndpoints, ep)
+
+ return nil
+}
+
+// unregisterRawEndpoint unregisters the raw endpoint for the given protocol
+// such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
+ }
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index da460db77..3347b5599 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -214,6 +214,10 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto), nil
}
+func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
func (*fakeTransportProtocol) MinimumPacketSize() int {
return fakeTransHeaderLen
}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d87bfe048..b3b7a1d0e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -46,17 +46,23 @@ const (
stateClosed
)
-// endpoint represents an ICMP (ping) endpoint. This struct serves as the
-// interface between users of the endpoint and the protocol implementation; it
-// is legal to have concurrent goroutines make calls into the endpoint, they
-// are properly synchronized.
+// endpoint represents an ICMP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// +stateify savable
type endpoint struct {
- // The following fields are initialized at creation time and do not
- // change throughout the lifetime of the endpoint.
+ // The following fields are initialized at creation time and are
+ // immutable.
stack *stack.Stack `state:"manual"`
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
+ // raw indicates whether the endpoint is intended for use by a raw
+ // socket, which returns the network layer header along with the
+ // payload. It is immutable.
+ raw bool
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -80,15 +86,26 @@ type endpoint struct {
route stack.Route `state:"manual"`
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) {
+ e := &endpoint{
stack: stack,
netProto: netProto,
transProto: transProto,
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ raw: raw,
+ }
+
+ // Raw endpoints must be immediately bound because they receive all
+ // ICMP traffic starting from when they're created via socket().
+ if raw {
+ if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
+ return nil, err
+ }
}
+
+ return e, nil
}
// Close puts the endpoint in a closed state and frees all resources
@@ -98,7 +115,11 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ if e.raw {
+ e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e)
+ } else {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ }
}
// Close the receive list and drain it.
@@ -285,10 +306,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
switch e.netProto {
case header.IPv4ProtocolNumber:
- err = sendPing4(route, e.id.LocalPort, v)
+ err = e.send4(route, v)
case header.IPv6ProtocolNumber:
- err = sendPing6(route, e.id.LocalPort, v)
+ err = send6(route, e.id.LocalPort, v)
}
if err != nil {
@@ -346,13 +367,19 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
+ if e.raw {
+ hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength()))
+ return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+ }
+
if len(data) < header.ICMPv4EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident. Sequence number is provided by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort)
hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
@@ -371,7 +398,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
-func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
@@ -412,6 +439,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO: We don't yet support connect on a raw socket.
+ if e.raw {
+ return tcpip.ErrNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -515,6 +547,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if e.raw {
+ err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false)
+ return stack.TransportEndpointID{}, err
+ }
+
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
@@ -657,7 +694,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -675,9 +712,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
Addr: id.RemoteAddress,
},
}
- pkt.data = vv.Clone(pkt.views[:])
+
+ if e.raw {
+ combinedVV := netHeader.ToVectorisedView()
+ combinedVV.Append(vv)
+ pkt.data = combinedVV.Clone(pkt.views[:])
+ } else {
+ pkt.data = vv.Clone(pkt.views[:])
+ }
+
e.rcvList.PushBack(pkt)
- e.rcvBufSize += vv.Size()
+ e.rcvBufSize += pkt.data.Size()
pkt.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9f0a2bf71..36b70988a 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -47,6 +47,7 @@ const (
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
+// protocol implements stack.TransportProtocol.
type protocol struct {
number tcpip.TransportProtocolNumber
}
@@ -66,12 +67,22 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
-// NewEndpoint creates a new icmp endpoint.
+// NewEndpoint creates a new icmp endpoint. It implements
+// stack.TransportProtocol.NewEndpoint.
func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue), nil
+ return newEndpoint(stack, netProto, p.number, waiterQueue, false)
+}
+
+// NewRawEndpoint creates a new raw icmp endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue, true)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a8618bb4a..c48a27d8f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1441,7 +1441,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
s := newSegment(r, id, vv)
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2f90839e9..ca53a076f 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -63,7 +63,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 753e1419e..639ad3fae 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -101,6 +101,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN
return newEndpoint(stack, netProto, waiterQueue), nil
}
+// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
// MinimumPacketSize returns the minimum valid tcp packet size.
func (*protocol) MinimumPacketSize() int {
return header.TCPMinimumSize
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 05d35e526..44b9cdf6a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -934,7 +934,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
// Get the header then trim it from the view.
hdr := header.UDP(vv.First())
if int(hdr.Length()) > vv.Size() {
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index b3fbed6e4..616a9f388 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -48,6 +48,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN
return newEndpoint(stack, netProto, waiterQueue), nil
}
+// NewRawEndpoint creates a new raw UDP endpoint. Raw UDP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
// MinimumPacketSize returns the minimum valid udp packet size.
func (*protocol) MinimumPacketSize() int {
return header.UDPMinimumSize