From bf870c1a423063eb86a62c6268fe5d83fb6b87ba Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 9 Oct 2019 17:54:51 -0700 Subject: Internal change. PiperOrigin-RevId: 273861936 --- pkg/tcpip/stack/stack.go | 28 ++ pkg/tcpip/stack/transport_test.go | 27 +- pkg/tcpip/tcpip.go | 128 +++++++- pkg/tcpip/transport/icmp/endpoint.go | 145 ++++++--- pkg/tcpip/transport/icmp/endpoint_state.go | 10 +- pkg/tcpip/transport/raw/endpoint.go | 346 ++++++++++++--------- pkg/tcpip/transport/raw/endpoint_state.go | 6 +- pkg/tcpip/transport/tcp/accept.go | 18 +- pkg/tcpip/transport/tcp/connect.go | 44 ++- pkg/tcpip/transport/tcp/endpoint.go | 273 +++++++++++----- pkg/tcpip/transport/tcp/endpoint_state.go | 28 +- pkg/tcpip/transport/tcp/snd.go | 3 + pkg/tcpip/transport/tcp/tcp_noracedetector_test.go | 8 + pkg/tcpip/transport/tcp/tcp_sack_test.go | 8 + pkg/tcpip/transport/tcp/tcp_test.go | 55 +++- pkg/tcpip/transport/udp/endpoint.go | 183 +++++++---- pkg/tcpip/transport/udp/endpoint_state.go | 14 +- pkg/tcpip/transport/udp/forwarder.go | 4 +- pkg/tcpip/transport/udp/udp_test.go | 175 ++++++++++- 19 files changed, 1093 insertions(+), 410 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ff574a055..7d73389cc 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -424,6 +424,34 @@ type Options struct { UnassociatedFactory UnassociatedEndpointFactory } +// TransportEndpointInfo holds useful information about a transport endpoint +// which can be queried by monitoring tools. +// +// +stateify savable +type TransportEndpointInfo struct { + // The following fields are initialized at creation time and are + // immutable. + + NetProto tcpip.NetworkProtocolNumber + TransProto tcpip.TransportProtocolNumber + + // The following fields are protected by endpoint mu. + + ID TransportEndpointID + // BindNICID and bindAddr are set via calls to Bind(). They are used to + // reject attempts to send data or connect via a different NIC or + // address + BindNICID tcpip.NICID + BindAddr tcpip.Address + // RegisterNICID is the default NICID registered as a side-effect of + // connect or datagram write. + RegisterNICID tcpip.NICID +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*TransportEndpointInfo) IsEndpointInfo() {} + // New allocates a new networking stack with only the requested networking and // transport protocols configured with default options. // diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index db290c404..63811c684 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -38,9 +38,8 @@ const ( // Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't // use it. type fakeTransportEndpoint struct { - id stack.TransportEndpointID + stack.TransportEndpointInfo stack *stack.Stack - netProto tcpip.NetworkProtocolNumber proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route @@ -49,8 +48,16 @@ type fakeTransportEndpoint struct { acceptQueue []fakeTransportEndpoint } -func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto} +func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { + return &f.TransportEndpointInfo +} + +func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { + return nil +} + +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} } func (f *fakeTransportEndpoint) Close() { @@ -126,8 +133,8 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() // Try to register so that we can start receiving packets. - f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */) + f.ID.RemoteAddress = addr.Addr + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -190,9 +197,11 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - id: id, - stack: f.stack, - netProto: f.netProto, + stack: f.stack, + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, proto: f.proto, peerAddr: r.RemoteAddress, route: r.Clone(), diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index a35d6562e..60ba98a4c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -429,6 +429,26 @@ type Endpoint interface { // IPTables returns the iptables for this endpoint's stack. IPTables() (iptables.IPTables, error) + + // Info returns a copy to the transport endpoint info. + Info() EndpointInfo + + // Stats returns a reference to the endpoint stats. + Stats() EndpointStats +} + +// EndpointInfo is the interface implemented by each endpoint info struct. +type EndpointInfo interface { + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo + // marker interface. + IsEndpointInfo() +} + +// EndpointStats is the interface implemented by each endpoint stats struct. +type EndpointStats interface { + // IsEndpointStats is an empty method to implement the tcpip.EndpointStats + // marker interface. + IsEndpointStats() } // WriteOptions contains options for Endpoint.Write. @@ -879,6 +899,9 @@ type TCPStats struct { // SegmentsSent is the number of TCP segments sent. SegmentsSent *StatCounter + // SegmentSendErrors is the number of TCP segments failed to be sent. + SegmentSendErrors *StatCounter + // ResetsSent is the number of TCP resets sent. ResetsSent *StatCounter @@ -931,6 +954,9 @@ type UDPStats struct { // PacketsSent is the number of UDP datagrams sent via sendUDP. PacketsSent *StatCounter + + // PacketSendErrors is the number of datagrams failed to be sent. + PacketSendErrors *StatCounter } // Stats holds statistics about the networking stack. @@ -941,7 +967,7 @@ type Stats struct { // stack that were for an unknown or unsupported protocol. UnknownProtocolRcvdPackets *StatCounter - // MalformedRcvPackets is the number of packets received by the stack + // MalformedRcvdPackets is the number of packets received by the stack // that were deemed malformed. MalformedRcvdPackets *StatCounter @@ -961,6 +987,86 @@ type Stats struct { UDP UDPStats } +// ReceiveErrors collects packet receive errors within transport endpoint. +type ReceiveErrors struct { + // ReceiveBufferOverflow is the number of received packets dropped + // due to the receive buffer being full. + ReceiveBufferOverflow StatCounter + + // MalformedPacketsReceived is the number of incoming packets + // dropped due to the packet header being in a malformed state. + MalformedPacketsReceived StatCounter + + // ClosedReceiver is the number of received packets dropped because + // of receiving endpoint state being closed. + ClosedReceiver StatCounter +} + +// SendErrors collects packet send errors within the transport layer for +// an endpoint. +type SendErrors struct { + // SendToNetworkFailed is the number of packets failed to be written to + // the network endpoint. + SendToNetworkFailed StatCounter + + // NoRoute is the number of times we failed to resolve IP route. + NoRoute StatCounter + + // NoLinkAddr is the number of times we failed to resolve ARP. + NoLinkAddr StatCounter +} + +// ReadErrors collects segment read errors from an endpoint read call. +type ReadErrors struct { + // ReadClosed is the number of received packet drops because the endpoint + // was shutdown for read. + ReadClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter +} + +// WriteErrors collects packet write errors from an endpoint write call. +type WriteErrors struct { + // WriteClosed is the number of packet drops because the endpoint + // was shutdown for write. + WriteClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter + + // InvalidArgs is the number of times invalid input arguments were + // provided for endpoint Write call. + InvalidArgs StatCounter +} + +// TransportEndpointStats collects statistics about the endpoint. +type TransportEndpointStats struct { + // PacketsReceived is the number of successful packet receives. + PacketsReceived StatCounter + + // PacketsSent is the number of successful packet sends. + PacketsSent StatCounter + + // ReceiveErrors collects packet receive errors within transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects packet read errors from an endpoint read call. + ReadErrors ReadErrors + + // SendErrors collects packet send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects packet write errors from an endpoint write call. + WriteErrors WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*TransportEndpointStats) IsEndpointStats() {} + func fillIn(v reflect.Value) { for i := 0; i < v.NumField(); i++ { v := v.Field(i) @@ -983,6 +1089,26 @@ func (s Stats) FillIn() Stats { return s } +// Clone returns a copy of the TransportEndpointStats by atomically reading +// each field. +func (src *TransportEndpointStats) Clone() TransportEndpointStats { + var dst TransportEndpointStats + clone(reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src).Elem()) + return dst +} + +func clone(dst reflect.Value, src reflect.Value) { + for i := 0; i < dst.NumField(); i++ { + d := dst.Field(i) + s := src.Field(i) + if c, ok := s.Addr().Interface().(*StatCounter); ok { + d.Addr().Interface().(*StatCounter).IncrementBy(c.Value()) + } else { + clone(d, s) + } + } +} + // String implements the fmt.Stringer interface. func (a Address) String() string { switch len(a) { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e17e737e4..d0cfdcda1 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -52,11 +52,11 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue // The following fields are used to manage the receive queue, and are @@ -73,28 +73,23 @@ type endpoint struct { sndBufSize int // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags - id stack.TransportEndpointID state endpointState - // bindNICID and bindAddr are set via calls to Bind(). They are used to - // reject attempts to send data or connect via a different NIC or - // address - bindNICID tcpip.NICID - bindAddr tcpip.Address - // regNICID is the default NIC to be used when callers don't specify a - // NIC. - regNICID tcpip.NICID - route stack.Route `state:"manual"` - ttl uint8 -} - -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + route stack.Route `state:"manual"` + ttl uint8 + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: stateInitial, }, nil } @@ -105,7 +100,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -144,6 +139,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -206,6 +202,29 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -256,12 +275,12 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + if e.BindNICID != 0 { + if nicid != 0 && nicid != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicid = e.BindNICID } toCopy := *to @@ -272,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -295,12 +314,12 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.id.LocalPort, v, e.ttl) + err = send4(route, e.ID.LocalPort, v, e.ttl) case header.IPv6ProtocolNumber: - err = send6(route, e.id.LocalPort, v, e.ttl) + err = send6(route, e.ID.LocalPort, v, e.ttl) } if err != nil { @@ -430,14 +449,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { return 0, tcpip.ErrNoRoute } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -458,16 +477,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { localPort := uint16(0) switch e.state { case stateBound, stateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicid != 0 && nicid != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicid = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -478,7 +497,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -500,9 +519,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id + e.ID = id e.route = r.Clone() - e.regNICID = nicid + e.RegisterNICID = nicid e.state = stateConnected @@ -557,14 +576,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -611,8 +630,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id - e.regNICID = addr.NIC + e.ID = id + e.RegisterNICID = addr.NIC // Mark endpoint as bound. e.state = stateBound @@ -635,8 +654,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { return err } - e.bindNICID = addr.NIC - e.bindAddr = addr.Addr + e.BindNICID = addr.NIC + e.BindAddr = addr.Addr return nil } @@ -647,9 +666,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -663,9 +682,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -691,17 +710,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { // Only accept echo replies. - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(vv.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: h := header.ICMPv6(vv.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } } @@ -709,9 +730,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } @@ -733,7 +762,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv pkt.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() - + e.stats.PacketsReceived.Increment() // Notify any waiters that there's data to be read now. if wasEmpty { e.waiterQueue.Notify(waiter.EventIn) @@ -749,3 +778,17 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index c587b96b6..9d263c0ec 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -76,19 +76,19 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { panic(err) } - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + e.ID.LocalAddress = e.route.LocalAddress + } else if len(e.ID.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 8ae132efa..4f5c286cf 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -62,11 +62,10 @@ type packet struct { // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue associated bool @@ -84,18 +83,10 @@ type endpoint struct { closed bool connected bool bound bool - // registeredNIC is the NIC to which th endpoint is explicitly - // registered. Is set when Connect or Bind are used to specify a NIC. - registeredNIC tcpip.NICID - // boundNIC and boundAddr are set on calls to Bind(). When callers - // attempt actions that would invalidate the binding data (e.g. sending - // data via a NIC other than boundNIC), the endpoint will return an - // error. - boundNIC tcpip.NICID - boundAddr tcpip.Address // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route stack.Route `state:"manual"` + stats tcpip.TransportEndpointStats `state:"nosave"` } // NewEndpoint returns a raw endpoint for the given protocols. @@ -104,15 +95,17 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if netProto != header.IPv4ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } - ep := &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + e := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, @@ -123,81 +116,82 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - ep.rcvBufSizeMax = 0 - ep.waiterQueue = nil - return ep, nil + e.rcvBufSizeMax = 0 + e.waiterQueue = nil + return e, nil } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { return nil, err } - return ep, nil + return e, nil } // Close implements tcpip.Endpoint.Close. -func (ep *endpoint) Close() { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed || !ep.associated { + if e.closed || !e.associated { return } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) - ep.rcvMu.Lock() - defer ep.rcvMu.Unlock() + e.rcvMu.Lock() + defer e.rcvMu.Unlock() // Clear the receive list. - ep.rcvClosed = true - ep.rcvBufSize = 0 - for !ep.rcvList.Empty() { - ep.rcvList.Remove(ep.rcvList.Front()) + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + e.rcvList.Remove(e.rcvList.Front()) } - if ep.connected { - ep.route.Release() - ep.connected = false + if e.connected { + e.route.Release() + e.connected = false } - ep.closed = true + e.closed = true - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (iptables.IPTables, error) { - return ep.stack.IPTables(), nil +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil } // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !ep.associated { +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if !e.associated { return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue } - ep.rcvMu.Lock() + e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. - if ep.rcvList.Empty() { + if e.rcvList.Empty() { err := tcpip.ErrWouldBlock - if ep.rcvClosed { + if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return buffer.View{}, tcpip.ControlMessages{}, err } - packet := ep.rcvList.Front() - ep.rcvList.Remove(packet) - ep.rcvBufSize -= packet.data.Size() + packet := e.rcvList.Front() + e.rcvList.Remove(packet) + e.rcvBufSize -= packet.data.Size() - ep.rcvMu.Unlock() + e.rcvMu.Unlock() if addr != nil { *addr = packet.senderAddr @@ -207,30 +201,54 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - ep.mu.RLock() + e.mu.RLock() - if ep.closed { - ep.mu.RUnlock() + if e.closed { + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { + e.mu.RUnlock() return 0, nil, err } // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !ep.associated { + if !e.associated { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -251,66 +269,66 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <- if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if !ep.connected { - ep.mu.RUnlock() + if !e.connected { + e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if ep.route.IsResolutionRequired() { - savedRoute := &ep.route + if e.route.IsResolutionRequired() { + savedRoute := &e.route // Promote lock to exclusive if using a shared route, // given that it may need to change in finishWrite. - ep.mu.RUnlock() - ep.mu.Lock() + e.mu.RUnlock() + e.mu.Lock() // Make sure that the route didn't change during the // time we didn't hold the lock. - if !ep.connected || savedRoute != &ep.route { - ep.mu.Unlock() + if !e.connected || savedRoute != &e.route { + e.mu.Unlock() return 0, nil, tcpip.ErrInvalidEndpointState } - n, ch, err := ep.finishWrite(payloadBytes, savedRoute) - ep.mu.Unlock() + n, ch, err := e.finishWrite(payloadBytes, savedRoute) + e.mu.Unlock() return n, ch, err } - n, ch, err := ep.finishWrite(payloadBytes, &ep.route) - ep.mu.RUnlock() + n, ch, err := e.finishWrite(payloadBytes, &e.route) + e.mu.RUnlock() return n, ch, err } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC - if ep.bound && nic != 0 && nic != ep.boundNIC { - ep.mu.RUnlock() + if e.bound && nic != 0 && nic != e.BindNICID { + e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } // We don't support IPv6 yet, so this has to be an IPv4 address. if len(opts.To.Addr) != header.IPv4AddressSize { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } - // Find the route to the destination. If boundAddress is 0, + // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. - route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, err } - n, ch, err := ep.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, &route) route.Release() - ep.mu.RUnlock() + e.mu.RUnlock() return n, ch, err } // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -323,16 +341,16 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch ep.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - if !ep.associated { + if !e.associated { if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, 0, true /* useDefaultTTL */); err != nil { + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), e.TransProto, 0, true /* useDefaultTTL */); err != nil { return 0, nil, err } @@ -344,7 +362,7 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -354,11 +372,11 @@ func (*endpoint) Disconnect() *tcpip.Error { } // Connect implements tcpip.Endpoint.Connect. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed { + if e.closed { return tcpip.ErrInvalidEndpointState } @@ -368,15 +386,15 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } nic := addr.NIC - if ep.bound { - if ep.boundNIC == 0 { + if e.bound { + if e.BindNICID == 0 { // If we're bound, but not to a specific NIC, the NIC // in addr will be used. Nothing to do here. } else if addr.NIC == 0 { // If we're bound to a specific NIC, but addr doesn't // specify a NIC, use the bound NIC. - nic = ep.boundNIC - } else if addr.NIC != ep.boundNIC { + nic = e.BindNICID + } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. return tcpip.ErrInvalidEndpointState @@ -384,53 +402,53 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the destination. - route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) if err != nil { return err } defer route.Release() - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = nic + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = nic } // Save the route we've connected via. - ep.route = route.Clone() - ep.connected = true + e.route = route.Clone() + e.connected = true return nil } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if !ep.connected { + if !e.connected { return tcpip.ErrNotConnected } return nil } // Listen implements tcpip.Endpoint.Listen. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() // Callers must provide an IPv4 address or no network address (for // binding to a NIC, but not an address). @@ -439,56 +457,56 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 { + if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = addr.NIC - ep.boundNIC = addr.NIC + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = addr.NIC + e.BindNICID = addr.NIC } - ep.boundAddr = addr.Addr - ep.bound = true + e.BindAddr = addr.Addr + e.bound = true return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } // Readiness implements tcpip.Endpoint.Readiness. -func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // The endpoint is always writable. result := waiter.EventOut & mask // Determine whether the endpoint is readable. if (mask & waiter.EventIn) != 0 { - ep.rcvMu.Lock() - if !ep.rcvList.Empty() || ep.rcvClosed { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { result |= waiter.EventIn } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() } return result } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -498,28 +516,28 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 - ep.rcvMu.Lock() - if !ep.rcvList.Empty() { - p := ep.rcvList.Front() + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() v = p.data.Size() } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return v, nil case tcpip.SendBufferSizeOption: - ep.mu.Lock() - v := ep.sndBufSize - ep.mu.Unlock() + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() return v, nil case tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - v := ep.rcvBufSizeMax - ep.rcvMu.Unlock() + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() return v, nil } @@ -528,7 +546,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: return nil @@ -543,37 +561,45 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { - ep.rcvMu.Lock() +func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax { - ep.stack.Stats().DroppedPackets.Increment() - ep.rcvMu.Unlock() + if e.rcvClosed { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } - if ep.bound { + if e.bound { // If bound to a NIC, only accept data for that NIC. - if ep.boundNIC != 0 && ep.boundNIC != route.NICID() { - ep.rcvMu.Unlock() + if e.BindNICID != 0 && e.BindNICID != route.NICID() { + e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + e.rcvMu.Unlock() return } } // If connected, only accept packets from the remote address we // connected to. - if ep.connected && ep.route.RemoteAddress != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.connected && e.route.RemoteAddress != route.RemoteAddress { + e.rcvMu.Unlock() return } - wasEmpty := ep.rcvBufSize == 0 + wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. packet := &packet{ @@ -586,20 +612,34 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b combinedVV := netHeader.ToVectorisedView() combinedVV.Append(vv) packet.data = combinedVV.Clone(packet.views[:]) - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = e.stack.NowNanoseconds() - ep.rcvList.PushBack(packet) - ep.rcvBufSize += packet.data.Size() - - ep.rcvMu.Unlock() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.rcvMu.Unlock() + e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { - ep.waiterQueue.Notify(waiter.EventIn) + e.waiterQueue.Notify(waiter.EventIn) } } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 168953dec..a6c7cc43a 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -73,7 +73,7 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if ep.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) if err != nil { panic(err) } @@ -81,12 +81,12 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b8b4bcee8..8f5572195 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -229,7 +229,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n := newEndpoint(l.stack, netProto, nil) n.v6only = l.v6only - n.id = s.id + n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { n.Close() return nil, err } @@ -290,7 +290,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h.resetToSynRcvd(cookie, irs, opts) if err := h.execute(); err != nil { - ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -311,14 +310,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.id] = n + l.pendingEndpoints[n.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.id) + delete(l.pendingEndpoints, n.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -363,6 +362,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) @@ -414,6 +414,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } decSynRcvdCount() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } else { @@ -421,6 +422,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // is full then drop the syn. if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -439,7 +441,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TSEcr: opts.TSVal, MSS: uint16(mss), } - sendSynTCP(&s.route, s.id, e.ttl, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, s.id, e.ttl, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } @@ -451,6 +453,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // complete the connection at the time of retransmit if // the backlog has space. e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -505,6 +508,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } @@ -536,7 +540,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only e.mu.Unlock() - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 1d6e7f5f3..cb8cfd619 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -255,8 +255,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if ttl == 0 { ttl = s.route.DefaultTTL() } - sendSynTCP(&s.route, h.ep.id, ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - + h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -300,7 +299,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -387,6 +386,11 @@ func (h *handshake) resolveRoute() *tcpip.Error { switch index { case wakerForResolution: if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + if err == tcpip.ErrNoLinkAddress { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + } else if err != nil { + h.ep.stats.SendErrors.NoRoute.Increment() + } // Either success (err == nil) or failure. return err } @@ -464,7 +468,7 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - sendSynTCP(&h.ep.route, h.ep.id, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) for h.state != handshakeCompleted { switch index, _ := s.Fetch(true); index { case wakerForResend: @@ -473,7 +477,7 @@ func (h *handshake) execute() *tcpip.Error { return tcpip.ErrTimeout } rt.Reset(timeOut) - sendSynTCP(&h.ep.route, h.ep.id, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: n := h.ep.fetchNotifications() @@ -583,11 +587,22 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) { options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, ttl, flags, seq, ack, rcvWnd, options, nil) + // We ignore SYN send errors and let the callers re-attempt send. + if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, flags, seq, ack, rcvWnd, options, nil); err != nil { + e.stats.SendErrors.SynSendToNetworkFailed.Increment() + } putOptions(options) - return err +} + +func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + if err := sendTCP(r, id, data, ttl, flags, seq, ack, rcvWnd, opts, gso); err != nil { + e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() + return err + } + e.stats.SegmentsSent.Increment() + return nil } // sendTCP sends a TCP segment with the provided options via the provided @@ -628,12 +643,15 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } + if err := r.WritePacket(gso, hdr, data, ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */); err != nil { + r.Stats().TCP.SegmentSendErrors.Increment() + return err + } r.Stats().TCP.SegmentsSent.Increment() if (flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } - - return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) + return nil } // makeOptions makes an options slice. @@ -682,7 +700,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := sendTCP(&e.route, e.id, data, e.ttl, flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, e.ID, data, e.ttl, flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } @@ -732,7 +750,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.state = StateError - e.hardError = err + e.HardError = err if err != tcpip.ErrConnectionReset { e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) } @@ -902,7 +920,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.mu.Lock() e.state = StateError - e.hardError = err + e.HardError = err // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 83d92b3e1..090a8eb24 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -172,6 +172,101 @@ type rcvBufAutoTuneParams struct { disabled bool } +// ReceiveErrors collect segment receive errors within transport layer. +type ReceiveErrors struct { + tcpip.ReceiveErrors + + // SegmentQueueDropped is the number of segments dropped due to + // a full segment queue. + SegmentQueueDropped tcpip.StatCounter + + // ChecksumErrors is the number of segments dropped due to bad checksums. + ChecksumErrors tcpip.StatCounter + + // ListenOverflowSynDrop is the number of times the listen queue overflowed + // and a SYN was dropped. + ListenOverflowSynDrop tcpip.StatCounter + + // ListenOverflowAckDrop is the number of times the final ACK + // in the handshake was dropped due to overflow. + ListenOverflowAckDrop tcpip.StatCounter + + // ZeroRcvWindowState is the number of times we advertised + // a zero receive window when rcvList is full. + ZeroRcvWindowState tcpip.StatCounter +} + +// SendErrors collect segment send errors within the transport layer. +type SendErrors struct { + tcpip.SendErrors + + // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent + // to the network endpoint. + SegmentSendToNetworkFailed tcpip.StatCounter + + // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent + // to the network endpoint. + SynSendToNetworkFailed tcpip.StatCounter + + // Retransmits is the number of TCP segments retransmitted. + Retransmits tcpip.StatCounter + + // FastRetransmit is the number of segments retransmitted in fast + // recovery. + FastRetransmit tcpip.StatCounter + + // Timeouts is the number of times the RTO expired. + Timeouts tcpip.StatCounter +} + +// Stats holds statistics about the endpoint. +type Stats struct { + // SegmentsReceived is the number of TCP segments received that + // the transport layer successfully parsed. + SegmentsReceived tcpip.StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent tcpip.StatCounter + + // FailedConnectionAttempts is the number of times we saw Connect and + // Accept errors. + FailedConnectionAttempts tcpip.StatCounter + + // ReceiveErrors collects segment receive errors within the + // transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects segment read errors from an endpoint read call. + ReadErrors tcpip.ReadErrors + + // SendErrors collects segment send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects segment write errors from an endpoint write call. + WriteErrors tcpip.WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*Stats) IsEndpointStats() {} + +// EndpointInfo holds useful information about a transport endpoint which +// can be queried by monitoring tools. +// +// +stateify savable +type EndpointInfo struct { + stack.TransportEndpointInfo + + // HardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. HardError is protected by endpoint mu. + HardError *tcpip.Error `state:".(string)"` +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*EndpointInfo) IsEndpointInfo() {} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -180,6 +275,8 @@ type rcvBufAutoTuneParams struct { // // +stateify savable type endpoint struct { + EndpointInfo + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -188,8 +285,7 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. - stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber + stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` // lastError represents the last error that the endpoint reported; @@ -220,7 +316,6 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID state EndpointState `state:".(EndpointState)"` @@ -243,11 +338,6 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` - // hardError is meaningful only when state is stateError, it stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. hardError is protected by mu. - hardError *tcpip.Error `state:".(string)"` - // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -399,13 +489,15 @@ type endpoint struct { probe stack.TCPProbeFunc `state:"nosave"` // The following are only used to assist the restore run to re-connect. - bindAddress tcpip.Address connectingAddress tcpip.Address // amss is the advertised MSS to the peer by this endpoint. amss uint16 gso *stack.GSO + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats Stats `state:"nosave"` } // StopWork halts packet processing. Only to be used in tests. @@ -433,10 +525,15 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + EndpointInfo: EndpointInfo{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + }, waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, @@ -452,26 +549,26 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite } var ss SendBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } var rs ReceiveBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } var cs tcpip.CongestionControlOption - if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } var mrb tcpip.ModerateReceiveBufferOption - if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - if p := stack.GetTCPProbe(); p != nil { + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -570,11 +667,11 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -631,12 +728,12 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -737,11 +834,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.hardError + he := e.HardError e.mu.RUnlock() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } + e.stats.ReadErrors.InvalidEndpointState.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -750,6 +848,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RUnlock() + if err == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } return v, tcpip.ControlMessages{}, err } @@ -793,7 +894,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { if !e.state.connected() { switch e.state { case StateError: - return 0, e.hardError + return 0, e.HardError default: return 0, tcpip.ErrClosedForSend } @@ -824,6 +925,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -858,6 +960,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -869,7 +972,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Add data to the send queue. - s := newSegmentFromView(&e.route, e.id, v) + s := newSegmentFromView(&e.route, e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -901,8 +1004,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.state; !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.hardError + return 0, tcpip.ControlMessages{}, e.HardError } + e.stats.ReadErrors.InvalidEndpointState.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -911,6 +1015,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.state.connected() { + e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -1102,7 +1207,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -1306,7 +1411,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -1396,7 +1501,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { // Fail if using a v4 mapped address on a v6only endpoint. if e.v6only { @@ -1412,7 +1517,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -1426,7 +1531,12 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return e.connect(addr, true, true) + err := e.connect(addr, true, true) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err } // connect connects the endpoint to its peer. In the normal non-S/R case, the @@ -1435,14 +1545,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() connectingAddr := addr.Addr @@ -1487,29 +1592,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnecting case StateError: - return e.hardError + return e.HardError default: return tcpip.ErrInvalidEndpointState } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() - origID := e.id + origID := e.ID netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.id.LocalAddress = r.LocalAddress - e.id.RemoteAddress = r.RemoteAddress - e.id.RemotePort = addr.Port + e.ID.LocalAddress = r.LocalAddress + e.ID.RemoteAddress = r.RemoteAddress + e.ID.RemotePort = addr.Port - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1518,35 +1623,35 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.id.LocalAddress == e.id.RemoteAddress + sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress // Calculate a port offset based on the destination IP/port and // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. h := jenkins.Sum32(e.stack.PortSeed()) - h.Write([]byte(e.id.LocalAddress)) - h.Write([]byte(e.id.RemoteAddress)) + h.Write([]byte(e.ID.LocalAddress)) + h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) - binary.LittleEndian.PutUint16(portBuf, e.id.RemotePort) + binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) h.Write(portBuf) portOffset := h.Sum32() if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { - if sameAddr && p == e.id.RemotePort { + if sameAddr && p == e.ID.RemotePort { return false, nil } // reusePort is false below because connect cannot reuse a port even if // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { return false, nil } - id := e.id + id := e.ID id.LocalPort = p switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: - e.id = id + e.ID = id return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1581,7 +1686,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Lock() for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.id + s.id = e.ID s.route = r.Clone() e.sndWaker.Assert() } @@ -1641,7 +1746,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.id, nil) + s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ @@ -1669,14 +1774,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { +func (e *endpoint) Listen(backlog int) *tcpip.Error { + err := e.listen(backlog) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err +} + +func (e *endpoint) listen(backlog int) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -1702,11 +1811,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { + e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { return err } @@ -1770,7 +1880,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { return tcpip.ErrAlreadyBound } - e.bindAddress = addr.Addr + e.BindAddr = addr.Addr netProto, err := e.checkV4Mapped(&addr) if err != nil { return err @@ -1794,7 +1904,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.isPortReserved = true e.effectiveNetProtos = netProtos - e.id.LocalPort = port + e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. defer func(bindToDevice tcpip.NICID) { @@ -1802,8 +1912,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil - e.id.LocalPort = 0 - e.id.LocalAddress = "" + e.ID.LocalPort = 0 + e.ID.LocalAddress = "" e.boundNICID = 0 } }(e.bindToDevice) @@ -1817,7 +1927,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } e.boundNICID = nic - e.id.LocalAddress = addr.Addr + e.ID.LocalAddress = addr.Addr } // Mark endpoint as bound. @@ -1832,8 +1942,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -1848,8 +1958,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, NIC: e.boundNICID, }, nil } @@ -1861,6 +1971,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() return } @@ -1868,11 +1979,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if !s.csumValid { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() return } e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + e.stats.SegmentsReceived.Increment() if (s.flags & header.TCPFlagRst) != 0 { e.stack.Stats().TCP.ResetsReceived.Increment() } @@ -1883,6 +1996,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } else { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.SegmentQueueDropped.Increment() s.decRef() } } @@ -1932,6 +2046,7 @@ func (e *endpoint) readyToRead(s *segment) { // that a subsequent read of the segment will correctly trigger // a non-zero notification. if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() e.zeroWindow = true } e.rcvList.PushBack(s) @@ -2084,7 +2199,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Copy EndpointID. e.mu.Lock() - s.ID = stack.TCPEndpointID(e.id) + s.ID = stack.TCPEndpointID(e.ID) e.mu.Unlock() // Copy endpoint rcv state. @@ -2191,7 +2306,7 @@ func (e *endpoint) initGSO() { gso.Type = stack.GSOTCPv6 gso.L3HdrLen = header.IPv6MinimumSize default: - panic(fmt.Sprintf("Unknown netProto: %v", e.netProto)) + panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } gso.NeedsCsum = true gso.CsumOffset = header.TCPChecksumOffset @@ -2207,6 +2322,20 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.EndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + func mssForRoute(r *stack.Route) uint16 { return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 831389ec7..eae17237e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() { case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() @@ -190,10 +190,10 @@ func (e *endpoint) Resume(s *stack.Stack) { bind := func() { e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress + if len(e.BindAddr) == 0 { + e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { + if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { panic("endpoint binding failed: " + err.String()) } } @@ -202,19 +202,19 @@ func (e *endpoint) Resume(s *stack.Stack) { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.id.RemoteAddress + e.connectingAddress = e.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } connectedLoading.Done() @@ -236,7 +236,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -288,21 +288,21 @@ func (e *endpoint) loadLastError(s string) { } // saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { +func (e *EndpointInfo) saveHardError() string { + if e.HardError == nil { return "" } - return e.hardError.String() + return e.HardError.String() } // loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { +func (e *EndpointInfo) loadHardError(s string) { if s == "" { return } - e.hardError = loadError(s) + e.HardError = loadError(s) } var messageToError map[string]*tcpip.Error diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 735edfe55..8332a0179 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -417,6 +417,7 @@ func (s *sender) resendSegment() { s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() + s.ep.stats.SendErrors.FastRetransmit.Increment() // Run SetPipe() as per RFC 6675 section 5 Step 4.4 s.SetPipe() @@ -435,6 +436,7 @@ func (s *sender) retransmitTimerExpired() bool { } s.ep.stack.Stats().TCP.Timeouts.Increment() + s.ep.stats.SendErrors.Timeouts.Increment() // Give up if we've waited more than a minute since the last resend. if s.rto >= 60*time.Second { @@ -1188,6 +1190,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { func (s *sender) sendSegment(seg *segment) *tcpip.Error { if !seg.xmitTime.IsZero() { s.ep.stack.Stats().TCP.Retransmits.Increment() + s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 9fa97528b..782d7b42c 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 4e7f1a740..afea124ec 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) { t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + } + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) // Acknowledge all pending data to recover point. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a86123829..8eaf9786d 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -99,7 +99,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) } } @@ -122,6 +125,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) + } } func TestTCPSegmentsSentIncrement(t *testing.T) { @@ -136,6 +142,9 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { + t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) + } } func TestTCPResetsSentIncrement(t *testing.T) { @@ -857,6 +866,10 @@ func TestShutdownRead(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) + } } func TestFullWindowReceive(t *testing.T) { @@ -913,6 +926,11 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("got data = %v, want = %v", v, data) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) + } + // Check that we get an ACK for the newly non-zero window. checker.IPv4(t, c.GetPacket(), checker.TCP( @@ -2085,6 +2103,13 @@ loop: t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) } } + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { + t.Fatalf("got EP state is not StateError") + } } func TestSendOnResetConnection(t *testing.T) { @@ -2656,6 +2681,17 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { + t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) + } + // Ensure there were no errors during handshake. If these stats have + // incremented, then the connection should not have been established. + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) + } + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) + } } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { @@ -2680,6 +2716,9 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } } func TestReceivedIncorrectChecksumIncrement(t *testing.T) { @@ -2706,6 +2745,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { if got := stats.TCP.ChecksumErrors.Value(); got != want { t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) + } } func TestReceivedSegmentQueuing(t *testing.T) { @@ -4158,6 +4200,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) + } we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -4196,6 +4241,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + // Expect InvalidEndpointState errors on a read at this point. + if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + } + if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { + t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + } + if err := ep.Listen(10); err != nil { t.Fatalf("Listen failed: %v", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 8ae83437e..2ad7978c2 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -49,6 +49,22 @@ const ( StateClosed ) +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnected: + return "CONNECTING" + case StateClosed: + return "CLOSED" + default: + return "UNKNOWN" + } +} + // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -58,10 +74,11 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue // The following fields are used to manage the receive queue, and are @@ -76,10 +93,7 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int - id stack.TransportEndpointID state EndpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID route stack.Route `state:"manual"` dstPort uint16 v6only bool @@ -106,6 +120,9 @@ type endpoint struct { // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). effectiveNetProtos []tcpip.NetworkProtocolNumber + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` } // +stateify savable @@ -114,10 +131,13 @@ type multicastMembership struct { multicastAddr tcpip.Address } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { return &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, waiterQueue: waiterQueue, // RFC 1075 section 5.4 recommends a TTL of 1 for membership // requests. @@ -135,6 +155,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: StateInitial, } } @@ -146,12 +167,12 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) } for _, mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } e.multicastMemberships = nil @@ -191,6 +212,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -253,7 +275,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // configured multicast interface if no interface is specified and the // specified address is a multicast address. func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { - localAddr := e.id.LocalAddress + localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" @@ -279,6 +301,29 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -330,12 +375,12 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + if e.BindNICID != 0 { + if nicid != 0 && nicid != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicid = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -384,7 +429,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl, useDefaultTTL); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -405,7 +450,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -458,7 +503,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if e.bindNICID != 0 && e.bindNICID != nic { + if e.BindNICID != 0 && e.BindNICID != nic { return tcpip.ErrInvalidEndpointState } @@ -485,7 +530,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -502,7 +547,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -523,7 +568,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -545,7 +590,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrBadLocalAddress } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -624,7 +669,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -733,14 +778,18 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u udp.SetChecksum(^udp.CalculateChecksum(xsum)) } + if err := r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL); err != nil { + r.Stats().UDP.PacketSendErrors.Increment() + return err + } + // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - - return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL) + return nil } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if len(addr.Addr) == 0 { return netProto, nil } @@ -757,14 +806,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t } // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.id.LocalAddress) == 16 { + if !allowMismatch && len(e.ID.LocalAddress) == 16 { return 0, tcpip.ErrNetworkUnreachable } } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -781,27 +830,27 @@ func (e *endpoint) Disconnect() *tcpip.Error { } id := stack.TransportEndpointID{} // Exclude ephemerally bound endpoints. - if e.bindNICID != 0 || e.id.LocalAddress == "" { + if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.id.LocalPort, - LocalAddress: e.id.LocalAddress, + LocalPort: e.ID.LocalPort, + LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } e.state = StateBound } else { - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) - e.id = id + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.ID = id e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -828,16 +877,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { switch e.state { case StateInitial: case StateBound, StateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicid != 0 && nicid != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicid = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -849,7 +898,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() id := stack.TransportEndpointID{ - LocalAddress: e.id.LocalAddress, + LocalAddress: e.ID.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, RemoteAddress: r.RemoteAddress, @@ -876,14 +925,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Remove the old registration. - if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) + if e.ID.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) } - e.id = id + e.ID = id e.route = r.Clone() e.dstPort = addr.Port - e.regNICID = nicid + e.RegisterNICID = nicid e.effectiveNetProtos = netProtos e.state = StateConnected @@ -939,7 +988,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if e.id.LocalPort == 0 { + if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { return id, err @@ -995,8 +1044,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id - e.regNICID = nicid + e.ID = id + e.RegisterNICID = nicid e.effectiveNetProtos = netProtos // Mark endpoint as bound. @@ -1021,7 +1070,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // Save the effective NICID generated by bindLocked. - e.bindNICID = e.regNICID + e.BindNICID = e.RegisterNICID return nil } @@ -1032,9 +1081,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -1048,9 +1097,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -1080,6 +1129,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if int(hdr.Length()) > vv.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } @@ -1087,11 +1137,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() + e.stats.PacketsReceived.Increment() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } @@ -1130,6 +1189,20 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index be46e6d4e..b227e353b 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -72,7 +72,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } } @@ -93,13 +93,13 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } @@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) { // Our saved state had a port, but we don't actually have a // reservation. We need to remove the port from our state, but still // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id := e.ID + e.ID.LocalPort = 0 + e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 2d0bc5221..d399ec722 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -79,10 +79,10 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, return nil, err } - ep.id = r.id + ep.ID = r.id ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort - ep.regNICID = r.route.NICID() + ep.RegisterNICID = r.route.NICID() ep.state = StateConnected diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index faa728b68..4ada73475 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -384,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h) + c.injectV4Packet(payload, &h, true /* valid */) } else { - c.injectV6Packet(payload, &h) + c.injectV6Packet(payload, &h, true /* valid */) } } // injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -409,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) + l := uint16(header.UDPMinimumSize + len(payload)) + if !valid { + // Change the UDP payload length to corrupt the header + // as requested by the caller. + l++ + } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), + Length: l, }) // Calculate the UDP pseudo-header checksum. @@ -426,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -// injectV6Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { +// injectV4Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -536,7 +546,7 @@ func TestBindToDeviceOption(t *testing.T) { // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its // correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) { +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -547,6 +557,9 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) c.wq.EventRegister(&we, waiter.EventIn) defer c.wq.EventUnregister(&we) + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + var addr tcpip.FullAddress v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { @@ -563,6 +576,11 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) } } + if expectReadError && err != nil { + c.checkEndpointReadStats(1, epstats, err) + return + } + if err != nil { c.t.Fatal("Read failed:", err) } @@ -581,6 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it @@ -588,15 +607,15 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) // its correctness. func testRead(c *testContext, flow testFlow) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) } // testFailingRead sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then tries to read it from the UDP // endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow) { +func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */) + testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) } func TestBindEphemeralPort(t *testing.T) { @@ -771,8 +790,8 @@ func TestReadOnBoundToMulticast(t *testing.T) { // Check that we receive multicast packets but not unicast or broadcast // ones. testRead(c, flow) - testFailingRead(c, broadcast) - testFailingRead(c, unicastV4) + testFailingRead(c, broadcast, false /* expectReadError */) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } @@ -795,7 +814,7 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { // Check that we receive broadcast packets but not unicast ones. testRead(c, flow) - testFailingRead(c, unicastV4) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } @@ -826,7 +845,8 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { // and verifies it fails with the provided error code. func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { c.t.Helper() - + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) @@ -834,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) + c.checkEndpointWriteStats(1, epstats, gotErr) if gotErr != wantErr { c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } @@ -859,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() writeOpts := tcpip.WriteOptions{} if setDest { @@ -876,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - + c.checkEndpointWriteStats(1, epstats, err) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -945,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Write to V4 mapped address. testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + const want = 1 + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { + c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) + } } func TestDualWriteConnectedToV4Mapped(t *testing.T) { @@ -1453,3 +1480,117 @@ func TestV6UnknownDestination(t *testing.T) { }) } } + +// TestIncrementMalformedPacketsReceived verifies if the malformed received +// global and endpoint stats get incremented. +func TestIncrementMalformedPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + payload := newPayload() + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + c.injectV6Packet(payload, &h, false /* !valid */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } +} + +// TestShutdownRead verifies endpoint read shutdown and error +// stats increment on packet receive. +func TestShutdownRead(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingRead(c, unicastV6, true /* expectReadError */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { + t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) + } +} + +// TestShutdownWrite verifies endpoint write shutdown and error +// stats increment on packet write. +func TestShutdownWrite(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) +} + +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil: + want.PacketsSent.IncrementBy(incr) + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + want.WriteErrors.InvalidArgs.IncrementBy(incr) + case tcpip.ErrClosedForSend: + want.WriteErrors.WriteClosed.IncrementBy(incr) + case tcpip.ErrInvalidEndpointState: + want.WriteErrors.InvalidEndpointState.IncrementBy(incr) + case tcpip.ErrNoLinkAddress: + want.SendErrors.NoLinkAddr.IncrementBy(incr) + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + want.SendErrors.NoRoute.IncrementBy(incr) + default: + want.SendErrors.SendToNetworkFailed.IncrementBy(incr) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} + +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil, tcpip.ErrWouldBlock: + case tcpip.ErrClosedForReceive: + want.ReadErrors.ReadClosed.IncrementBy(incr) + default: + c.t.Errorf("Endpoint error missing stats update err %v", err) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} -- cgit v1.2.3