diff options
-rw-r--r-- | pkg/tcpip/tcpip.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 84 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 2 | ||||
-rw-r--r-- | tools/go_generics/defs.bzl | 8 |
7 files changed, 86 insertions, 82 deletions
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index cf25a086d..fdeba6bc4 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -85,17 +85,6 @@ var ( errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask") ) -// ErrSaveRejection indicates a failed save due to unsupported networking state. -// This type of errors is only used for save logic. -type ErrSaveRejection struct { - Err error -} - -// Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { - return "save rejected due to unsupported networking state: " + e.Err.Error() -} - // A Clock provides the current time. // // Times returned by a Clock should always be used for application-visible diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index ac213e310..a71cb444f 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -349,17 +349,13 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // to the endpoint. e.mu.Lock() e.state = stateClosed + e.mu.Unlock() // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) // Do cleanup if needed. - e.completeWorkerLocked() - - if e.drainDone != nil { - close(e.drainDone) - } - e.mu.Unlock() + e.completeWorker() }() e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 66904856c..5115dabe6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -690,7 +690,7 @@ func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, return err } -func (e *endpoint) handleWrite() *tcpip.Error { +func (e *endpoint) handleWrite() bool { // Move packets from send queue to send list. The queue is accessible // from other goroutines and protected by the send mutex, while the send // list is only accessible from the handler goroutine, so it needs no @@ -714,42 +714,47 @@ func (e *endpoint) handleWrite() *tcpip.Error { // Push out any new packets. e.snd.sendData() - return nil + return true } -func (e *endpoint) handleClose() *tcpip.Error { +func (e *endpoint) handleClose() bool { // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - return nil + return true } -// resetConnectionLocked sends a RST segment and puts the endpoint in an error -// state with the given error code. This method must only be called from the -// protocol goroutine. -func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { +// resetConnection sends a RST segment and puts the endpoint in an error state +// with the given error code. +// This method must only be called from the protocol goroutine. +func (e *endpoint) resetConnection(err *tcpip.Error) { e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + e.mu.Lock() e.state = stateError e.hardError = err + e.mu.Unlock() } -// completeWorkerLocked is called by the worker goroutine when it's about to -// exit. It marks the worker as completed and performs cleanup work if requested -// by Close(). -func (e *endpoint) completeWorkerLocked() { +// completeWorker is called by the worker goroutine when it's about to exit. It +// marks the worker as completed and performs cleanup work if requested by +// Close(). +func (e *endpoint) completeWorker() { + e.mu.Lock() + defer e.mu.Unlock() + e.workerRunning = false if e.workerCleanup { - e.cleanupLocked() + e.cleanup() } } // handleSegments pulls segments from the queue and processes them. It returns -// no error if the protocol loop should continue, an error otherwise. -func (e *endpoint) handleSegments() *tcpip.Error { +// true if the protocol loop should continue, false otherwise. +func (e *endpoint) handleSegments() bool { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { s := e.segmentQueue.dequeue() @@ -770,7 +775,11 @@ func (e *endpoint) handleSegments() *tcpip.Error { // validated by checking their SEQ-fields." So // we only process it if it's acceptable. s.decRef() - return tcpip.ErrConnectionReset + e.mu.Lock() + e.state = stateError + e.hardError = tcpip.ErrConnectionReset + e.mu.Unlock() + return false } } else if s.flagIsSet(flagAck) { // Patch the window size in the segment according to the @@ -807,7 +816,7 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } - return nil + return true } // protocolMainLoop is the main loop of the TCP protocol. It runs in its own @@ -818,10 +827,9 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { var closeWaker sleep.Waker defer func() { - // e.mu is expected to be held upon entering this section. // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.completeWorkerLocked() + e.completeWorker() if e.snd != nil { e.snd.resendTimer.cleanup() @@ -830,12 +838,6 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { if closeTimer != nil { closeTimer.Stop() } - - if e.drainDone != nil { - close(e.drainDone) - } - - e.mu.Unlock() }() if !passive { @@ -854,7 +856,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.mu.Lock() e.state = stateError e.hardError = err - // Lock released in deferred statement. + drained := e.drainDone != nil + e.mu.Unlock() + if drained { + close(e.drainDone) + <-e.undrain + } return err } @@ -885,7 +892,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // wakes up. funcs := []struct { w *sleep.Waker - f func() *tcpip.Error + f func() bool }{ { w: &e.sndWaker, @@ -901,22 +908,24 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { }, { w: &closeWaker, - f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + f: func() bool { + e.resetConnection(tcpip.ErrConnectionAborted) + return false }, }, { w: &e.snd.resendWaker, - f: func() *tcpip.Error { + f: func() bool { if !e.snd.retransmitTimerExpired() { - return tcpip.ErrTimeout + e.resetConnection(tcpip.ErrTimeout) + return false } - return nil + return true }, }, { w: &e.notificationWaker, - f: func() *tcpip.Error { + f: func() bool { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -943,7 +952,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { closeWaker.Assert() }) } - return nil + return true }, }, } @@ -960,10 +969,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() - if err := funcs[v].f(); err != nil { - e.mu.Lock() - e.resetConnectionLocked(err) - // Lock released in deferred statement. + if !funcs[v].f() { return nil } } @@ -971,7 +977,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() e.state = stateClosed - // Lock released in deferred statement. + e.mu.Unlock() return nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3f87c4cac..f26b28632 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -94,7 +94,7 @@ type endpoint struct { state endpointState isPortReserved bool `state:"manual"` isRegistered bool - boundNICID tcpip.NICID `state:"manual"` + boundNICID tcpip.NICID route stack.Route `state:"manual"` v6only bool isConnectNotified bool @@ -118,7 +118,7 @@ type endpoint struct { // workerCleanup specifies if the worker goroutine must perform cleanup // before exitting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. - workerCleanup bool `state:"zerovalue"` + workerCleanup bool // sendTSOk is used to indicate when the TS Option has been negotiated. // When sendTSOk is true every non-RST segment should carry a TS as per @@ -326,7 +326,13 @@ func (e *endpoint) Close() { // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + // While we hold the lock, determine if the cleanup should happen + // inline or if we should tell the worker (if any) to do the cleanup. e.mu.Lock() + worker := e.workerRunning + if worker { + e.workerCleanup = true + } // We always release ports inline so that they are immediately available // for reuse after Close() is called. If also registered, it means this @@ -342,32 +348,29 @@ func (e *endpoint) Close() { } } - // Either perform the local cleanup or kick the worker to make sure it - // knows it needs to cleanup. - if !e.workerRunning { - e.cleanupLocked() + e.mu.Unlock() + + // Now that we don't hold the lock anymore, either perform the local + // cleanup or kick the worker to make sure it knows it needs to cleanup. + if !worker { + e.cleanup() } else { - e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } - - e.mu.Unlock() } -// cleanupLocked frees all resources associated with the endpoint. It is called -// after Close() is called and the worker goroutine (if any) is done with its -// work. -func (e *endpoint) cleanupLocked() { +// cleanup frees all resources associated with the endpoint. It is called after +// Close() is called and the worker goroutine (if any) is done with its work. +func (e *endpoint) cleanup() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { close(e.acceptedChan) for n := range e.acceptedChan { - n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.resetConnection(tcpip.ErrConnectionAborted) n.Close() } } - e.workerCleanup = false if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b1e249bff..deef670d5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -12,6 +12,17 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" ) +// ErrSaveRejection indicates a failed save due to unsupported tcp endpoint +// state. +type ErrSaveRejection struct { + Err error +} + +// Error returns a sensible description of the save rejection error. +func (e ErrSaveRejection) Error() string { + return "save rejected due to unsupported endpoint state: " + e.Err.Error() +} + func (e *endpoint) drainSegmentLocked() { // Drain only up to once. if e.drainDone != nil { @@ -37,7 +48,8 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial, stateBound: + case stateInitial: + case stateBound: case stateListen: if !e.segmentQueue.empty() { e.drainSegmentLocked() @@ -50,11 +62,9 @@ func (e *endpoint) beforeSave() { fallthrough case stateConnected: // FIXME - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) - case stateClosed, stateError: - if e.workerRunning { - panic(fmt.Sprintf("endpoint still has worker running in closed or error state")) - } + panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + case stateClosed: + case stateError: default: panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 07e4bfd73..c742fc394 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -30,7 +30,7 @@ type segment struct { segmentEntry refCnt int32 id stack.TransportEndpointID - route stack.Route `state:"manual"` + route stack.Route data buffer.VectorisedView // views is used as buffer for data when its length is large // enough to store a VectorisedView. diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index 214e0dfca..631bd11d3 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -64,22 +64,22 @@ def _go_template_instance_impl(ctx): if t not in ctx.attr.types: fail("Missing value for type %s in %s" % (t, ctx.attr.template.label)) - # Check that all defined types are expected by the template. + # Check that all defined types are expected by the template. for t in ctx.attr.types: if (t not in template.types) and (t not in template.opt_types): fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label)) - # Check that all required consts are defined. + # Check that all required consts are defined. for t in template.consts: if t not in ctx.attr.consts: fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label)) - # Check that all defined consts are expected by the template. + # Check that all defined consts are expected by the template. for t in ctx.attr.consts: if (t not in template.consts) and (t not in template.opt_consts): fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label)) - # Build the argument list. + # Build the argument list. args = ["-i=%s" % template.file.path, "-o=%s" % output.path] args += ["-p=%s" % ctx.attr.package] |