summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2019-10-09 17:54:51 -0700
committergVisor bot <gvisor-bot@google.com>2019-10-09 17:56:05 -0700
commitbf870c1a423063eb86a62c6268fe5d83fb6b87ba (patch)
treef08f7db5122ad778647fcc7f564f7e5cab657376 /pkg/tcpip/transport/udp
parent7a2d5b2fa7c398f7710a134b5790265bf620fced (diff)
Internal change.
PiperOrigin-RevId: 273861936
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go183
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go14
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go175
4 files changed, 295 insertions, 81 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 8ae83437e..2ad7978c2 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -49,6 +49,22 @@ const (
StateClosed
)
+// String implements fmt.Stringer.String.
+func (s EndpointState) String() string {
+ switch s {
+ case StateInitial:
+ return "INITIAL"
+ case StateBound:
+ return "BOUND"
+ case StateConnected:
+ return "CONNECTING"
+ case StateClosed:
+ return "CLOSED"
+ default:
+ return "UNKNOWN"
+ }
+}
+
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -58,10 +74,11 @@ const (
//
// +stateify savable
type endpoint struct {
+ stack.TransportEndpointInfo
+
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
// The following fields are used to manage the receive queue, and are
@@ -76,10 +93,7 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
- id stack.TransportEndpointID
state EndpointState
- bindNICID tcpip.NICID
- regNICID tcpip.NICID
route stack.Route `state:"manual"`
dstPort uint16
v6only bool
@@ -106,6 +120,9 @@ type endpoint struct {
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats tcpip.TransportEndpointStats `state:"nosave"`
}
// +stateify savable
@@ -114,10 +131,13 @@ type multicastMembership struct {
multicastAddr tcpip.Address
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
- stack: stack,
- netProto: netProto,
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
// requests.
@@ -135,6 +155,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ state: StateInitial,
}
}
@@ -146,12 +167,12 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
e.multicastMemberships = nil
@@ -191,6 +212,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
@@ -253,7 +275,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
- localAddr := e.id.LocalAddress
+ localAddr := e.ID.LocalAddress
if isBroadcastOrMulticast(localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
@@ -279,6 +301,29 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -330,12 +375,12 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicid := to.NIC
- if e.bindNICID != 0 {
- if nicid != 0 && nicid != e.bindNICID {
+ if e.BindNICID != 0 {
+ if nicid != 0 && nicid != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
}
if to.Addr == header.IPv4Broadcast && !e.broadcast {
@@ -384,7 +429,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl, useDefaultTTL); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -405,7 +450,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
@@ -458,7 +503,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
- if e.bindNICID != 0 && e.bindNICID != nic {
+ if e.BindNICID != 0 && e.BindNICID != nic {
return tcpip.ErrInvalidEndpointState
}
@@ -485,7 +530,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
} else {
- nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrUnknownDevice
@@ -502,7 +547,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
- if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -523,7 +568,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
} else {
- nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrUnknownDevice
@@ -545,7 +590,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -624,7 +669,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrUnknownProtocolOption
}
@@ -733,14 +778,18 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
+ if err := r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL); err != nil {
+ r.Stats().UDP.PacketSendErrors.Increment()
+ return err
+ }
+
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
-
- return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL)
+ return nil
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if len(addr.Addr) == 0 {
return netProto, nil
}
@@ -757,14 +806,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
}
// Fail if we are bound to an IPv6 address.
- if !allowMismatch && len(e.id.LocalAddress) == 16 {
+ if !allowMismatch && len(e.ID.LocalAddress) == 16 {
return 0, tcpip.ErrNetworkUnreachable
}
}
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -781,27 +830,27 @@ func (e *endpoint) Disconnect() *tcpip.Error {
}
id := stack.TransportEndpointID{}
// Exclude ephemerally bound endpoints.
- if e.bindNICID != 0 || e.id.LocalAddress == "" {
+ if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.id.LocalPort,
- LocalAddress: e.id.LocalAddress,
+ LocalPort: e.ID.LocalPort,
+ LocalAddress: e.ID.LocalAddress,
}
- id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
e.state = StateBound
} else {
- if e.id.LocalPort != 0 {
+ if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
- e.id = id
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.ID = id
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
@@ -828,16 +877,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
switch e.state {
case StateInitial:
case StateBound, StateConnected:
- localPort = e.id.LocalPort
- if e.bindNICID == 0 {
+ localPort = e.ID.LocalPort
+ if e.BindNICID == 0 {
break
}
- if nicid != 0 && nicid != e.bindNICID {
+ if nicid != 0 && nicid != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
@@ -849,7 +898,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
id := stack.TransportEndpointID{
- LocalAddress: e.id.LocalAddress,
+ LocalAddress: e.ID.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
@@ -876,14 +925,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Remove the old registration.
- if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
+ if e.ID.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
}
- e.id = id
+ e.ID = id
e.route = r.Clone()
e.dstPort = addr.Port
- e.regNICID = nicid
+ e.RegisterNICID = nicid
e.effectiveNetProtos = netProtos
e.state = StateConnected
@@ -939,7 +988,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
- if e.id.LocalPort == 0 {
+ if e.ID.LocalPort == 0 {
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
@@ -995,8 +1044,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.id = id
- e.regNICID = nicid
+ e.ID = id
+ e.RegisterNICID = nicid
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
@@ -1021,7 +1070,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the effective NICID generated by bindLocked.
- e.bindNICID = e.regNICID
+ e.BindNICID = e.RegisterNICID
return nil
}
@@ -1032,9 +1081,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
}, nil
}
@@ -1048,9 +1097,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
}, nil
}
@@ -1080,6 +1129,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if int(hdr.Length()) > vv.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
@@ -1087,11 +1137,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
+ e.stats.PacketsReceived.Increment()
// Drop the packet if our buffer is currently full.
- if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ if !e.rcvReady || e.rcvClosed {
+ e.rcvMu.Unlock()
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
}
@@ -1130,6 +1189,20 @@ func (e *endpoint) State() uint32 {
return uint32(e.state)
}
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index be46e6d4e..b227e353b 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -72,7 +72,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
panic(err)
}
}
@@ -93,13 +93,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
var err *tcpip.Error
if e.state == StateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
panic(err)
}
- } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound
+ } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
@@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
// Our saved state had a port, but we don't actually have a
// reservation. We need to remove the port from our state, but still
// pass it to the reservation machinery.
- id := e.id
- e.id.LocalPort = 0
- e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ id := e.ID
+ e.ID.LocalPort = 0
+ e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 2d0bc5221..d399ec722 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -79,10 +79,10 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
return nil, err
}
- ep.id = r.id
+ ep.ID = r.id
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
- ep.regNICID = r.route.NICID()
+ ep.RegisterNICID = r.route.NICID()
ep.state = StateConnected
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index faa728b68..4ada73475 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -384,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h)
+ c.injectV4Packet(payload, &h, true /* valid */)
} else {
- c.injectV6Packet(payload, &h)
+ c.injectV6Packet(payload, &h, true /* valid */)
}
}
// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -409,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
+ l := uint16(header.UDPMinimumSize + len(payload))
+ if !valid {
+ // Change the UDP payload length to corrupt the header
+ // as requested by the caller.
+ l++
+ }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
+ Length: l,
})
// Calculate the UDP pseudo-header checksum.
@@ -426,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-// injectV6Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
+// injectV4Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -536,7 +546,7 @@ func TestBindToDeviceOption(t *testing.T) {
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) {
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
c.t.Helper()
payload := newPayload()
@@ -547,6 +557,9 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+
var addr tcpip.FullAddress
v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
@@ -563,6 +576,11 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
}
}
+ if expectReadError && err != nil {
+ c.checkEndpointReadStats(1, epstats, err)
+ return
+ }
+
if err != nil {
c.t.Fatal("Read failed:", err)
}
@@ -581,6 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+ c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
@@ -588,15 +607,15 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
// its correctness.
func testRead(c *testContext, flow testFlow) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
}
// testFailingRead sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then tries to read it from the UDP
// endpoint and expects this to fail.
-func testFailingRead(c *testContext, flow testFlow) {
+func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
c.t.Helper()
- testReadInternal(c, flow, true /* packetShouldBeDropped */)
+ testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
}
func TestBindEphemeralPort(t *testing.T) {
@@ -771,8 +790,8 @@ func TestReadOnBoundToMulticast(t *testing.T) {
// Check that we receive multicast packets but not unicast or broadcast
// ones.
testRead(c, flow)
- testFailingRead(c, broadcast)
- testFailingRead(c, unicastV4)
+ testFailingRead(c, broadcast, false /* expectReadError */)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
})
}
}
@@ -795,7 +814,7 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
// Check that we receive broadcast packets but not unicast ones.
testRead(c, flow)
- testFailingRead(c, unicastV4)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
})
}
}
@@ -826,7 +845,8 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
// and verifies it fails with the provided error code.
func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
c.t.Helper()
-
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
h := flow.header4Tuple(outgoing)
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
@@ -834,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
_, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
+ c.checkEndpointWriteStats(1, epstats, gotErr)
if gotErr != wantErr {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
@@ -859,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec
func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
writeOpts := tcpip.WriteOptions{}
if setDest {
@@ -876,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
-
+ c.checkEndpointWriteStats(1, epstats, err)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -945,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Write to V4 mapped address.
testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
+ const want = 1
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
+ c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
+ }
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
@@ -1453,3 +1480,117 @@ func TestV6UnknownDestination(t *testing.T) {
})
}
}
+
+// TestIncrementMalformedPacketsReceived verifies if the malformed received
+// global and endpoint stats get incremented.
+func TestIncrementMalformedPacketsReceived(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ payload := newPayload()
+ c.t.Helper()
+ h := unicastV6.header4Tuple(incoming)
+ c.injectV6Packet(payload, &h, false /* !valid */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownRead verifies endpoint read shutdown and error
+// stats increment on packet receive.
+func TestShutdownRead(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingRead(c, unicastV6, true /* expectReadError */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownWrite verifies endpoint write shutdown and error
+// stats increment on packet write.
+func TestShutdownWrite(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
+}
+
+func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil:
+ want.PacketsSent.IncrementBy(incr)
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ want.WriteErrors.InvalidArgs.IncrementBy(incr)
+ case tcpip.ErrClosedForSend:
+ want.WriteErrors.WriteClosed.IncrementBy(incr)
+ case tcpip.ErrInvalidEndpointState:
+ want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
+ case tcpip.ErrNoLinkAddress:
+ want.SendErrors.NoLinkAddr.IncrementBy(incr)
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ want.SendErrors.NoRoute.IncrementBy(incr)
+ default:
+ want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}
+
+func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil, tcpip.ErrWouldBlock:
+ case tcpip.ErrClosedForReceive:
+ want.ReadErrors.ReadClosed.IncrementBy(incr)
+ default:
+ c.t.Errorf("Endpoint error missing stats update err %v", err)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}