diff options
author | gVisor bot <gvisor-bot@google.com> | 2019-10-09 17:54:51 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-10-09 17:56:05 -0700 |
commit | bf870c1a423063eb86a62c6268fe5d83fb6b87ba (patch) | |
tree | f08f7db5122ad778647fcc7f564f7e5cab657376 /pkg/tcpip/transport/udp | |
parent | 7a2d5b2fa7c398f7710a134b5790265bf620fced (diff) |
Internal change.
PiperOrigin-RevId: 273861936
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 183 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/forwarder.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 175 |
4 files changed, 295 insertions, 81 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 8ae83437e..2ad7978c2 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -49,6 +49,22 @@ const ( StateClosed ) +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnected: + return "CONNECTING" + case StateClosed: + return "CLOSED" + default: + return "UNKNOWN" + } +} + // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -58,10 +74,11 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue // The following fields are used to manage the receive queue, and are @@ -76,10 +93,7 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int - id stack.TransportEndpointID state EndpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID route stack.Route `state:"manual"` dstPort uint16 v6only bool @@ -106,6 +120,9 @@ type endpoint struct { // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). effectiveNetProtos []tcpip.NetworkProtocolNumber + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` } // +stateify savable @@ -114,10 +131,13 @@ type multicastMembership struct { multicastAddr tcpip.Address } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { return &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, waiterQueue: waiterQueue, // RFC 1075 section 5.4 recommends a TTL of 1 for membership // requests. @@ -135,6 +155,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: StateInitial, } } @@ -146,12 +167,12 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) } for _, mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } e.multicastMemberships = nil @@ -191,6 +212,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -253,7 +275,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // configured multicast interface if no interface is specified and the // specified address is a multicast address. func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { - localAddr := e.id.LocalAddress + localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" @@ -279,6 +301,29 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -330,12 +375,12 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + if e.BindNICID != 0 { + if nicid != 0 && nicid != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicid = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -384,7 +429,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl, useDefaultTTL); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -405,7 +450,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -458,7 +503,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if e.bindNICID != 0 && e.bindNICID != nic { + if e.BindNICID != 0 && e.BindNICID != nic { return tcpip.ErrInvalidEndpointState } @@ -485,7 +530,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -502,7 +547,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -523,7 +568,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -545,7 +590,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrBadLocalAddress } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -624,7 +669,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -733,14 +778,18 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u udp.SetChecksum(^udp.CalculateChecksum(xsum)) } + if err := r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL); err != nil { + r.Stats().UDP.PacketSendErrors.Increment() + return err + } + // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - - return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL) + return nil } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if len(addr.Addr) == 0 { return netProto, nil } @@ -757,14 +806,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t } // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.id.LocalAddress) == 16 { + if !allowMismatch && len(e.ID.LocalAddress) == 16 { return 0, tcpip.ErrNetworkUnreachable } } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -781,27 +830,27 @@ func (e *endpoint) Disconnect() *tcpip.Error { } id := stack.TransportEndpointID{} // Exclude ephemerally bound endpoints. - if e.bindNICID != 0 || e.id.LocalAddress == "" { + if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.id.LocalPort, - LocalAddress: e.id.LocalAddress, + LocalPort: e.ID.LocalPort, + LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } e.state = StateBound } else { - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) - e.id = id + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.ID = id e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -828,16 +877,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { switch e.state { case StateInitial: case StateBound, StateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicid != 0 && nicid != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicid = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -849,7 +898,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() id := stack.TransportEndpointID{ - LocalAddress: e.id.LocalAddress, + LocalAddress: e.ID.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, RemoteAddress: r.RemoteAddress, @@ -876,14 +925,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Remove the old registration. - if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) + if e.ID.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) } - e.id = id + e.ID = id e.route = r.Clone() e.dstPort = addr.Port - e.regNICID = nicid + e.RegisterNICID = nicid e.effectiveNetProtos = netProtos e.state = StateConnected @@ -939,7 +988,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if e.id.LocalPort == 0 { + if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { return id, err @@ -995,8 +1044,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id - e.regNICID = nicid + e.ID = id + e.RegisterNICID = nicid e.effectiveNetProtos = netProtos // Mark endpoint as bound. @@ -1021,7 +1070,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // Save the effective NICID generated by bindLocked. - e.bindNICID = e.regNICID + e.BindNICID = e.RegisterNICID return nil } @@ -1032,9 +1081,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -1048,9 +1097,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -1080,6 +1129,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if int(hdr.Length()) > vv.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } @@ -1087,11 +1137,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() + e.stats.PacketsReceived.Increment() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } @@ -1130,6 +1189,20 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index be46e6d4e..b227e353b 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -72,7 +72,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } } @@ -93,13 +93,13 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } @@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) { // Our saved state had a port, but we don't actually have a // reservation. We need to remove the port from our state, but still // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id := e.ID + e.ID.LocalPort = 0 + e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 2d0bc5221..d399ec722 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -79,10 +79,10 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, return nil, err } - ep.id = r.id + ep.ID = r.id ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort - ep.regNICID = r.route.NICID() + ep.RegisterNICID = r.route.NICID() ep.state = StateConnected diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index faa728b68..4ada73475 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -384,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h) + c.injectV4Packet(payload, &h, true /* valid */) } else { - c.injectV6Packet(payload, &h) + c.injectV6Packet(payload, &h, true /* valid */) } } // injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -409,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) + l := uint16(header.UDPMinimumSize + len(payload)) + if !valid { + // Change the UDP payload length to corrupt the header + // as requested by the caller. + l++ + } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), + Length: l, }) // Calculate the UDP pseudo-header checksum. @@ -426,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -// injectV6Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { +// injectV4Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -536,7 +546,7 @@ func TestBindToDeviceOption(t *testing.T) { // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its // correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) { +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -547,6 +557,9 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) c.wq.EventRegister(&we, waiter.EventIn) defer c.wq.EventUnregister(&we) + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + var addr tcpip.FullAddress v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { @@ -563,6 +576,11 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) } } + if expectReadError && err != nil { + c.checkEndpointReadStats(1, epstats, err) + return + } + if err != nil { c.t.Fatal("Read failed:", err) } @@ -581,6 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it @@ -588,15 +607,15 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) // its correctness. func testRead(c *testContext, flow testFlow) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) } // testFailingRead sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then tries to read it from the UDP // endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow) { +func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */) + testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) } func TestBindEphemeralPort(t *testing.T) { @@ -771,8 +790,8 @@ func TestReadOnBoundToMulticast(t *testing.T) { // Check that we receive multicast packets but not unicast or broadcast // ones. testRead(c, flow) - testFailingRead(c, broadcast) - testFailingRead(c, unicastV4) + testFailingRead(c, broadcast, false /* expectReadError */) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } @@ -795,7 +814,7 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { // Check that we receive broadcast packets but not unicast ones. testRead(c, flow) - testFailingRead(c, unicastV4) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } @@ -826,7 +845,8 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { // and verifies it fails with the provided error code. func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { c.t.Helper() - + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) @@ -834,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) + c.checkEndpointWriteStats(1, epstats, gotErr) if gotErr != wantErr { c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } @@ -859,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() writeOpts := tcpip.WriteOptions{} if setDest { @@ -876,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - + c.checkEndpointWriteStats(1, epstats, err) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -945,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Write to V4 mapped address. testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + const want = 1 + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { + c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) + } } func TestDualWriteConnectedToV4Mapped(t *testing.T) { @@ -1453,3 +1480,117 @@ func TestV6UnknownDestination(t *testing.T) { }) } } + +// TestIncrementMalformedPacketsReceived verifies if the malformed received +// global and endpoint stats get incremented. +func TestIncrementMalformedPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + payload := newPayload() + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + c.injectV6Packet(payload, &h, false /* !valid */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } +} + +// TestShutdownRead verifies endpoint read shutdown and error +// stats increment on packet receive. +func TestShutdownRead(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingRead(c, unicastV6, true /* expectReadError */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { + t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) + } +} + +// TestShutdownWrite verifies endpoint write shutdown and error +// stats increment on packet write. +func TestShutdownWrite(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) +} + +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil: + want.PacketsSent.IncrementBy(incr) + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + want.WriteErrors.InvalidArgs.IncrementBy(incr) + case tcpip.ErrClosedForSend: + want.WriteErrors.WriteClosed.IncrementBy(incr) + case tcpip.ErrInvalidEndpointState: + want.WriteErrors.InvalidEndpointState.IncrementBy(incr) + case tcpip.ErrNoLinkAddress: + want.SendErrors.NoLinkAddr.IncrementBy(incr) + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + want.SendErrors.NoRoute.IncrementBy(incr) + default: + want.SendErrors.SendToNetworkFailed.IncrementBy(incr) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} + +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil, tcpip.ErrWouldBlock: + case tcpip.ErrClosedForReceive: + want.ReadErrors.ReadClosed.IncrementBy(incr) + default: + c.t.Errorf("Endpoint error missing stats update err %v", err) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} |