diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 308 |
1 files changed, 241 insertions, 67 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a1b784b49..fe629aa40 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -121,6 +122,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -287,6 +293,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -319,6 +326,11 @@ type endpoint struct { state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` @@ -330,6 +342,11 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple @@ -411,7 +428,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -458,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -502,6 +525,36 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS } // StopWork halts packet processing. Only to be used in tests. @@ -550,6 +603,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption @@ -572,6 +626,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptInt(tcpip.DelayOption, 1) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -659,6 +723,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -671,14 +742,18 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -704,9 +779,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -732,16 +805,19 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -752,7 +828,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -1133,16 +1211,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { e.sndBufMu.Unlock() return nil - default: - return nil - } -} - -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 - switch v := opt.(type) { case tcpip.DelayOption: if v == 0 { atomic.StoreUint32(&e.delay, 0) @@ -1154,6 +1222,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { case tcpip.CorkOption: if v == 0 { atomic.StoreUint32(&e.cork, 0) @@ -1184,9 +1262,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = 0 return nil } - for nicid, nic := range e.stack.NICInfo() { + for nicID, nic := range e.stack.NICInfo() { if nic.Name == string(v) { - e.bindToDevice = nicid + e.bindToDevice = nicID return nil } } @@ -1206,7 +1284,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidOptionValue } e.mu.Lock() - e.userMSS = int(userMSS) + e.userMSS = uint16(userMSS) e.mu.Unlock() e.notifyProtocolGoroutine(notifyMSSChanged) return nil @@ -1262,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1319,6 +1403,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1345,6 +1451,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: e.sndBufMu.Lock() v := e.sndBufSize @@ -1357,8 +1464,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil + case tcpip.DelayOption: + var o int + if v := atomic.LoadUint32(&e.delay); v != 0 { + o = 1 + } + return o, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -1379,13 +1494,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.DelayOption: - *o = 0 - if v := atomic.LoadUint32(&e.delay); v != 0 { - *o = 1 - } - return nil - case *tcpip.CorkOption: *o = 0 if v := atomic.LoadUint32(&e.cork); v != 0 { @@ -1496,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 @@ -1530,6 +1644,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1602,7 +1722,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnected } - nicid := addr.NIC + nicID := addr.NIC switch e.state { case StateBound: // If we're already bound to a NIC but the caller is requesting @@ -1611,11 +1731,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1634,7 +1754,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -1649,7 +1769,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) if err != nil { return err } @@ -1664,7 +1784,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1678,15 +1798,18 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // reusePort is false below because connect cannot reuse a port even if // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { return false, nil } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: + // Port picking successful. Save the details of + // the selected port. e.ID = id + e.boundBindToDevice = e.bindToDevice return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1702,14 +1825,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } e.isRegistered = true e.state = StateConnecting e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1729,6 +1852,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.state = StateEstablished + e.stack.Stats().TCP.CurrentEstablished.Increment() } if run { @@ -1749,9 +1873,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1766,6 +1889,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1784,14 +1908,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1799,11 +1920,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } @@ -1844,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } + if e.state == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() @@ -1851,7 +1990,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { return err } @@ -1895,12 +2034,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -1908,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + return e.bindLocked(addr) +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. @@ -1932,26 +2070,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = flags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func(bindToDevice tcpip.NICID) { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }(e.bindToDevice) + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -2001,8 +2146,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -2025,6 +2170,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() @@ -2037,7 +2186,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -2327,11 +2476,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { return s } -func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityGSO == 0 { - return - } - +func (e *endpoint) initHardwareGSO() { gso := &stack.GSO{} switch e.route.NetProto { case header.IPv4ProtocolNumber: @@ -2349,6 +2494,18 @@ func (e *endpoint) initGSO() { e.gso = gso } +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + e.initHardwareGSO() + } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + e.gso = &stack.GSO{ + MaxSize: e.route.GSOMaxSize(), + Type: stack.GSOSW, + NeedsCsum: false, + } + } +} + // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { @@ -2371,6 +2528,23 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } |