diff options
author | Kevin Krakauer <krakauer@google.com> | 2019-04-02 11:12:29 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-02 11:13:49 -0700 |
commit | 52a51a8e20b3e5c28eb1e66bd57203216cf644c5 (patch) | |
tree | 4d98f67dc4abb15e6de568b2af12a3e12f41aa99 | |
parent | 1df3fa69977477092efa65a8de407bd6f0f88db4 (diff) |
Add a raw socket transport endpoint and use it for raw ICMP sockets.
Having raw socket code together will make it easier to add support for other raw
network protocols. Currently, only ICMP uses the raw endpoint. However, adding
support for other protocols such as UDP shouldn't be much more difficult than
adding a few switch cases.
PiperOrigin-RevId: 241564875
Change-Id: I77e03adafe4ce0fd29ba2d5dfdc547d2ae8f25bf
-rw-r--r-- | pkg/tcpip/stack/registration.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/BUILD | 45 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/raw.go | 558 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/state.go | 88 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 | ||||
-rw-r--r-- | test/syscalls/linux/raw_socket_ipv4.cc | 421 |
13 files changed, 1106 insertions, 194 deletions
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ff356ea22..f3cc849ec 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,13 +64,24 @@ const ( type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) + HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) } +// RawTransportEndpoint is the interface that needs to be implemented by raw +// transport protocol endpoints. RawTransportEndpoints receive the entire +// packet - including the link, network, and transport headers - as delivered +// to netstack. +type RawTransportEndpoint interface { + // HandlePacket is called by the stack when new packets arrive to + // this transport endpoint. The packet contains all data from the link + // layer up. + HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) +} + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 15a268b10..a74c0a7a0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -955,11 +955,11 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip } // RegisterRawTransportEndpoint registers the given endpoint with the stack -// transport dispatcher. Received packets that match the provided protocol will -// be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { +// transport dispatcher. Received packets that match the provided transport +// protocol will be delivered to the given endpoint. +func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { if nicID == 0 { - return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } s.mu.RLock() @@ -970,14 +970,14 @@ func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpi return tcpip.ErrUnknownNICID } - return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort) + return nic.demux.registerRawEndpoint(netProto, transProto, ep) } -// UnregisterRawTransportEndpoint removes the endpoint for the protocol from -// the stack transport dispatcher. -func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { +// UnregisterRawTransportEndpoint removes the endpoint for the transport +// protocol from the stack transport dispatcher. +func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { if nicID == 0 { - s.demux.unregisterRawEndpoint(netProtos, protocol, ep) + s.demux.unregisterRawEndpoint(netProto, transProto, ep) return } @@ -986,7 +986,7 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tc nic := s.nics[nicID] if nic != nil { - nic.demux.unregisterRawEndpoint(netProtos, protocol, ep) + nic.demux.unregisterRawEndpoint(netProto, transProto, ep) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9ab314188..a8ac18e72 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "math/rand" "sync" @@ -37,7 +38,7 @@ type transportEndpoints struct { endpoints map[TransportEndpointID]TransportEndpoint // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. - rawEndpoints []TransportEndpoint + rawEndpoints []RawTransportEndpoint } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -60,8 +61,10 @@ func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep Tra // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then -// based on endpoints IDs. +// based on endpoints IDs. It should only be instantiated via +// newTransportDemuxer. type transportDemuxer struct { + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints } @@ -137,22 +140,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { // If this is a broadcast datagram, deliver the datagram to all endpoints // managed by ep. if id.LocalAddress == header.IPv4Broadcast { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, netHeader, vv) + endpoint.HandlePacket(r, id, vv) break } vvCopy := buffer.NewView(vv.Size()) copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } } else { - ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv) + ep.selectEndpoint(id).HandlePacket(r, id, vv) } } @@ -286,17 +289,17 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via // raw endpoint first. If there are multipe raw endpoints, they all // receive the packet. - found := false + foundRaw := false for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) - found = true + rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + foundRaw = true } eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 && !found { + if len(destEps) == 0 && !foundRaw { // UDP packet could not be delivered to an unknown destination port. if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() @@ -306,7 +309,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, netHeader, vv) + ep.HandlePacket(r, id, vv) } return true @@ -371,19 +374,8 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw // endpoint. -func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error { - for i, n := range netProtos { - if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil { - d.unregisterRawEndpoint(netProtos[:i], protocol, ep) - return err - } - } - - return nil -} - -func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error { - eps, ok := d.protocol[protocolIDs{netProto, protocol}] +func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil } @@ -395,19 +387,20 @@ func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProto return nil } -// unregisterRawEndpoint unregisters the raw endpoint for the given protocol -// such that it won't receive any more packets. -func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) { - for _, n := range netProtos { - if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.mu.Lock() - defer eps.mu.Unlock() - for i, rawEP := range eps.rawEndpoints { - if rawEP == ep { - eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) - return - } - } +// unregisterRawEndpoint unregisters the raw endpoint for the given transport +// protocol such that it won't receive any more packets. +func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto)) + } + + eps.mu.Lock() + defer eps.mu.Unlock() + for i, rawEP := range eps.rawEndpoints { + if rawEP == ep { + eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) + return } } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index dfd31557a..0c2589083 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 74d9ff253..9aa6f3978 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/raw", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 182097b46..8f2e3aa20 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -59,10 +59,6 @@ type endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue - // raw indicates whether the endpoint is intended for use by a raw - // socket, which returns the network layer header along with the - // payload. It is immutable. - raw bool // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -80,32 +76,26 @@ type endpoint struct { shutdownFlags tcpip.ShutdownFlags id stack.TransportEndpointID state endpointState - bindNICID tcpip.NICID - bindAddr tcpip.Address - regNICID tcpip.NICID - route stack.Route `state:"manual"` + // bindNICID and bindAddr are set via calls to Bind(). They are used to + // reject attempts to send data or connect via a different NIC or + // address + bindNICID tcpip.NICID + bindAddr tcpip.Address + // regNICID is the default NIC to be used when callers don't specify a + // NIC. + regNICID tcpip.NICID + route stack.Route `state:"manual"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) { - e := &endpoint{ +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return &endpoint{ stack: stack, netProto: netProto, transProto: transProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, - raw: raw, - } - - // Raw endpoints must be immediately bound because they receive all - // ICMP traffic starting from when they're created via socket(). - if raw { - if err := e.bindLocked(tcpip.FullAddress{}); err != nil { - return nil, err - } - } - - return e, nil + }, nil } // Close puts the endpoint in a closed state and frees all resources @@ -115,11 +105,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - if e.raw { - e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e) - } else { - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) - } + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) } // Close the receive list and drain it. @@ -244,8 +230,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c route = &e.route if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, given that it may - // need to change in Route.Resolve() call below. + // Promote lock to exclusive if using a shared route, + // given that it may need to change in Route.Resolve() + // call below. e.mu.RUnlock() defer e.mu.RLock() @@ -290,8 +277,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c waker := &sleep.Waker{} if ch, err := route.Resolve(waker); err != nil { if err == tcpip.ErrWouldBlock { - // Link address needs to be resolved. Resolution was triggered the - // background. Better luck next time. + // Link address needs to be resolved. + // Resolution was triggered the background. + // Better luck next time. route.RemoveWaker(waker) return 0, ch, tcpip.ErrNoLinkAddress } @@ -368,11 +356,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { - if e.raw { - hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) - } - if len(data) < header.ICMPv4EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -439,11 +422,6 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - // TODO: We don't yet support connect on a raw socket. - if e.raw { - return tcpip.ErrNotSupported - } - e.mu.Lock() defer e.mu.Unlock() @@ -547,11 +525,6 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if e.raw { - err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false) - return stack.TransportEndpointID{}, err - } - if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -687,11 +660,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.stack.Stats().DroppedPackets.Increment() e.rcvMu.Unlock() return } @@ -706,13 +680,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, ne }, } - if e.raw { - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - pkt.data = combinedVV.Clone(pkt.views[:]) - } else { - pkt.data = vv.Clone(pkt.views[:]) - } + pkt.data = vv.Clone(pkt.views[:]) e.rcvList.PushBack(pkt) e.rcvBufSize += pkt.data.Size() diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 36b70988a..09ee2f892 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -30,6 +30,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -73,7 +74,7 @@ func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtoco if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue, false) + return newEndpoint(stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements @@ -82,7 +83,7 @@ func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProt if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue, true) + return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) } // MinimumPacketSize returns the minimum valid icmp packet size. diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD new file mode 100644 index 000000000..005079639 --- /dev/null +++ b/pkg/tcpip/transport/raw/BUILD @@ -0,0 +1,45 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +go_template_instance( + name = "packet_list", + out = "packet_list.go", + package = "raw", + prefix = "packet", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*packet", + "Linker": "*packet", + }, +) + +go_library( + name = "raw", + srcs = [ + "packet_list.go", + "raw.go", + "state.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw", + imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) + +filegroup( + name = "autogen", + srcs = [ + "packet_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/transport/raw/raw.go b/pkg/tcpip/transport/raw/raw.go new file mode 100644 index 000000000..8dada2e4f --- /dev/null +++ b/pkg/tcpip/transport/raw/raw.go @@ -0,0 +1,558 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package raw provides the implementation of raw sockets (see raw(7)). Raw +// sockets allow applications to: +// +// * manually write and inspect transport layer headers and payloads +// * receive all traffic of a given transport protcol (e.g. ICMP or UDP) +// * optionally write and inspect network layer and link layer headers for +// packets +// +// Raw sockets don't have any notion of ports, and incoming packets are +// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will +// receive every UDP packet received by netstack. bind(2) and connect(2) can be +// used to filter incoming packets by source and destination. +package raw + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // views is pre-allocated space to back data. As long as the packet is + // made up of fewer than 8 buffer.Views, no extra allocation is + // necessary to store packet data. + views [8]buffer.View `state:"nosave"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to +// have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + closed bool + connected bool + bound bool + // registeredNIC is the NIC to which th endpoint is explicitly + // registered. Is set when Connect or Bind are used to specify a NIC. + registeredNIC tcpip.NICID + // boundNIC and boundAddr are set on calls to Bind(). When callers + // attempt actions that would invalidate the binding data (e.g. sending + // data via a NIC other than boundNIC), the endpoint will return an + // error. + boundNIC tcpip.NICID + boundAddr tcpip.Address + // route is the route to a remote network endpoint. It is set via + // Connect(), and is valid only when conneted is true. + route stack.Route `state:"manual"` +} + +// NewEndpoint returns a raw endpoint for the given protocols. +// TODO: IP_HDRINCL, IPPROTO_RAW, and AF_PACKET. +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != header.IPv4ProtocolNumber { + return nil, tcpip.ErrUnknownProtocol + } + + ep := &endpoint{ + stack: stack, + netProto: netProto, + transProto: transProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + return nil, err + } + + return ep, nil +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + if ep.connected { + ep.route.Release() + } + + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +// Write implements tcpip.Endpoint.Write. +func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + ep.mu.RLock() + + if ep.closed { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + // Check whether we've shutdown writing. + if ep.shutdownFlags&tcpip.ShutdownWrite != 0 { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrClosedForSend + } + + // Did the user caller provide a destination? If not, use the connected + // destination. + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if !ep.connected { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrNotConnected + } + + if ep.route.IsResolutionRequired() { + savedRoute := &ep.route + // Promote lock to exclusive if using a shared route, + // given that it may need to change in finishWrite. + ep.mu.RUnlock() + ep.mu.Lock() + + // Make sure that the route didn't change during the + // time we didn't hold the lock. + if !ep.connected || savedRoute != &ep.route { + ep.mu.Unlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + n, ch, err := ep.finishWrite(payload, savedRoute) + ep.mu.Unlock() + return n, ch, err + } + + n, ch, err := ep.finishWrite(payload, &ep.route) + ep.mu.RUnlock() + return n, ch, err + } + + // The caller provided a destination. Reject destination address if it + // goes through a different NIC than the endpoint was bound to. + nic := opts.To.NIC + if ep.bound && nic != 0 && nic != ep.boundNIC { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrNoRoute + } + + // We don't support IPv6 yet, so this has to be an IPv4 address. + if len(opts.To.Addr) != header.IPv4AddressSize { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + // Find the route to the destination. If boundAddress is 0, + // FindRoute will choose an appropriate source address. + route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false) + if err != nil { + ep.mu.RUnlock() + return 0, nil, err + } + + n, ch, err := ep.finishWrite(payload, &route) + route.Release() + ep.mu.RUnlock() + return n, ch, err +} + +// finishWrite writes the payload to a route. It resolves the route if +// necessary. It's really just a helper to make defer unnecessary in Write. +func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { + // We may need to resolve the route (match a link layer address to the + // network address). If that requires blocking (e.g. to use ARP), + // return a channel on which the caller can wait. + if route.IsResolutionRequired() { + waker := &sleep.Waker{} + if ch, err := route.Resolve(waker); err != nil { + if err == tcpip.ErrWouldBlock { + // Link address needs to be resolved. + // Resolution was triggered the background. + // Better luck next time. + route.RemoveWaker(waker) + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + payloadBytes, err := payload.Get(payload.Size()) + if err != nil { + return 0, nil, err + } + + switch ep.netProto { + case header.IPv4ProtocolNumber: + hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), header.ICMPv4ProtocolNumber, route.DefaultTTL()); err != nil { + return 0, nil, err + } + + default: + return 0, nil, tcpip.ErrUnknownProtocol + } + + return uintptr(len(payloadBytes)), nil, nil +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Connect implements tcpip.Endpoint.Connect. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return tcpip.ErrInvalidEndpointState + } + + // We don't support IPv6 yet. + if len(addr.Addr) != header.IPv4AddressSize { + return tcpip.ErrInvalidEndpointState + } + + nic := addr.NIC + if ep.bound { + if ep.boundNIC == 0 { + // If we're bound, but not to a specific NIC, the NIC + // in addr will be used. Nothing to do here. + } else if addr.NIC == 0 { + // If we're bound to a specific NIC, but addr doesn't + // specify a NIC, use the bound NIC. + nic = ep.boundNIC + } else if addr.NIC != ep.boundNIC { + // We're bound and addr specifies a NIC. They must be + // the same. + return tcpip.ErrInvalidEndpointState + } + } + + // Find a route to the destination. + route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false) + if err != nil { + return err + } + defer route.Release() + + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + // Save the route and NIC we've connected via. + ep.route = route.Clone() + ep.registeredNIC = nic + ep.connected = true + + return nil +} + +// Shutdown implements tcpip.Endpoint.Shutdown. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + if !ep.connected { + return tcpip.ErrNotConnected + } + + ep.shutdownFlags |= flags + + if flags&tcpip.ShutdownRead != 0 { + ep.rcvMu.Lock() + wasClosed := ep.rcvClosed + ep.rcvClosed = true + ep.rcvMu.Unlock() + + if !wasClosed { + ep.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen implements tcpip.Endpoint.Listen. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + // Callers must provide an IPv4 address or no network address (for + // binding to a NIC, but not an address). + if len(addr.Addr) != 0 && len(addr.Addr) != 4 { + return tcpip.ErrInvalidEndpointState + } + + // If a local address was specified, verify that it's valid. + if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + ep.registeredNIC = addr.NIC + ep.boundNIC = addr.NIC + ep.boundAddr = addr.Addr + ep.bound = true + + return nil +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + ep.mu.RLock() + defer ep.mu.RUnlock() + + if !ep.connected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: ep.registeredNIC, + Addr: ep.route.RemoteAddress, + }, nil +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + ep.mu.Lock() + *o = tcpip.SendBufferSizeOption(ep.sndBufSize) + ep.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) + ep.rcvMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + ep.rcvMu.Lock() + if ep.rcvList.Empty() { + *o = 0 + } else { + p := ep.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + ep.rcvMu.Unlock() + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// HandlePacket implements stack.RawTransportEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.stack.Stats().DroppedPackets.Increment() + ep.rcvMu.Unlock() + return + } + + if ep.bound { + // If bound to a NIC, only accept data for that NIC. + if ep.boundNIC != 0 && ep.boundNIC != route.NICID() { + ep.rcvMu.Unlock() + return + } + // If bound to an address, only accept data for that address. + if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress { + ep.rcvMu.Unlock() + return + } + } + + // If connected, only accept packets from the remote address we + // connected to. + if ep.connected && ep.route.RemoteAddress != route.RemoteAddress { + ep.rcvMu.Unlock() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + packet := &packet{ + senderAddr: tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + }, + } + + combinedVV := netHeader.ToVectorisedView() + combinedVV.Append(vv) + packet.data = combinedVV.Clone(packet.views[:]) + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} diff --git a/pkg/tcpip/transport/raw/state.go b/pkg/tcpip/transport/raw/state.go new file mode 100644 index 000000000..e3891a8b8 --- /dev/null +++ b/pkg/tcpip/transport/raw/state.go @@ -0,0 +1,88 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // If the endpoint is connected, re-connect via the save/restore stack. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + if err != nil { + panic(*err) + } + } + + // If the endpoint is bound, re-bind via the save/restore stack. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0427af34f..41c87cc7e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1438,7 +1438,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { s := newSegment(r, id, vv) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5637f46e3..19e532180 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -940,7 +940,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { // Get the header then trim it from the view. hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc index 19cded07f..b13806dcb 100644 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ b/test/syscalls/linux/raw_socket_ipv4.cc @@ -16,8 +16,11 @@ #include <netinet/in.h> #include <netinet/ip.h> #include <netinet/ip_icmp.h> +#include <sys/poll.h> #include <sys/socket.h> #include <sys/types.h> +#include <unistd.h> +#include <algorithm> #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" @@ -39,22 +42,26 @@ class RawSocketTest : public ::testing::Test { // Closes the socket created by SetUp(). void TearDown() override; - // The socket used for both reading and writing. - int s_; + // Checks that both an ICMP echo request and reply are received. Calls should + // be wrapped in ASSERT_NO_FATAL_FAILURE. + void ExpectICMPSuccess(const struct icmphdr& icmp); - // The loopback address. - struct sockaddr_in addr_; + void SendEmptyICMP(const struct icmphdr& icmp); + + void SendEmptyICMPTo(int sock, struct sockaddr_in* addr, + const struct icmphdr& icmp); - void SendEmptyICMP(struct icmphdr *icmp); + void ReceiveICMP(char* recv_buf, size_t recv_buf_len, size_t expected_size, + struct sockaddr_in* src); - void SendEmptyICMPTo(int sock, struct sockaddr_in *addr, - struct icmphdr *icmp); + void ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len, + size_t expected_size, struct sockaddr_in* src, int sock); - void ReceiveICMP(char *recv_buf, size_t recv_buf_len, size_t expected_size, - struct sockaddr_in *src); + // The socket used for both reading and writing. + int s_; - void ReceiveICMPFrom(char *recv_buf, size_t recv_buf_len, - size_t expected_size, struct sockaddr_in *src, int sock); + // The loopback address. + struct sockaddr_in addr_; }; void RawSocketTest::SetUp() { @@ -100,49 +107,9 @@ TEST_F(RawSocketTest, SendAndReceive) { icmp.checksum = 2011; icmp.un.echo.sequence = 2012; icmp.un.echo.id = 2014; - ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(&icmp)); - - // We're going to receive both the echo request and reply, but the order is - // indeterminate. - char recv_buf[512]; - struct sockaddr_in src; - bool received_request = false; - bool received_reply = false; - - for (int i = 0; i < 2; i++) { - // Receive the packet. - ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf), - sizeof(struct icmphdr), &src)); - EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0); - struct icmphdr *recvd_icmp = - reinterpret_cast<struct icmphdr *>(recv_buf + sizeof(struct iphdr)); - switch (recvd_icmp->type) { - case ICMP_ECHO: - EXPECT_FALSE(received_request); - received_request = true; - // The packet should be identical to what we sent. - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)), - 0); - break; - - case ICMP_ECHOREPLY: - EXPECT_FALSE(received_reply); - received_reply = true; - // Most fields should be the same. - EXPECT_EQ(recvd_icmp->code, icmp.code); - EXPECT_EQ(recvd_icmp->un.echo.sequence, icmp.un.echo.sequence); - EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id); - // A couple are different. - EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY); - // The checksum is computed in such a way that it is guaranteed to have - // changed. - EXPECT_NE(recvd_icmp->checksum, icmp.checksum); - break; - } - } + ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); - ASSERT_TRUE(received_request); - ASSERT_TRUE(received_reply); + ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); } // We should be able to create multiple raw sockets for the same protocol and @@ -162,7 +129,7 @@ TEST_F(RawSocketTest, MultipleSocketReceive) { icmp.checksum = 2014; icmp.un.echo.sequence = 2016; icmp.un.echo.id = 2018; - ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(&icmp)); + ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); // Both sockets will receive the echo request and reply in indeterminate // order, so we'll need to read 2 packets from each. @@ -191,12 +158,14 @@ TEST_F(RawSocketTest, MultipleSocketReceive) { int types[] = {ICMP_ECHO, ICMP_ECHOREPLY}; for (int type : types) { auto match_type = [=](char buf[kBufSize]) { - struct icmphdr *icmp = - reinterpret_cast<struct icmphdr *>(buf + sizeof(struct iphdr)); + struct icmphdr* icmp = + reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr)); return icmp->type == type; }; - char *icmp1 = *std::find_if(recv_buf1.begin(), recv_buf1.end(), match_type); - char *icmp2 = *std::find_if(recv_buf2.begin(), recv_buf2.end(), match_type); + const char* icmp1 = + *std::find_if(recv_buf1.begin(), recv_buf1.end(), match_type); + const char* icmp2 = + *std::find_if(recv_buf2.begin(), recv_buf2.end(), match_type); ASSERT_NE(icmp1, *recv_buf1.end()); ASSERT_NE(icmp2, *recv_buf2.end()); EXPECT_EQ(memcmp(icmp1 + sizeof(struct iphdr), icmp2 + sizeof(struct iphdr), @@ -217,10 +186,10 @@ TEST_F(RawSocketTest, RawAndPingSockets) { struct icmphdr icmp; icmp.type = ICMP_ECHO; icmp.code = 0; - icmp.un.echo.sequence = - *static_cast<unsigned short *>(&icmp.un.echo.sequence); + icmp.un.echo.sequence = *static_cast<unsigned short*>(&icmp.un.echo.sequence); ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, sizeof(icmp), 0, - (struct sockaddr *)&addr_, sizeof(addr_)), + reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), SyscallSucceedsWithValue(sizeof(icmp))); // Both sockets will receive the echo request and reply in indeterminate @@ -247,12 +216,12 @@ TEST_F(RawSocketTest, RawAndPingSockets) { int types[] = {ICMP_ECHO, ICMP_ECHOREPLY}; for (int type : types) { auto match_type_ping = [=](char buf[kBufSize]) { - struct icmphdr *icmp = reinterpret_cast<struct icmphdr *>(buf); + struct icmphdr* icmp = reinterpret_cast<struct icmphdr*>(buf); return icmp->type == type; }; auto match_type_raw = [=](char buf[kBufSize]) { - struct icmphdr *icmp = - reinterpret_cast<struct icmphdr *>(buf + sizeof(struct iphdr)); + struct icmphdr* icmp = + reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr)); return icmp->type == type; }; @@ -266,39 +235,317 @@ TEST_F(RawSocketTest, RawAndPingSockets) { } } -void RawSocketTest::SendEmptyICMP(struct icmphdr *icmp) { +// Test that shutting down an unconnected socket fails. +TEST_F(RawSocketTest, FailShutdownWithoutConnect) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); +} + +// Test that writing to a shutdown write socket fails. +TEST_F(RawSocketTest, FailWritingToShutdown) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); + + char c; + ASSERT_THAT(RetryEINTR(write)(s_, &c, sizeof(c)), + SyscallFailsWithErrno(EPIPE)); +} + +// Test that reading from a shutdown read socket gets nothing. +TEST_F(RawSocketTest, FailReadingFromShutdown) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + + char c; + ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(0)); +} + +// Test that listen() fails. +TEST_F(RawSocketTest, FailListen) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP)); +} + +// Test that accept() fails. +TEST_F(RawSocketTest, FailAccept) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr saddr; + socklen_t addrlen; + ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP)); +} + +// Test that getpeername() returns nothing before connect(). +TEST_F(RawSocketTest, FailGetPeerNameBeforeConnect) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr saddr; + socklen_t addrlen; + ASSERT_THAT(getpeername(s_, &saddr, &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +// Test that getpeername() returns something after connect(). +TEST_F(RawSocketTest, GetPeerName) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + struct sockaddr saddr; + socklen_t addrlen; + ASSERT_THAT(getpeername(s_, &saddr, &addrlen), SyscallSucceeds()); + ASSERT_GT(addrlen, 0); +} + +// Test that the socket is writable immediately. +TEST_F(RawSocketTest, PollWritableImmediately) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct pollfd pfd = {}; + pfd.fd = s_; + pfd.events = POLLOUT; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); +} + +// Test that the socket isn't readable before receiving anything. +TEST_F(RawSocketTest, PollNotReadableInitially) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Try to receive data with MSG_DONTWAIT, which returns immediately if there's + // nothing to be read. + char buf[117]; + ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Test that the socket becomes readable once something is written to it. +TEST_F(RawSocketTest, PollTriggeredOnWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Write something so that there's data to be read. + struct icmphdr icmp = {}; + ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); + + struct pollfd pfd = {}; + pfd.fd = s_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); +} + +// Test that we can connect() to a valid IP (loopback). +TEST_F(RawSocketTest, ConnectToLoopback) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); +} + +// Test that connect() sends packets to the right place. +TEST_F(RawSocketTest, SendAndReceiveViaConnect) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + + // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, + // and ID. None of that should matter for raw sockets - the kernel should + // still give us the packet. + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.checksum = 2001; + icmp.un.echo.sequence = 2003; + icmp.un.echo.id = 2004; + ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0), + SyscallSucceedsWithValue(sizeof(icmp))); + + ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); +} + +// Test that calling send() without connect() fails. +TEST_F(RawSocketTest, SendWithoutConnectFails) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, + // and ID. None of that should matter for raw sockets - the kernel should + // still give us the packet. + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.checksum = 2015; + icmp.un.echo.sequence = 2017; + icmp.un.echo.id = 2019; + ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0), + SyscallFailsWithErrno(ENOTCONN)); +} + +// Bind to localhost. +TEST_F(RawSocketTest, BindToLocalhost) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); +} + +// Bind to a different address. +TEST_F(RawSocketTest, BindToInvalid) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +// Bind to localhost, then send and receive packets. +TEST_F(RawSocketTest, BindSendAndReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + + // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, + // and ID. None of that should matter for raw sockets - the kernel should + // still give us the packet. + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.checksum = 2001; + icmp.un.echo.sequence = 2004; + icmp.un.echo.id = 2007; + ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); + + ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); +} + +// Bind and connect to localhost and send/receive packets. +TEST_F(RawSocketTest, BindConnectSendAndReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); + + // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence, + // and ID. None of that should matter for raw sockets - the kernel should + // still give us the packet. + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.checksum = 2009; + icmp.un.echo.sequence = 2010; + icmp.un.echo.id = 7; + ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); + + ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); +} + +void RawSocketTest::ExpectICMPSuccess(const struct icmphdr& icmp) { + // We're going to receive both the echo request and reply, but the order is + // indeterminate. + char recv_buf[512]; + struct sockaddr_in src; + bool received_request = false; + bool received_reply = false; + + for (int i = 0; i < 2; i++) { + // Receive the packet. + ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf), + sizeof(struct icmphdr), &src)); + EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0); + struct icmphdr* recvd_icmp = + reinterpret_cast<struct icmphdr*>(recv_buf + sizeof(struct iphdr)); + switch (recvd_icmp->type) { + case ICMP_ECHO: + EXPECT_FALSE(received_request); + received_request = true; + // The packet should be identical to what we sent. + EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)), + 0); + break; + + case ICMP_ECHOREPLY: + EXPECT_FALSE(received_reply); + received_reply = true; + // Most fields should be the same. + EXPECT_EQ(recvd_icmp->code, icmp.code); + EXPECT_EQ(recvd_icmp->un.echo.sequence, icmp.un.echo.sequence); + EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id); + // A couple are different. + EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY); + // The checksum is computed in such a way that it is guaranteed to have + // changed. + EXPECT_NE(recvd_icmp->checksum, icmp.checksum); + break; + } + } + + ASSERT_TRUE(received_request); + ASSERT_TRUE(received_reply); +} + +void RawSocketTest::SendEmptyICMP(const struct icmphdr& icmp) { ASSERT_NO_FATAL_FAILURE(SendEmptyICMPTo(s_, &addr_, icmp)); } -void RawSocketTest::SendEmptyICMPTo(int sock, struct sockaddr_in *addr, - struct icmphdr *icmp) { - struct iovec iov = {.iov_base = icmp, .iov_len = sizeof(*icmp)}; - struct msghdr msg { - .msg_name = addr, .msg_namelen = sizeof(*addr), .msg_iov = &iov, - .msg_iovlen = 1, .msg_control = NULL, .msg_controllen = 0, .msg_flags = 0, - }; - ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(*icmp))); +void RawSocketTest::SendEmptyICMPTo(int sock, struct sockaddr_in* addr, + const struct icmphdr& icmp) { + // It's safe to use const_cast here because sendmsg won't modify the iovec. + struct iovec iov = {}; + iov.iov_base = static_cast<void*>(const_cast<struct icmphdr*>(&icmp)); + iov.iov_len = sizeof(icmp); + struct msghdr msg = {}; + msg.msg_name = addr; + msg.msg_namelen = sizeof(*addr); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(icmp))); } -void RawSocketTest::ReceiveICMP(char *recv_buf, size_t recv_buf_len, - size_t expected_size, struct sockaddr_in *src) { +void RawSocketTest::ReceiveICMP(char* recv_buf, size_t recv_buf_len, + size_t expected_size, struct sockaddr_in* src) { ASSERT_NO_FATAL_FAILURE( ReceiveICMPFrom(recv_buf, recv_buf_len, expected_size, src, s_)); } -void RawSocketTest::ReceiveICMPFrom(char *recv_buf, size_t recv_buf_len, +void RawSocketTest::ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len, size_t expected_size, - struct sockaddr_in *src, int sock) { - struct iovec iov = {.iov_base = recv_buf, .iov_len = recv_buf_len}; - struct msghdr msg = { - .msg_name = src, - .msg_namelen = sizeof(*src), - .msg_iov = &iov, - .msg_iovlen = 1, - .msg_control = NULL, - .msg_controllen = 0, - .msg_flags = 0, - }; + struct sockaddr_in* src, int sock) { + struct iovec iov = {}; + iov.iov_base = recv_buf; + iov.iov_len = recv_buf_len; + struct msghdr msg = {}; + msg.msg_name = src; + msg.msg_namelen = sizeof(*src); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; // We should receive the ICMP packet plus 20 bytes of IP header. ASSERT_THAT(recvmsg(sock, &msg, 0), SyscallSucceedsWithValue(expected_size + sizeof(struct iphdr))); |