diff options
author | Kevin Krakauer <krakauer@google.com> | 2019-02-27 14:30:20 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-02-27 14:31:21 -0800 |
commit | 121db29a93c651b8b62e8701bb0f16c231b08257 (patch) | |
tree | c6b235b72340de44c8ee0e9398d337a8cc6c30f3 | |
parent | 6df212b831dcc3350b7677423ec7835ed40b3f22 (diff) |
Ping support via IPv4 raw sockets.
Broadly, this change:
* Enables sockets to be created via `socket(AF_INET, SOCK_RAW, IPPROTO_ICMP)`.
* Passes the network-layer (IP) header up the stack to the transport endpoint,
which can pass it up to the socket layer. This allows a raw socket to return
the entire IP packet to users.
* Adds functions to stack.TransportProtocol, stack.Stack, stack.transportDemuxer
that enable incoming packets to be delivered to raw endpoints. New raw sockets
of other protocols (not ICMP) just need to register with the stack.
* Enables ping.endpoint to return IP headers when created via SOCK_RAW.
PiperOrigin-RevId: 235993280
Change-Id: I60ed994f5ff18b2cbd79f063a7fdf15d093d845a
-rw-r--r-- | pkg/sentry/socket/epsocket/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/provider.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 78 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 81 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 6 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 16 | ||||
-rw-r--r-- | test/syscalls/linux/raw_socket_ipv4.cc | 245 |
22 files changed, 533 insertions, 59 deletions
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 45e418db3..44bb97b5b 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/inet", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/kdefs", "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index 0184d8e3e..0d9c2df24 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -18,8 +18,10 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -38,9 +40,9 @@ type provider struct { netProto tcpip.NetworkProtocolNumber } -// GetTransportProtocol figures out transport protocol. Currently only TCP, +// getTransportProtocol figures out transport protocol. Currently only TCP, // UDP, and ICMP are supported. -func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { @@ -57,6 +59,18 @@ func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.Transpo case syscall.IPPROTO_ICMPV6: return header.ICMPv6ProtocolNumber, nil } + + case linux.SOCK_RAW: + // Raw sockets require CAP_NET_RAW. + creds := auth.CredentialsFromContext(ctx) + if !creds.HasCapability(linux.CAP_NET_RAW) { + return 0, syserr.ErrPermissionDenied + } + + switch protocol { + case syscall.IPPROTO_ICMP: + return header.ICMPv4ProtocolNumber, nil + } } return 0, syserr.ErrInvalidArgument } @@ -76,14 +90,20 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int } // Figure out the transport protocol. - transProto, err := GetTransportProtocol(stype, protocol) + transProto, err := getTransportProtocol(t, stype, protocol) if err != nil { return nil, err } // Create the endpoint. + var ep tcpip.Endpoint + var e *tcpip.Error wq := &waiter.Queue{} - ep, e := eps.Stack.NewEndpoint(transProto, p.netProto, wq) + if stype == linux.SOCK_RAW { + ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq) + } else { + ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) + } if e != nil { return nil, syserr.TranslateNetstackError(e) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 5c1e88e56..97a43aece 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -94,7 +94,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index f82dc098f..ea8392c98 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -55,7 +55,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) } -func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { v := vv.First() if len(v) < header.ICMPv4MinimumSize { return @@ -67,19 +67,22 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { if len(v) < header.ICMPv4EchoMinimumSize { return } - vv.TrimFront(header.ICMPv4MinimumSize) - req := echoRequest{r: r.Clone(), v: vv.ToView()} + echoPayload := vv.ToView() + echoPayload.TrimFront(header.ICMPv4MinimumSize) + req := echoRequest{r: r.Clone(), v: echoPayload} select { case e.echoRequests <- req: default: req.r.Release() } + // It's possible that a raw socket expects to receive this. + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) case header.ICMPv4EchoReply: if len(v) < header.ICMPv4EchoMinimumSize { return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) case header.ICMPv4DstUnreachable: if len(v) < header.ICMPv4DstUnreachableMinimumSize { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 0c41519df..bfc3c08fa 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -131,7 +131,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) + headerView := vv.First() + h := header.IPv4(headerView) if !h.IsValid(vv.Size()) { return } @@ -153,11 +154,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - e.handleICMP(r, vv) + headerView.CapLength(hlen) + e.handleICMP(r, headerView, vv) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, vv) + e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 14107443b..5a3c17768 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) } -func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { v := vv.First() if len(v) < header.ICMPv6MinimumSize { return @@ -148,7 +148,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { if len(v) < header.ICMPv6EchoMinimumSize { return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4d0b6ee9c..5f68ef7d5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -102,7 +102,8 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) + headerView := vv.First() + h := header.IPv6(headerView) if !h.IsValid(vv.Size()) { return } @@ -112,12 +113,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, vv) + e.handleICMP(r, headerView, vv) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, vv) + e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 2278fbf65..79f845225 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -505,7 +505,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -525,16 +525,16 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, vv, id) { + if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } - if n.stack.demux.deliverPacket(r, protocol, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, vv) { + if state.defaultHandler(r, id, netHeader, vv) { return } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5accffa1b..62acd5919 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,7 +64,7 @@ const ( type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. @@ -80,6 +80,9 @@ type TransportProtocol interface { // NewEndpoint creates a new endpoint of the transport protocol. NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + // NewRawEndpoint creates a new raw endpoint of the transport protocol. + NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller // than this targeted at this protocol. @@ -113,8 +116,9 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) + // transport protocol endpoint. It also returns the network layer + // header for the enpoint to inspect or pass up the stack. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 252c79317..797489ad9 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -48,7 +48,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -437,7 +437,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -499,6 +499,18 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return t.proto.NewEndpoint(s, network, waiterQueue) } +// NewRawEndpoint creates a new raw transport layer endpoint of the given +// protocol. Raw endpoints receive all traffic for a given protocol regardless +// of address. +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + t, ok := s.transportProtocols[transport] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + return t.proto.NewRawEndpoint(s, network, waiterQueue) +} + // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error { @@ -934,6 +946,42 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip } } +// RegisterRawTransportEndpoint registers the given endpoint with the stack +// transport dispatcher. Received packets that match the provided protocol will +// be delivered to the given endpoint. +func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { + if nicID == 0 { + return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) +} + +// UnregisterRawTransportEndpoint removes the endpoint for the protocol from +// the stack transport dispatcher. +func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { + if nicID == 0 { + s.demux.unregisterRawEndpoint(netProtos, protocol, ep) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic != nil { + nic.demux.unregisterRawEndpoint(netProtos, protocol, ep) + } +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 163fadded..28743f3d5 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -97,7 +97,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c18208dc0..9ab314188 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -32,8 +32,12 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { + // mu protects all fields of the transportEndpoints. mu sync.RWMutex endpoints map[TransportEndpointID]TransportEndpoint + // rawEndpoints contains endpoints for raw sockets, which receive all + // traffic of a given protocol regardless of port. + rawEndpoints []TransportEndpoint } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -67,7 +71,9 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)} + d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + endpoints: make(map[TransportEndpointID]TransportEndpoint), + } } } @@ -131,22 +137,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { // If this is a broadcast datagram, deliver the datagram to all endpoints // managed by ep. if id.LocalAddress == header.IPv4Broadcast { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, netHeader, vv) break } vvCopy := buffer.NewView(vv.Size()) copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView()) } } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv) } } @@ -250,7 +256,7 @@ var loopbackSubnet = func() tcpip.Subnet { // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if it // found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -276,10 +282,21 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } + + // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via + // raw endpoint first. If there are multipe raw endpoints, they all + // receive the packet. + found := false + for _, rawEP := range eps.rawEndpoints { + // Each endpoint gets its own copy of the packet for the sake + // of save/restore. + rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + found = true + } eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { + if len(destEps) == 0 && !found { // UDP packet could not be delivered to an unknown destination port. if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() @@ -289,7 +306,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.HandlePacket(r, id, netHeader, vv) } return true @@ -349,3 +366,48 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer return nil } + +// registerRawEndpoint registers the given endpoint with the dispatcher such +// that packets of the appropriate protocol are delivered to it. A single +// packet can be sent to one or more raw endpoints along with a non-raw +// endpoint. +func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { + for i, n := range netProtos { + if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil { + d.unregisterRawEndpoint(netProtos[:i], protocol, ep) + return err + } + } + + return nil +} + +func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return nil + } + + eps.mu.Lock() + defer eps.mu.Unlock() + eps.rawEndpoints = append(eps.rawEndpoints, ep) + + return nil +} + +// unregisterRawEndpoint unregisters the raw endpoint for the given protocol +// such that it won't receive any more packets. +func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { + for _, n := range netProtos { + if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { + eps.mu.Lock() + defer eps.mu.Unlock() + for i, rawEP := range eps.rawEndpoints { + if rawEP == ep { + eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) + return + } + } + } + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index da460db77..3347b5599 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -214,6 +214,10 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto), nil } +func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + func (*fakeTransportProtocol) MinimumPacketSize() int { return fakeTransHeaderLen } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d87bfe048..b3b7a1d0e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -46,17 +46,23 @@ const ( stateClosed ) -// endpoint represents an ICMP (ping) endpoint. This struct serves as the -// interface between users of the endpoint and the protocol implementation; it -// is legal to have concurrent goroutines make calls into the endpoint, they -// are properly synchronized. +// endpoint represents an ICMP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// +stateify savable type endpoint struct { - // The following fields are initialized at creation time and do not - // change throughout the lifetime of the endpoint. + // The following fields are initialized at creation time and are + // immutable. stack *stack.Stack `state:"manual"` netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + // raw indicates whether the endpoint is intended for use by a raw + // socket, which returns the network layer header along with the + // payload. It is immutable. + raw bool // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -80,15 +86,26 @@ type endpoint struct { route stack.Route `state:"manual"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - return &endpoint{ +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) { + e := &endpoint{ stack: stack, netProto: netProto, transProto: transProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + raw: raw, + } + + // Raw endpoints must be immediately bound because they receive all + // ICMP traffic starting from when they're created via socket(). + if raw { + if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil { + return nil, err + } } + + return e, nil } // Close puts the endpoint in a closed state and frees all resources @@ -98,7 +115,11 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + if e.raw { + e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e) + } else { + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + } } // Close the receive list and drain it. @@ -285,10 +306,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c switch e.netProto { case header.IPv4ProtocolNumber: - err = sendPing4(route, e.id.LocalPort, v) + err = e.send4(route, v) case header.IPv6ProtocolNumber: - err = sendPing6(route, e.id.LocalPort, v) + err = send6(route, e.id.LocalPort, v) } if err != nil { @@ -346,13 +367,19 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { + if e.raw { + hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength())) + return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + } + if len(data) < header.ICMPv4EchoMinimumSize { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort) hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) @@ -371,7 +398,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) } -func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -412,6 +439,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // TODO: We don't yet support connect on a raw socket. + if e.raw { + return tcpip.ErrNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -515,6 +547,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if e.raw { + err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false) + return stack.TransportEndpointID{}, err + } + if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -657,7 +694,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -675,9 +712,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + + if e.raw { + combinedVV := netHeader.ToVectorisedView() + combinedVV.Append(vv) + pkt.data = combinedVV.Clone(pkt.views[:]) + } else { + pkt.data = vv.Clone(pkt.views[:]) + } + e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + e.rcvBufSize += pkt.data.Size() pkt.timestamp = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9f0a2bf71..36b70988a 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -47,6 +47,7 @@ const ( ProtocolNumber6 = header.ICMPv6ProtocolNumber ) +// protocol implements stack.TransportProtocol. type protocol struct { number tcpip.TransportProtocolNumber } @@ -66,12 +67,22 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { panic(fmt.Sprint("unknown protocol number: ", p.number)) } -// NewEndpoint creates a new icmp endpoint. +// NewEndpoint creates a new icmp endpoint. It implements +// stack.TransportProtocol.NewEndpoint. func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue), nil + return newEndpoint(stack, netProto, p.number, waiterQueue, false) +} + +// NewRawEndpoint creates a new raw icmp endpoint. It implements +// stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return newEndpoint(stack, netProto, p.number, waiterQueue, true) } // MinimumPacketSize returns the minimum valid icmp packet size. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a8618bb4a..c48a27d8f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1441,7 +1441,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { s := newSegment(r, id, vv) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2f90839e9..ca53a076f 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -63,7 +63,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 753e1419e..639ad3fae 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -101,6 +101,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN return newEndpoint(stack, netProto, waiterQueue), nil } +// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + // MinimumPacketSize returns the minimum valid tcp packet size. func (*protocol) MinimumPacketSize() int { return header.TCPMinimumSize diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 05d35e526..44b9cdf6a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -934,7 +934,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { // Get the header then trim it from the view. hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index b3fbed6e4..616a9f388 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -48,6 +48,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN return newEndpoint(stack, netProto, waiterQueue), nil } +// NewRawEndpoint creates a new raw UDP endpoint. Raw UDP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + // MinimumPacketSize returns the minimum valid udp packet size. func (*protocol) MinimumPacketSize() int { return header.UDPMinimumSize diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index beece8930..4c818238b 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1530,6 +1530,22 @@ cc_binary( ) cc_binary( + name = "raw_socket_ipv4_test", + testonly = 1, + srcs = ["raw_socket_ipv4.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "read_test", testonly = 1, srcs = ["read.cc"], diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc new file mode 100644 index 000000000..c6749321c --- /dev/null +++ b/test/syscalls/linux/raw_socket_ipv4.cc @@ -0,0 +1,245 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/capability.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Fixture for tests parameterized by address family (currently only AF_INET). +class RawSocketTest : public ::testing::Test { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // The socket used for both reading and writing. + int s_; + + // The loopback address. + struct sockaddr_in addr_; + + void sendEmptyICMP(struct icmphdr *icmp); + + void sendEmptyICMPTo(int sock, struct sockaddr_in *addr, + struct icmphdr *icmp); + + void receiveICMP(char *recv_buf, size_t recv_buf_len, size_t expected_size, + struct sockaddr_in *src); + + void receiveICMPFrom(char *recv_buf, size_t recv_buf_len, + size_t expected_size, struct sockaddr_in *src, int sock); +}; + +void RawSocketTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); + + addr_ = {}; + + // We don't set ports because raw sockets don't have a notion of ports. + addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr_.sin_family = AF_INET; +} + +void RawSocketTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + EXPECT_THAT(close(s_), SyscallSucceeds()); +} + +// We should be able to create multiple raw sockets for the same protocol. +// BasicRawSocket::Setup creates the first one, so we only have to create one +// more here. +TEST_F(RawSocketTest, MultipleCreation) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int s2; + ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); + + ASSERT_THAT(close(s2), SyscallSucceeds()); +} + +// Send and receive an ICMP packet. +TEST_F(RawSocketTest, SendAndReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, + // and ID. None of that should matter for raw sockets - the kernel should + // still give us the packet. + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.checksum = *(unsigned short *)&icmp.checksum; + icmp.un.echo.sequence = *(unsigned short *)&icmp.un.echo.sequence; + icmp.un.echo.id = *(unsigned short *)&icmp.un.echo.id; + ASSERT_NO_FATAL_FAILURE(sendEmptyICMP(&icmp)); + + // Receive the packet and make sure it's identical. + char recv_buf[512]; + struct sockaddr_in src; + ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf), + sizeof(struct icmphdr), &src)); + EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0); + EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)), 0); + + // We should also receive the automatically generated echo reply. + ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf), + sizeof(struct icmphdr), &src)); + EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0); + struct icmphdr *reply_icmp = + (struct icmphdr *)(recv_buf + sizeof(struct iphdr)); + // Most fields should be the same. + EXPECT_EQ(reply_icmp->code, icmp.code); + EXPECT_EQ(reply_icmp->un.echo.sequence, icmp.un.echo.sequence); + EXPECT_EQ(reply_icmp->un.echo.id, icmp.un.echo.id); + // A couple are different. + EXPECT_EQ(reply_icmp->type, ICMP_ECHOREPLY); + // The checksum is computed in such a way that it is guaranteed to have + // changed. + EXPECT_NE(reply_icmp->checksum, icmp.checksum); +} + +// We should be able to create multiple raw sockets for the same protocol and +// receive the same packet on both. +TEST_F(RawSocketTest, MultipleSocketReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor s2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP)); + + // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, + // and ID. None of that should matter for raw sockets - the kernel should + // still give us the packet. + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.checksum = *(unsigned short *)&icmp.checksum; + icmp.un.echo.sequence = *(unsigned short *)&icmp.un.echo.sequence; + icmp.un.echo.id = *(unsigned short *)&icmp.un.echo.id; + ASSERT_NO_FATAL_FAILURE(sendEmptyICMP(&icmp)); + + // Receive it on socket 1. + char recv_buf1[512]; + struct sockaddr_in src; + ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf1, ABSL_ARRAYSIZE(recv_buf1), + sizeof(struct icmphdr), &src)); + EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0); + + // Receive it on socket 2. + char recv_buf2[512]; + ASSERT_NO_FATAL_FAILURE(receiveICMPFrom(recv_buf2, ABSL_ARRAYSIZE(recv_buf2), + sizeof(struct icmphdr), &src, + s2.get())); + EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0); + + EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr), + recv_buf2 + sizeof(struct iphdr), sizeof(icmp)), + 0); +} + +// A raw ICMP socket and ping socket should both receive the ICMP packets +// indended for the ping socket. +TEST_F(RawSocketTest, RawAndPingSockets) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor ping_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); + + // Ping sockets take care of the ICMP ID and checksum. + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.un.echo.sequence = *(unsigned short *)&icmp.un.echo.sequence; + ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, sizeof(icmp), 0, + (struct sockaddr *)&addr_, sizeof(addr_)), + SyscallSucceedsWithValue(sizeof(icmp))); + + // Receive the packet via raw socket. + char recv_buf[512]; + struct sockaddr_in src; + ASSERT_NO_FATAL_FAILURE(receiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf), + sizeof(struct icmphdr), &src)); + EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0); + + // Receive the packet via ping socket. + struct icmphdr ping_header; + ASSERT_THAT( + RetryEINTR(recv)(ping_sock.get(), &ping_header, sizeof(ping_header), 0), + SyscallSucceedsWithValue(sizeof(ping_header))); + + // Packets should be the same. + EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &ping_header, + sizeof(struct icmphdr)), + 0); +} + +void RawSocketTest::sendEmptyICMP(struct icmphdr *icmp) { + ASSERT_NO_FATAL_FAILURE(sendEmptyICMPTo(s_, &addr_, icmp)); +} + +void RawSocketTest::sendEmptyICMPTo(int sock, struct sockaddr_in *addr, + struct icmphdr *icmp) { + struct iovec iov = {.iov_base = icmp, .iov_len = sizeof(*icmp)}; + struct msghdr msg { + .msg_name = addr, .msg_namelen = sizeof(*addr), .msg_iov = &iov, + .msg_iovlen = 1, .msg_control = NULL, .msg_controllen = 0, .msg_flags = 0, + }; + ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(*icmp))); +} + +void RawSocketTest::receiveICMP(char *recv_buf, size_t recv_buf_len, + size_t expected_size, struct sockaddr_in *src) { + ASSERT_NO_FATAL_FAILURE( + receiveICMPFrom(recv_buf, recv_buf_len, expected_size, src, s_)); +} + +void RawSocketTest::receiveICMPFrom(char *recv_buf, size_t recv_buf_len, + size_t expected_size, + struct sockaddr_in *src, int sock) { + struct iovec iov = {.iov_base = recv_buf, .iov_len = recv_buf_len}; + struct msghdr msg = { + .msg_name = src, + .msg_namelen = sizeof(*src), + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = 0, + }; + // We should receive the ICMP packet plus 20 bytes of IP header. + ASSERT_THAT(recvmsg(sock, &msg, 0), + SyscallSucceedsWithValue(expected_size + sizeof(struct iphdr))); +} + +} // namespace + +} // namespace testing +} // namespace gvisor |