summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/epsocket/BUILD1
-rw-r--r--pkg/sentry/socket/epsocket/provider.go28
-rw-r--r--pkg/tcpip/network/ip_test.go2
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go11
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go8
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go7
-rw-r--r--pkg/tcpip/stack/nic.go8
-rw-r--r--pkg/tcpip/stack/registration.go10
-rw-r--r--pkg/tcpip/stack/stack.go52
-rw-r--r--pkg/tcpip/stack/stack_test.go2
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go78
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go81
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go15
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go6
-rw-r--r--test/syscalls/linux/BUILD16
-rw-r--r--test/syscalls/linux/raw_socket_ipv4.cc245
22 files changed, 533 insertions, 59 deletions
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD
index 45e418db3..44bb97b5b 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/epsocket/BUILD
@@ -27,6 +27,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/kdefs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 0184d8e3e..0d9c2df24 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -18,8 +18,10 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/syserr"
@@ -38,9 +40,9 @@ type provider struct {
netProto tcpip.NetworkProtocolNumber
}
-// GetTransportProtocol figures out transport protocol. Currently only TCP,
+// getTransportProtocol figures out transport protocol. Currently only TCP,
// UDP, and ICMP are supported.
-func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
switch stype {
case linux.SOCK_STREAM:
if protocol != 0 && protocol != syscall.IPPROTO_TCP {
@@ -57,6 +59,18 @@ func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.Transpo
case syscall.IPPROTO_ICMPV6:
return header.ICMPv6ProtocolNumber, nil
}
+
+ case linux.SOCK_RAW:
+ // Raw sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ switch protocol {
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, nil
+ }
}
return 0, syserr.ErrInvalidArgument
}
@@ -76,14 +90,20 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int
}
// Figure out the transport protocol.
- transProto, err := GetTransportProtocol(stype, protocol)
+ transProto, err := getTransportProtocol(t, stype, protocol)
if err != nil {
return nil, err
}
// Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
wq := &waiter.Queue{}
- ep, e := eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ }
if e != nil {
return nil, syserr.TranslateNetstackError(e)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 5c1e88e56..97a43aece 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -94,7 +94,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index f82dc098f..ea8392c98 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -55,7 +55,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
}
-func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv4MinimumSize {
return
@@ -67,19 +67,22 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
if len(v) < header.ICMPv4EchoMinimumSize {
return
}
- vv.TrimFront(header.ICMPv4MinimumSize)
- req := echoRequest{r: r.Clone(), v: vv.ToView()}
+ echoPayload := vv.ToView()
+ echoPayload.TrimFront(header.ICMPv4MinimumSize)
+ req := echoRequest{r: r.Clone(), v: echoPayload}
select {
case e.echoRequests <- req:
default:
req.r.Release()
}
+ // It's possible that a raw socket expects to receive this.
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
case header.ICMPv4EchoReply:
if len(v) < header.ICMPv4EchoMinimumSize {
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
case header.ICMPv4DstUnreachable:
if len(v) < header.ICMPv4DstUnreachableMinimumSize {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 0c41519df..bfc3c08fa 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -131,7 +131,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- h := header.IPv4(vv.First())
+ headerView := vv.First()
+ h := header.IPv4(headerView)
if !h.IsValid(vv.Size()) {
return
}
@@ -153,11 +154,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
}
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
- e.handleICMP(r, vv)
+ headerView.CapLength(hlen)
+ e.handleICMP(r, headerView, vv)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
}
// Close cleans up resources associated with the endpoint.
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 14107443b..5a3c17768 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
}
-func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv6MinimumSize {
return
@@ -148,7 +148,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
if len(v) < header.ICMPv6EchoMinimumSize {
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv)
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 4d0b6ee9c..5f68ef7d5 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -102,7 +102,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- h := header.IPv6(vv.First())
+ headerView := vv.First()
+ h := header.IPv6(headerView)
if !h.IsValid(vv.Size()) {
return
}
@@ -112,12 +113,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
p := h.TransportProtocol()
if p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, vv)
+ e.handleICMP(r, headerView, vv)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
}
// Close cleans up resources associated with the endpoint.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 2278fbf65..79f845225 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -505,7 +505,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -525,16 +525,16 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, vv, id) {
+ if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
- if n.stack.demux.deliverPacket(r, protocol, vv, id) {
+ if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, vv) {
+ if state.defaultHandler(r, id, netHeader, vv) {
return
}
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 5accffa1b..62acd5919 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -64,7 +64,7 @@ const (
type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+ HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
@@ -80,6 +80,9 @@ type TransportProtocol interface {
// NewEndpoint creates a new endpoint of the transport protocol.
NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ // NewRawEndpoint creates a new raw endpoint of the transport protocol.
+ NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
// than this targeted at this protocol.
@@ -113,8 +116,9 @@ type TransportProtocol interface {
// the network layer.
type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
- // transport protocol endpoint.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView)
+ // transport protocol endpoint. It also returns the network layer
+ // header for the enpoint to inspect or pass up the stack.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 252c79317..797489ad9 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -48,7 +48,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool
+ defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -437,7 +437,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -499,6 +499,18 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return t.proto.NewEndpoint(s, network, waiterQueue)
}
+// NewRawEndpoint creates a new raw transport layer endpoint of the given
+// protocol. Raw endpoints receive all traffic for a given protocol regardless
+// of address.
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewRawEndpoint(s, network, waiterQueue)
+}
+
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
@@ -934,6 +946,42 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
}
}
+// RegisterRawTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided protocol will
+// be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+}
+
+// UnregisterRawTransportEndpoint removes the endpoint for the protocol from
+// the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+ if nicID == 0 {
+ s.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ }
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 163fadded..28743f3d5 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -97,7 +97,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index c18208dc0..9ab314188 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -32,8 +32,12 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
+ // mu protects all fields of the transportEndpoints.
mu sync.RWMutex
endpoints map[TransportEndpointID]TransportEndpoint
+ // rawEndpoints contains endpoints for raw sockets, which receive all
+ // traffic of a given protocol regardless of port.
+ rawEndpoints []TransportEndpoint
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
@@ -67,7 +71,9 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ }
}
}
@@ -131,22 +137,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
// If this is a broadcast datagram, deliver the datagram to all endpoints
// managed by ep.
if id.LocalAddress == header.IPv4Broadcast {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
+ endpoint.HandlePacket(r, id, netHeader, vv)
break
}
vvCopy := buffer.NewView(vv.Size())
copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView())
}
} else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv)
}
}
@@ -250,7 +256,7 @@ var loopbackSubnet = func() tcpip.Subnet {
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if it
// found one or more endpoints, false otherwise.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -276,10 +282,21 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
+
+ // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
+ // raw endpoint first. If there are multipe raw endpoints, they all
+ // receive the packet.
+ found := false
+ for _, rawEP := range eps.rawEndpoints {
+ // Each endpoint gets its own copy of the packet for the sake
+ // of save/restore.
+ rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ found = true
+ }
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
+ if len(destEps) == 0 && !found {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
@@ -289,7 +306,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.HandlePacket(r, id, netHeader, vv)
}
return true
@@ -349,3 +366,48 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
return nil
}
+
+// registerRawEndpoint registers the given endpoint with the dispatcher such
+// that packets of the appropriate protocol are delivered to it. A single
+// packet can be sent to one or more raw endpoints along with a non-raw
+// endpoint.
+func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil {
+ d.unregisterRawEndpoint(netProtos[:i], protocol, ep)
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ eps.rawEndpoints = append(eps.rawEndpoints, ep)
+
+ return nil
+}
+
+// unregisterRawEndpoint unregisters the raw endpoint for the given protocol
+// such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
+ }
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index da460db77..3347b5599 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -214,6 +214,10 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto), nil
}
+func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
func (*fakeTransportProtocol) MinimumPacketSize() int {
return fakeTransHeaderLen
}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d87bfe048..b3b7a1d0e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -46,17 +46,23 @@ const (
stateClosed
)
-// endpoint represents an ICMP (ping) endpoint. This struct serves as the
-// interface between users of the endpoint and the protocol implementation; it
-// is legal to have concurrent goroutines make calls into the endpoint, they
-// are properly synchronized.
+// endpoint represents an ICMP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// +stateify savable
type endpoint struct {
- // The following fields are initialized at creation time and do not
- // change throughout the lifetime of the endpoint.
+ // The following fields are initialized at creation time and are
+ // immutable.
stack *stack.Stack `state:"manual"`
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
+ // raw indicates whether the endpoint is intended for use by a raw
+ // socket, which returns the network layer header along with the
+ // payload. It is immutable.
+ raw bool
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -80,15 +86,26 @@ type endpoint struct {
route stack.Route `state:"manual"`
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) {
+ e := &endpoint{
stack: stack,
netProto: netProto,
transProto: transProto,
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ raw: raw,
+ }
+
+ // Raw endpoints must be immediately bound because they receive all
+ // ICMP traffic starting from when they're created via socket().
+ if raw {
+ if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
+ return nil, err
+ }
}
+
+ return e, nil
}
// Close puts the endpoint in a closed state and frees all resources
@@ -98,7 +115,11 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ if e.raw {
+ e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e)
+ } else {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ }
}
// Close the receive list and drain it.
@@ -285,10 +306,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
switch e.netProto {
case header.IPv4ProtocolNumber:
- err = sendPing4(route, e.id.LocalPort, v)
+ err = e.send4(route, v)
case header.IPv6ProtocolNumber:
- err = sendPing6(route, e.id.LocalPort, v)
+ err = send6(route, e.id.LocalPort, v)
}
if err != nil {
@@ -346,13 +367,19 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
+ if e.raw {
+ hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength()))
+ return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+ }
+
if len(data) < header.ICMPv4EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident. Sequence number is provided by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort)
hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
@@ -371,7 +398,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
-func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
@@ -412,6 +439,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO: We don't yet support connect on a raw socket.
+ if e.raw {
+ return tcpip.ErrNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -515,6 +547,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if e.raw {
+ err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false)
+ return stack.TransportEndpointID{}, err
+ }
+
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
@@ -657,7 +694,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -675,9 +712,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
Addr: id.RemoteAddress,
},
}
- pkt.data = vv.Clone(pkt.views[:])
+
+ if e.raw {
+ combinedVV := netHeader.ToVectorisedView()
+ combinedVV.Append(vv)
+ pkt.data = combinedVV.Clone(pkt.views[:])
+ } else {
+ pkt.data = vv.Clone(pkt.views[:])
+ }
+
e.rcvList.PushBack(pkt)
- e.rcvBufSize += vv.Size()
+ e.rcvBufSize += pkt.data.Size()
pkt.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9f0a2bf71..36b70988a 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -47,6 +47,7 @@ const (
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
+// protocol implements stack.TransportProtocol.
type protocol struct {
number tcpip.TransportProtocolNumber
}
@@ -66,12 +67,22 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
-// NewEndpoint creates a new icmp endpoint.
+// NewEndpoint creates a new icmp endpoint. It implements
+// stack.TransportProtocol.NewEndpoint.
func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue), nil
+ return newEndpoint(stack, netProto, p.number, waiterQueue, false)
+}
+
+// NewRawEndpoint creates a new raw icmp endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue, true)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a8618bb4a..c48a27d8f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1441,7 +1441,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
s := newSegment(r, id, vv)
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2f90839e9..ca53a076f 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -63,7 +63,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 753e1419e..639ad3fae 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -101,6 +101,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN
return newEndpoint(stack, netProto, waiterQueue), nil
}
+// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
// MinimumPacketSize returns the minimum valid tcp packet size.
func (*protocol) MinimumPacketSize() int {
return header.TCPMinimumSize
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 05d35e526..44b9cdf6a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -934,7 +934,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
// Get the header then trim it from the view.
hdr := header.UDP(vv.First())
if int(hdr.Length()) > vv.Size() {
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index b3fbed6e4..616a9f388 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -48,6 +48,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN
return newEndpoint(stack, netProto, waiterQueue), nil
}
+// NewRawEndpoint creates a new raw UDP endpoint. Raw UDP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
// MinimumPacketSize returns the minimum valid udp packet size.
func (*protocol) MinimumPacketSize() int {
return header.UDPMinimumSize
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index beece8930..4c818238b 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1530,6 +1530,22 @@ cc_binary(
)
cc_binary(
+ name = "raw_socket_ipv4_test",
+ testonly = 1,
+ srcs = ["raw_socket_ipv4.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "read_test",
testonly = 1,
srcs = ["read.cc"],
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc
new file mode 100644
index 000000000..c6749321c
--- /dev/null
+++ b/test/syscalls/linux/raw_socket_ipv4.cc
@@ -0,0 +1,245 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/capability.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Fixture for tests parameterized by address family (currently only AF_INET).
+class RawSocketTest : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // The socket used for both reading and writing.
+ int s_;
+
+ // The loopback address.
+ struct sockaddr_in addr_;
+
+ void sendEmptyICMP(struct icmphdr *icmp);
+
+ void sendEmptyICMPTo(int sock, struct sockaddr_in *addr,
+ struct icmphdr *icmp);
+
+ void receiveICMP(char *recv_buf, size_t recv_buf_len, size_t expected_size,
+ struct sockaddr_in *src);
+
+ void receiveICMPFrom(char *recv_buf, size_t recv_buf_len,
+ size_t expected_size, struct sockaddr_in *src, int sock);
+};
+
+void RawSocketTest::SetUp() {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
+
+ addr_ = {};
+
+ // We don't set ports because raw sockets don't have a notion of ports.
+ addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr_.sin_family = AF_INET;
+}
+
+void RawSocketTest::TearDown() {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+}
+
+// We should be able to create multiple raw sockets for the same protocol.
+// BasicRawSocket::Setup creates the first one, so we only have to create one
+// more here.
+TEST_F(RawSocketTest, MultipleCreation) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int s2;
+ ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
+
+ ASSERT_THAT(close(s2), SyscallSucceeds());
+}
+
+// Send and receive an ICMP packet.
+TEST_F(RawSocketTest, SendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = *(unsigned short *)&icmp.checksum;
+ icmp.un.echo.sequence = *(unsigned short *)&icmp.un.echo.sequence;
+ icmp.un.echo.id = *(unsigned short *)&icmp.un.echo.id;
+ ASSERT_NO_FATAL_FAILURE(sendEmptyICMP(&icmp));
+
+ // Receive the packet and make sure it's identical.
+ char recv_buf[512];
+ struct sockaddr_in src;
+ ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf),
+ sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0);
+ EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)), 0);
+
+ // We should also receive the automatically generated echo reply.
+ ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf),
+ sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0);
+ struct icmphdr *reply_icmp =
+ (struct icmphdr *)(recv_buf + sizeof(struct iphdr));
+ // Most fields should be the same.
+ EXPECT_EQ(reply_icmp->code, icmp.code);
+ EXPECT_EQ(reply_icmp->un.echo.sequence, icmp.un.echo.sequence);
+ EXPECT_EQ(reply_icmp->un.echo.id, icmp.un.echo.id);
+ // A couple are different.
+ EXPECT_EQ(reply_icmp->type, ICMP_ECHOREPLY);
+ // The checksum is computed in such a way that it is guaranteed to have
+ // changed.
+ EXPECT_NE(reply_icmp->checksum, icmp.checksum);
+}
+
+// We should be able to create multiple raw sockets for the same protocol and
+// receive the same packet on both.
+TEST_F(RawSocketTest, MultipleSocketReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ FileDescriptor s2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP));
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = *(unsigned short *)&icmp.checksum;
+ icmp.un.echo.sequence = *(unsigned short *)&icmp.un.echo.sequence;
+ icmp.un.echo.id = *(unsigned short *)&icmp.un.echo.id;
+ ASSERT_NO_FATAL_FAILURE(sendEmptyICMP(&icmp));
+
+ // Receive it on socket 1.
+ char recv_buf1[512];
+ struct sockaddr_in src;
+ ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf1, ABSL_ARRAYSIZE(recv_buf1),
+ sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0);
+
+ // Receive it on socket 2.
+ char recv_buf2[512];
+ ASSERT_NO_FATAL_FAILURE(receiveICMPFrom(recv_buf2, ABSL_ARRAYSIZE(recv_buf2),
+ sizeof(struct icmphdr), &src,
+ s2.get()));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0);
+
+ EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr),
+ recv_buf2 + sizeof(struct iphdr), sizeof(icmp)),
+ 0);
+}
+
+// A raw ICMP socket and ping socket should both receive the ICMP packets
+// indended for the ping socket.
+TEST_F(RawSocketTest, RawAndPingSockets) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ FileDescriptor ping_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+
+ // Ping sockets take care of the ICMP ID and checksum.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.un.echo.sequence = *(unsigned short *)&icmp.un.echo.sequence;
+ ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, sizeof(icmp), 0,
+ (struct sockaddr *)&addr_, sizeof(addr_)),
+ SyscallSucceedsWithValue(sizeof(icmp)));
+
+ // Receive the packet via raw socket.
+ char recv_buf[512];
+ struct sockaddr_in src;
+ ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf),
+ sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0);
+
+ // Receive the packet via ping socket.
+ struct icmphdr ping_header;
+ ASSERT_THAT(
+ RetryEINTR(recv)(ping_sock.get(), &ping_header, sizeof(ping_header), 0),
+ SyscallSucceedsWithValue(sizeof(ping_header)));
+
+ // Packets should be the same.
+ EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &ping_header,
+ sizeof(struct icmphdr)),
+ 0);
+}
+
+void RawSocketTest::sendEmptyICMP(struct icmphdr *icmp) {
+ ASSERT_NO_FATAL_FAILURE(sendEmptyICMPTo(s_, &addr_, icmp));
+}
+
+void RawSocketTest::sendEmptyICMPTo(int sock, struct sockaddr_in *addr,
+ struct icmphdr *icmp) {
+ struct iovec iov = {.iov_base = icmp, .iov_len = sizeof(*icmp)};
+ struct msghdr msg {
+ .msg_name = addr, .msg_namelen = sizeof(*addr), .msg_iov = &iov,
+ .msg_iovlen = 1, .msg_control = NULL, .msg_controllen = 0, .msg_flags = 0,
+ };
+ ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(*icmp)));
+}
+
+void RawSocketTest::receiveICMP(char *recv_buf, size_t recv_buf_len,
+ size_t expected_size, struct sockaddr_in *src) {
+ ASSERT_NO_FATAL_FAILURE(
+ receiveICMPFrom(recv_buf, recv_buf_len, expected_size, src, s_));
+}
+
+void RawSocketTest::receiveICMPFrom(char *recv_buf, size_t recv_buf_len,
+ size_t expected_size,
+ struct sockaddr_in *src, int sock) {
+ struct iovec iov = {.iov_base = recv_buf, .iov_len = recv_buf_len};
+ struct msghdr msg = {
+ .msg_name = src,
+ .msg_namelen = sizeof(*src),
+ .msg_iov = &iov,
+ .msg_iovlen = 1,
+ .msg_control = NULL,
+ .msg_controllen = 0,
+ .msg_flags = 0,
+ };
+ // We should receive the ICMP packet plus 20 bytes of IP header.
+ ASSERT_THAT(recvmsg(sock, &msg, 0),
+ SyscallSucceedsWithValue(expected_size + sizeof(struct iphdr)));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor