summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/registration.go13
-rw-r--r--pkg/tcpip/stack/stack.go20
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go67
-rw-r--r--pkg/tcpip/stack/transport_test.go2
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go76
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go5
-rw-r--r--pkg/tcpip/transport/raw/BUILD45
-rw-r--r--pkg/tcpip/transport/raw/raw.go558
-rw-r--r--pkg/tcpip/transport/raw/state.go88
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--test/syscalls/linux/raw_socket_ipv4.cc421
13 files changed, 1106 insertions, 194 deletions
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ff356ea22..f3cc849ec 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -64,13 +64,24 @@ const (
type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView)
+ HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
+// RawTransportEndpoint is the interface that needs to be implemented by raw
+// transport protocol endpoints. RawTransportEndpoints receive the entire
+// packet - including the link, network, and transport headers - as delivered
+// to netstack.
+type RawTransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint. The packet contains all data from the link
+ // layer up.
+ HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
+}
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 15a268b10..a74c0a7a0 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -955,11 +955,11 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
-// transport dispatcher. Received packets that match the provided protocol will
-// be delivered to the given endpoint.
-func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+// transport dispatcher. Received packets that match the provided transport
+// protocol will be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
if nicID == 0 {
- return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
s.mu.RLock()
@@ -970,14 +970,14 @@ func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpi
return tcpip.ErrUnknownNICID
}
- return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+ return nic.demux.registerRawEndpoint(netProto, transProto, ep)
}
-// UnregisterRawTransportEndpoint removes the endpoint for the protocol from
-// the stack transport dispatcher.
-func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+// UnregisterRawTransportEndpoint removes the endpoint for the transport
+// protocol from the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
return
}
@@ -986,7 +986,7 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tc
nic := s.nics[nicID]
if nic != nil {
- nic.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 9ab314188..a8ac18e72 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"math/rand"
"sync"
@@ -37,7 +38,7 @@ type transportEndpoints struct {
endpoints map[TransportEndpointID]TransportEndpoint
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
- rawEndpoints []TransportEndpoint
+ rawEndpoints []RawTransportEndpoint
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
@@ -60,8 +61,10 @@ func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep Tra
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
-// based on endpoints IDs.
+// based on endpoints IDs. It should only be instantiated via
+// newTransportDemuxer.
type transportDemuxer struct {
+ // protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
}
@@ -137,22 +140,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
// If this is a broadcast datagram, deliver the datagram to all endpoints
// managed by ep.
if id.LocalAddress == header.IPv4Broadcast {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, netHeader, vv)
+ endpoint.HandlePacket(r, id, vv)
break
}
vvCopy := buffer.NewView(vv.Size())
copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
} else {
- ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv)
+ ep.selectEndpoint(id).HandlePacket(r, id, vv)
}
}
@@ -286,17 +289,17 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
// raw endpoint first. If there are multipe raw endpoints, they all
// receive the packet.
- found := false
+ foundRaw := false
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
- found = true
+ rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ foundRaw = true
}
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 && !found {
+ if len(destEps) == 0 && !foundRaw {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
@@ -306,7 +309,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, netHeader, vv)
+ ep.HandlePacket(r, id, vv)
}
return true
@@ -371,19 +374,8 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
// endpoint.
-func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- for i, n := range netProtos {
- if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil {
- d.unregisterRawEndpoint(netProtos[:i], protocol, ep)
- return err
- }
- }
-
- return nil
-}
-
-func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error {
- eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
}
@@ -395,19 +387,20 @@ func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProto
return nil
}
-// unregisterRawEndpoint unregisters the raw endpoint for the given protocol
-// such that it won't receive any more packets.
-func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
- for _, n := range netProtos {
- if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.mu.Lock()
- defer eps.mu.Unlock()
- for i, rawEP := range eps.rawEndpoints {
- if rawEP == ep {
- eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
- return
- }
- }
+// unregisterRawEndpoint unregisters the raw endpoint for the given transport
+// protocol such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
}
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index dfd31557a..0c2589083 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 74d9ff253..9aa6f3978 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/raw",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 182097b46..8f2e3aa20 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -59,10 +59,6 @@ type endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
- // raw indicates whether the endpoint is intended for use by a raw
- // socket, which returns the network layer header along with the
- // payload. It is immutable.
- raw bool
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -80,32 +76,26 @@ type endpoint struct {
shutdownFlags tcpip.ShutdownFlags
id stack.TransportEndpointID
state endpointState
- bindNICID tcpip.NICID
- bindAddr tcpip.Address
- regNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ // bindNICID and bindAddr are set via calls to Bind(). They are used to
+ // reject attempts to send data or connect via a different NIC or
+ // address
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ // regNICID is the default NIC to be used when callers don't specify a
+ // NIC.
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) {
- e := &endpoint{
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return &endpoint{
stack: stack,
netProto: netProto,
transProto: transProto,
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
- raw: raw,
- }
-
- // Raw endpoints must be immediately bound because they receive all
- // ICMP traffic starting from when they're created via socket().
- if raw {
- if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
- return nil, err
- }
- }
-
- return e, nil
+ }, nil
}
// Close puts the endpoint in a closed state and frees all resources
@@ -115,11 +105,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- if e.raw {
- e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e)
- } else {
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
- }
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
}
// Close the receive list and drain it.
@@ -244,8 +230,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
route = &e.route
if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route, given that it may
- // need to change in Route.Resolve() call below.
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in Route.Resolve()
+ // call below.
e.mu.RUnlock()
defer e.mu.RLock()
@@ -290,8 +277,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
waker := &sleep.Waker{}
if ch, err := route.Resolve(waker); err != nil {
if err == tcpip.ErrWouldBlock {
- // Link address needs to be resolved. Resolution was triggered the
- // background. Better luck next time.
+ // Link address needs to be resolved.
+ // Resolution was triggered the background.
+ // Better luck next time.
route.RemoveWaker(waker)
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -368,11 +356,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
- if e.raw {
- hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
- }
-
if len(data) < header.ICMPv4EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
@@ -439,11 +422,6 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- // TODO: We don't yet support connect on a raw socket.
- if e.raw {
- return tcpip.ErrNotSupported
- }
-
e.mu.Lock()
defer e.mu.Unlock()
@@ -547,11 +525,6 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
- if e.raw {
- err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false)
- return stack.TransportEndpointID{}, err
- }
-
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
@@ -687,11 +660,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.stack.Stats().DroppedPackets.Increment()
e.rcvMu.Unlock()
return
}
@@ -706,13 +680,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, ne
},
}
- if e.raw {
- combinedVV := netHeader.ToVectorisedView()
- combinedVV.Append(vv)
- pkt.data = combinedVV.Clone(pkt.views[:])
- } else {
- pkt.data = vv.Clone(pkt.views[:])
- }
+ pkt.data = vv.Clone(pkt.views[:])
e.rcvList.PushBack(pkt)
e.rcvBufSize += pkt.data.Size()
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 36b70988a..09ee2f892 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -30,6 +30,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
@@ -73,7 +74,7 @@ func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtoco
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue, false)
+ return newEndpoint(stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
@@ -82,7 +83,7 @@ func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProt
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue, true)
+ return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
new file mode 100644
index 000000000..005079639
--- /dev/null
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -0,0 +1,45 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_template_instance(
+ name = "packet_list",
+ out = "packet_list.go",
+ package = "raw",
+ prefix = "packet",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*packet",
+ "Linker": "*packet",
+ },
+)
+
+go_library(
+ name = "raw",
+ srcs = [
+ "packet_list.go",
+ "raw.go",
+ "state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw",
+ imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sleep",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "packet_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/transport/raw/raw.go b/pkg/tcpip/transport/raw/raw.go
new file mode 100644
index 000000000..8dada2e4f
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw.go
@@ -0,0 +1,558 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package raw provides the implementation of raw sockets (see raw(7)). Raw
+// sockets allow applications to:
+//
+// * manually write and inspect transport layer headers and payloads
+// * receive all traffic of a given transport protcol (e.g. ICMP or UDP)
+// * optionally write and inspect network layer and link layer headers for
+// packets
+//
+// Raw sockets don't have any notion of ports, and incoming packets are
+// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
+// receive every UDP packet received by netstack. bind(2) and connect(2) can be
+// used to filter incoming packets by source and destination.
+package raw
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is pre-allocated space to back data. As long as the packet is
+ // made up of fewer than 8 buffer.Views, no extra allocation is
+ // necessary to store packet data.
+ views [8]buffer.View `state:"nosave"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
+// have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+ closed bool
+ connected bool
+ bound bool
+ // registeredNIC is the NIC to which th endpoint is explicitly
+ // registered. Is set when Connect or Bind are used to specify a NIC.
+ registeredNIC tcpip.NICID
+ // boundNIC and boundAddr are set on calls to Bind(). When callers
+ // attempt actions that would invalidate the binding data (e.g. sending
+ // data via a NIC other than boundNIC), the endpoint will return an
+ // error.
+ boundNIC tcpip.NICID
+ boundAddr tcpip.Address
+ // route is the route to a remote network endpoint. It is set via
+ // Connect(), and is valid only when conneted is true.
+ route stack.Route `state:"manual"`
+}
+
+// NewEndpoint returns a raw endpoint for the given protocols.
+// TODO: IP_HDRINCL, IPPROTO_RAW, and AF_PACKET.
+func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != header.IPv4ProtocolNumber {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ ep := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ return nil, err
+ }
+
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ if ep.connected {
+ ep.route.Release()
+ }
+
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+// Write implements tcpip.Endpoint.Write.
+func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ ep.mu.RLock()
+
+ if ep.closed {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Check whether we've shutdown writing.
+ if ep.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Did the user caller provide a destination? If not, use the connected
+ // destination.
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if !ep.connected {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrNotConnected
+ }
+
+ if ep.route.IsResolutionRequired() {
+ savedRoute := &ep.route
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in finishWrite.
+ ep.mu.RUnlock()
+ ep.mu.Lock()
+
+ // Make sure that the route didn't change during the
+ // time we didn't hold the lock.
+ if !ep.connected || savedRoute != &ep.route {
+ ep.mu.Unlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ n, ch, err := ep.finishWrite(payload, savedRoute)
+ ep.mu.Unlock()
+ return n, ch, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &ep.route)
+ ep.mu.RUnlock()
+ return n, ch, err
+ }
+
+ // The caller provided a destination. Reject destination address if it
+ // goes through a different NIC than the endpoint was bound to.
+ nic := opts.To.NIC
+ if ep.bound && nic != 0 && nic != ep.boundNIC {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ // We don't support IPv6 yet, so this has to be an IPv4 address.
+ if len(opts.To.Addr) != header.IPv4AddressSize {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Find the route to the destination. If boundAddress is 0,
+ // FindRoute will choose an appropriate source address.
+ route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false)
+ if err != nil {
+ ep.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &route)
+ route.Release()
+ ep.mu.RUnlock()
+ return n, ch, err
+}
+
+// finishWrite writes the payload to a route. It resolves the route if
+// necessary. It's really just a helper to make defer unnecessary in Write.
+func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // We may need to resolve the route (match a link layer address to the
+ // network address). If that requires blocking (e.g. to use ARP),
+ // return a channel on which the caller can wait.
+ if route.IsResolutionRequired() {
+ waker := &sleep.Waker{}
+ if ch, err := route.Resolve(waker); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Link address needs to be resolved.
+ // Resolution was triggered the background.
+ // Better luck next time.
+ route.RemoveWaker(waker)
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ payloadBytes, err := payload.Get(payload.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch ep.netProto {
+ case header.IPv4ProtocolNumber:
+ hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
+ if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), header.ICMPv4ProtocolNumber, route.DefaultTTL()); err != nil {
+ return 0, nil, err
+ }
+
+ default:
+ return 0, nil, tcpip.ErrUnknownProtocol
+ }
+
+ return uintptr(len(payloadBytes)), nil, nil
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Connect implements tcpip.Endpoint.Connect.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // We don't support IPv6 yet.
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nic := addr.NIC
+ if ep.bound {
+ if ep.boundNIC == 0 {
+ // If we're bound, but not to a specific NIC, the NIC
+ // in addr will be used. Nothing to do here.
+ } else if addr.NIC == 0 {
+ // If we're bound to a specific NIC, but addr doesn't
+ // specify a NIC, use the bound NIC.
+ nic = ep.boundNIC
+ } else if addr.NIC != ep.boundNIC {
+ // We're bound and addr specifies a NIC. They must be
+ // the same.
+ return tcpip.ErrInvalidEndpointState
+ }
+ }
+
+ // Find a route to the destination.
+ route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ // Save the route and NIC we've connected via.
+ ep.route = route.Clone()
+ ep.registeredNIC = nic
+ ep.connected = true
+
+ return nil
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if !ep.connected {
+ return tcpip.ErrNotConnected
+ }
+
+ ep.shutdownFlags |= flags
+
+ if flags&tcpip.ShutdownRead != 0 {
+ ep.rcvMu.Lock()
+ wasClosed := ep.rcvClosed
+ ep.rcvClosed = true
+ ep.rcvMu.Unlock()
+
+ if !wasClosed {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen implements tcpip.Endpoint.Listen.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // Callers must provide an IPv4 address or no network address (for
+ // binding to a NIC, but not an address).
+ if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If a local address was specified, verify that it's valid.
+ if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.registeredNIC = addr.NIC
+ ep.boundNIC = addr.NIC
+ ep.boundAddr = addr.Addr
+ ep.bound = true
+
+ return nil
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ if !ep.connected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: ep.registeredNIC,
+ Addr: ep.route.RemoteAddress,
+ }, nil
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
+ ep.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ ep.rcvMu.Lock()
+ if ep.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := ep.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ if ep.bound {
+ // If bound to a NIC, only accept data for that NIC.
+ if ep.boundNIC != 0 && ep.boundNIC != route.NICID() {
+ ep.rcvMu.Unlock()
+ return
+ }
+ // If bound to an address, only accept data for that address.
+ if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+ }
+
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if ep.connected && ep.route.RemoteAddress != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ packet := &packet{
+ senderAddr: tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ },
+ }
+
+ combinedVV := netHeader.ToVectorisedView()
+ combinedVV.Append(vv)
+ packet.data = combinedVV.Clone(packet.views[:])
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/state.go b/pkg/tcpip/transport/raw/state.go
new file mode 100644
index 000000000..e3891a8b8
--- /dev/null
+++ b/pkg/tcpip/transport/raw/state.go
@@ -0,0 +1,88 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // If the endpoint is connected, re-connect via the save/restore stack.
+ if ep.connected {
+ var err *tcpip.Error
+ ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ if err != nil {
+ panic(*err)
+ }
+ }
+
+ // If the endpoint is bound, re-bind via the save/restore stack.
+ if ep.bound {
+ if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 0427af34f..41c87cc7e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1438,7 +1438,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
s := newSegment(r, id, vv)
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 5637f46e3..19e532180 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -940,7 +940,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
// Get the header then trim it from the view.
hdr := header.UDP(vv.First())
if int(hdr.Length()) > vv.Size() {
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc
index 19cded07f..b13806dcb 100644
--- a/test/syscalls/linux/raw_socket_ipv4.cc
+++ b/test/syscalls/linux/raw_socket_ipv4.cc
@@ -16,8 +16,11 @@
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip_icmp.h>
+#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/types.h>
+#include <unistd.h>
+#include <algorithm>
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
@@ -39,22 +42,26 @@ class RawSocketTest : public ::testing::Test {
// Closes the socket created by SetUp().
void TearDown() override;
- // The socket used for both reading and writing.
- int s_;
+ // Checks that both an ICMP echo request and reply are received. Calls should
+ // be wrapped in ASSERT_NO_FATAL_FAILURE.
+ void ExpectICMPSuccess(const struct icmphdr& icmp);
- // The loopback address.
- struct sockaddr_in addr_;
+ void SendEmptyICMP(const struct icmphdr& icmp);
+
+ void SendEmptyICMPTo(int sock, struct sockaddr_in* addr,
+ const struct icmphdr& icmp);
- void SendEmptyICMP(struct icmphdr *icmp);
+ void ReceiveICMP(char* recv_buf, size_t recv_buf_len, size_t expected_size,
+ struct sockaddr_in* src);
- void SendEmptyICMPTo(int sock, struct sockaddr_in *addr,
- struct icmphdr *icmp);
+ void ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len,
+ size_t expected_size, struct sockaddr_in* src, int sock);
- void ReceiveICMP(char *recv_buf, size_t recv_buf_len, size_t expected_size,
- struct sockaddr_in *src);
+ // The socket used for both reading and writing.
+ int s_;
- void ReceiveICMPFrom(char *recv_buf, size_t recv_buf_len,
- size_t expected_size, struct sockaddr_in *src, int sock);
+ // The loopback address.
+ struct sockaddr_in addr_;
};
void RawSocketTest::SetUp() {
@@ -100,49 +107,9 @@ TEST_F(RawSocketTest, SendAndReceive) {
icmp.checksum = 2011;
icmp.un.echo.sequence = 2012;
icmp.un.echo.id = 2014;
- ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(&icmp));
-
- // We're going to receive both the echo request and reply, but the order is
- // indeterminate.
- char recv_buf[512];
- struct sockaddr_in src;
- bool received_request = false;
- bool received_reply = false;
-
- for (int i = 0; i < 2; i++) {
- // Receive the packet.
- ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf),
- sizeof(struct icmphdr), &src));
- EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0);
- struct icmphdr *recvd_icmp =
- reinterpret_cast<struct icmphdr *>(recv_buf + sizeof(struct iphdr));
- switch (recvd_icmp->type) {
- case ICMP_ECHO:
- EXPECT_FALSE(received_request);
- received_request = true;
- // The packet should be identical to what we sent.
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)),
- 0);
- break;
-
- case ICMP_ECHOREPLY:
- EXPECT_FALSE(received_reply);
- received_reply = true;
- // Most fields should be the same.
- EXPECT_EQ(recvd_icmp->code, icmp.code);
- EXPECT_EQ(recvd_icmp->un.echo.sequence, icmp.un.echo.sequence);
- EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id);
- // A couple are different.
- EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY);
- // The checksum is computed in such a way that it is guaranteed to have
- // changed.
- EXPECT_NE(recvd_icmp->checksum, icmp.checksum);
- break;
- }
- }
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
- ASSERT_TRUE(received_request);
- ASSERT_TRUE(received_reply);
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
}
// We should be able to create multiple raw sockets for the same protocol and
@@ -162,7 +129,7 @@ TEST_F(RawSocketTest, MultipleSocketReceive) {
icmp.checksum = 2014;
icmp.un.echo.sequence = 2016;
icmp.un.echo.id = 2018;
- ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(&icmp));
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
// Both sockets will receive the echo request and reply in indeterminate
// order, so we'll need to read 2 packets from each.
@@ -191,12 +158,14 @@ TEST_F(RawSocketTest, MultipleSocketReceive) {
int types[] = {ICMP_ECHO, ICMP_ECHOREPLY};
for (int type : types) {
auto match_type = [=](char buf[kBufSize]) {
- struct icmphdr *icmp =
- reinterpret_cast<struct icmphdr *>(buf + sizeof(struct iphdr));
+ struct icmphdr* icmp =
+ reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr));
return icmp->type == type;
};
- char *icmp1 = *std::find_if(recv_buf1.begin(), recv_buf1.end(), match_type);
- char *icmp2 = *std::find_if(recv_buf2.begin(), recv_buf2.end(), match_type);
+ const char* icmp1 =
+ *std::find_if(recv_buf1.begin(), recv_buf1.end(), match_type);
+ const char* icmp2 =
+ *std::find_if(recv_buf2.begin(), recv_buf2.end(), match_type);
ASSERT_NE(icmp1, *recv_buf1.end());
ASSERT_NE(icmp2, *recv_buf2.end());
EXPECT_EQ(memcmp(icmp1 + sizeof(struct iphdr), icmp2 + sizeof(struct iphdr),
@@ -217,10 +186,10 @@ TEST_F(RawSocketTest, RawAndPingSockets) {
struct icmphdr icmp;
icmp.type = ICMP_ECHO;
icmp.code = 0;
- icmp.un.echo.sequence =
- *static_cast<unsigned short *>(&icmp.un.echo.sequence);
+ icmp.un.echo.sequence = *static_cast<unsigned short*>(&icmp.un.echo.sequence);
ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, sizeof(icmp), 0,
- (struct sockaddr *)&addr_, sizeof(addr_)),
+ reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_)),
SyscallSucceedsWithValue(sizeof(icmp)));
// Both sockets will receive the echo request and reply in indeterminate
@@ -247,12 +216,12 @@ TEST_F(RawSocketTest, RawAndPingSockets) {
int types[] = {ICMP_ECHO, ICMP_ECHOREPLY};
for (int type : types) {
auto match_type_ping = [=](char buf[kBufSize]) {
- struct icmphdr *icmp = reinterpret_cast<struct icmphdr *>(buf);
+ struct icmphdr* icmp = reinterpret_cast<struct icmphdr*>(buf);
return icmp->type == type;
};
auto match_type_raw = [=](char buf[kBufSize]) {
- struct icmphdr *icmp =
- reinterpret_cast<struct icmphdr *>(buf + sizeof(struct iphdr));
+ struct icmphdr* icmp =
+ reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr));
return icmp->type == type;
};
@@ -266,39 +235,317 @@ TEST_F(RawSocketTest, RawAndPingSockets) {
}
}
-void RawSocketTest::SendEmptyICMP(struct icmphdr *icmp) {
+// Test that shutting down an unconnected socket fails.
+TEST_F(RawSocketTest, FailShutdownWithoutConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Test that writing to a shutdown write socket fails.
+TEST_F(RawSocketTest, FailWritingToShutdown) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+ ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds());
+
+ char c;
+ ASSERT_THAT(RetryEINTR(write)(s_, &c, sizeof(c)),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+// Test that reading from a shutdown read socket gets nothing.
+TEST_F(RawSocketTest, FailReadingFromShutdown) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+ ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
+
+ char c;
+ ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(0));
+}
+
+// Test that listen() fails.
+TEST_F(RawSocketTest, FailListen) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that accept() fails.
+TEST_F(RawSocketTest, FailAccept) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr saddr;
+ socklen_t addrlen;
+ ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that getpeername() returns nothing before connect().
+TEST_F(RawSocketTest, FailGetPeerNameBeforeConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr saddr;
+ socklen_t addrlen;
+ ASSERT_THAT(getpeername(s_, &saddr, &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Test that getpeername() returns something after connect().
+TEST_F(RawSocketTest, GetPeerName) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+ struct sockaddr saddr;
+ socklen_t addrlen;
+ ASSERT_THAT(getpeername(s_, &saddr, &addrlen), SyscallSucceeds());
+ ASSERT_GT(addrlen, 0);
+}
+
+// Test that the socket is writable immediately.
+TEST_F(RawSocketTest, PollWritableImmediately) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct pollfd pfd = {};
+ pfd.fd = s_;
+ pfd.events = POLLOUT;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
+}
+
+// Test that the socket isn't readable before receiving anything.
+TEST_F(RawSocketTest, PollNotReadableInitially) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Try to receive data with MSG_DONTWAIT, which returns immediately if there's
+ // nothing to be read.
+ char buf[117];
+ ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Test that the socket becomes readable once something is written to it.
+TEST_F(RawSocketTest, PollTriggeredOnWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Write something so that there's data to be read.
+ struct icmphdr icmp = {};
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ struct pollfd pfd = {};
+ pfd.fd = s_;
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
+}
+
+// Test that we can connect() to a valid IP (loopback).
+TEST_F(RawSocketTest, ConnectToLoopback) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+}
+
+// Test that connect() sends packets to the right place.
+TEST_F(RawSocketTest, SendAndReceiveViaConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 2001;
+ icmp.un.echo.sequence = 2003;
+ icmp.un.echo.id = 2004;
+ ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0),
+ SyscallSucceedsWithValue(sizeof(icmp)));
+
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
+}
+
+// Test that calling send() without connect() fails.
+TEST_F(RawSocketTest, SendWithoutConnectFails) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 2015;
+ icmp.un.echo.sequence = 2017;
+ icmp.un.echo.id = 2019;
+ ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Bind to localhost.
+TEST_F(RawSocketTest, BindToLocalhost) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+}
+
+// Bind to a different address.
+TEST_F(RawSocketTest, BindToInvalid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr_in bind_addr = {};
+ bind_addr.sin_family = AF_INET;
+ bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
+ ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+// Bind to localhost, then send and receive packets.
+TEST_F(RawSocketTest, BindSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 2001;
+ icmp.un.echo.sequence = 2004;
+ icmp.un.echo.id = 2007;
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
+}
+
+// Bind and connect to localhost and send/receive packets.
+TEST_F(RawSocketTest, BindConnectSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Prepare and send an ICMP packet. Use arbitrary junk for checksum, sequence,
+ // and ID. None of that should matter for raw sockets - the kernel should
+ // still give us the packet.
+ struct icmphdr icmp;
+ icmp.type = ICMP_ECHO;
+ icmp.code = 0;
+ icmp.checksum = 2009;
+ icmp.un.echo.sequence = 2010;
+ icmp.un.echo.id = 7;
+ ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp));
+
+ ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
+}
+
+void RawSocketTest::ExpectICMPSuccess(const struct icmphdr& icmp) {
+ // We're going to receive both the echo request and reply, but the order is
+ // indeterminate.
+ char recv_buf[512];
+ struct sockaddr_in src;
+ bool received_request = false;
+ bool received_reply = false;
+
+ for (int i = 0; i < 2; i++) {
+ // Receive the packet.
+ ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf, ABSL_ARRAYSIZE(recv_buf),
+ sizeof(struct icmphdr), &src));
+ EXPECT_EQ(memcmp(&src, &addr_, sizeof(sockaddr_in)), 0);
+ struct icmphdr* recvd_icmp =
+ reinterpret_cast<struct icmphdr*>(recv_buf + sizeof(struct iphdr));
+ switch (recvd_icmp->type) {
+ case ICMP_ECHO:
+ EXPECT_FALSE(received_request);
+ received_request = true;
+ // The packet should be identical to what we sent.
+ EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), &icmp, sizeof(icmp)),
+ 0);
+ break;
+
+ case ICMP_ECHOREPLY:
+ EXPECT_FALSE(received_reply);
+ received_reply = true;
+ // Most fields should be the same.
+ EXPECT_EQ(recvd_icmp->code, icmp.code);
+ EXPECT_EQ(recvd_icmp->un.echo.sequence, icmp.un.echo.sequence);
+ EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id);
+ // A couple are different.
+ EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY);
+ // The checksum is computed in such a way that it is guaranteed to have
+ // changed.
+ EXPECT_NE(recvd_icmp->checksum, icmp.checksum);
+ break;
+ }
+ }
+
+ ASSERT_TRUE(received_request);
+ ASSERT_TRUE(received_reply);
+}
+
+void RawSocketTest::SendEmptyICMP(const struct icmphdr& icmp) {
ASSERT_NO_FATAL_FAILURE(SendEmptyICMPTo(s_, &addr_, icmp));
}
-void RawSocketTest::SendEmptyICMPTo(int sock, struct sockaddr_in *addr,
- struct icmphdr *icmp) {
- struct iovec iov = {.iov_base = icmp, .iov_len = sizeof(*icmp)};
- struct msghdr msg {
- .msg_name = addr, .msg_namelen = sizeof(*addr), .msg_iov = &iov,
- .msg_iovlen = 1, .msg_control = NULL, .msg_controllen = 0, .msg_flags = 0,
- };
- ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(*icmp)));
+void RawSocketTest::SendEmptyICMPTo(int sock, struct sockaddr_in* addr,
+ const struct icmphdr& icmp) {
+ // It's safe to use const_cast here because sendmsg won't modify the iovec.
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<struct icmphdr*>(&icmp));
+ iov.iov_len = sizeof(icmp);
+ struct msghdr msg = {};
+ msg.msg_name = addr;
+ msg.msg_namelen = sizeof(*addr);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(sizeof(icmp)));
}
-void RawSocketTest::ReceiveICMP(char *recv_buf, size_t recv_buf_len,
- size_t expected_size, struct sockaddr_in *src) {
+void RawSocketTest::ReceiveICMP(char* recv_buf, size_t recv_buf_len,
+ size_t expected_size, struct sockaddr_in* src) {
ASSERT_NO_FATAL_FAILURE(
ReceiveICMPFrom(recv_buf, recv_buf_len, expected_size, src, s_));
}
-void RawSocketTest::ReceiveICMPFrom(char *recv_buf, size_t recv_buf_len,
+void RawSocketTest::ReceiveICMPFrom(char* recv_buf, size_t recv_buf_len,
size_t expected_size,
- struct sockaddr_in *src, int sock) {
- struct iovec iov = {.iov_base = recv_buf, .iov_len = recv_buf_len};
- struct msghdr msg = {
- .msg_name = src,
- .msg_namelen = sizeof(*src),
- .msg_iov = &iov,
- .msg_iovlen = 1,
- .msg_control = NULL,
- .msg_controllen = 0,
- .msg_flags = 0,
- };
+ struct sockaddr_in* src, int sock) {
+ struct iovec iov = {};
+ iov.iov_base = recv_buf;
+ iov.iov_len = recv_buf_len;
+ struct msghdr msg = {};
+ msg.msg_name = src;
+ msg.msg_namelen = sizeof(*src);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
// We should receive the ICMP packet plus 20 bytes of IP header.
ASSERT_THAT(recvmsg(sock, &msg, 0),
SyscallSucceedsWithValue(expected_size + sizeof(struct iphdr)));