summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/tcpip.go11
-rw-r--r--pkg/tcpip/transport/tcp/accept.go8
-rw-r--r--pkg/tcpip/transport/tcp/connect.go84
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go22
-rw-r--r--pkg/tcpip/transport/tcp/segment.go2
-rw-r--r--tools/go_generics/defs.bzl8
7 files changed, 86 insertions, 82 deletions
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index cf25a086d..fdeba6bc4 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -85,17 +85,6 @@ var (
errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
)
-// ErrSaveRejection indicates a failed save due to unsupported networking state.
-// This type of errors is only used for save logic.
-type ErrSaveRejection struct {
- Err error
-}
-
-// Error returns a sensible description of the save rejection error.
-func (e ErrSaveRejection) Error() string {
- return "save rejected due to unsupported networking state: " + e.Err.Error()
-}
-
// A Clock provides the current time.
//
// Times returned by a Clock should always be used for application-visible
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index ac213e310..a71cb444f 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -349,17 +349,13 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
// to the endpoint.
e.mu.Lock()
e.state = stateClosed
+ e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
// Do cleanup if needed.
- e.completeWorkerLocked()
-
- if e.drainDone != nil {
- close(e.drainDone)
- }
- e.mu.Unlock()
+ e.completeWorker()
}()
e.mu.Lock()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 66904856c..5115dabe6 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -690,7 +690,7 @@ func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value,
return err
}
-func (e *endpoint) handleWrite() *tcpip.Error {
+func (e *endpoint) handleWrite() bool {
// Move packets from send queue to send list. The queue is accessible
// from other goroutines and protected by the send mutex, while the send
// list is only accessible from the handler goroutine, so it needs no
@@ -714,42 +714,47 @@ func (e *endpoint) handleWrite() *tcpip.Error {
// Push out any new packets.
e.snd.sendData()
- return nil
+ return true
}
-func (e *endpoint) handleClose() *tcpip.Error {
+func (e *endpoint) handleClose() bool {
// Drain the send queue.
e.handleWrite()
// Mark send side as closed.
e.snd.closed = true
- return nil
+ return true
}
-// resetConnectionLocked sends a RST segment and puts the endpoint in an error
-// state with the given error code. This method must only be called from the
-// protocol goroutine.
-func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+// resetConnection sends a RST segment and puts the endpoint in an error state
+// with the given error code.
+// This method must only be called from the protocol goroutine.
+func (e *endpoint) resetConnection(err *tcpip.Error) {
e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ e.mu.Lock()
e.state = stateError
e.hardError = err
+ e.mu.Unlock()
}
-// completeWorkerLocked is called by the worker goroutine when it's about to
-// exit. It marks the worker as completed and performs cleanup work if requested
-// by Close().
-func (e *endpoint) completeWorkerLocked() {
+// completeWorker is called by the worker goroutine when it's about to exit. It
+// marks the worker as completed and performs cleanup work if requested by
+// Close().
+func (e *endpoint) completeWorker() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
e.workerRunning = false
if e.workerCleanup {
- e.cleanupLocked()
+ e.cleanup()
}
}
// handleSegments pulls segments from the queue and processes them. It returns
-// no error if the protocol loop should continue, an error otherwise.
-func (e *endpoint) handleSegments() *tcpip.Error {
+// true if the protocol loop should continue, false otherwise.
+func (e *endpoint) handleSegments() bool {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue()
@@ -770,7 +775,11 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
s.decRef()
- return tcpip.ErrConnectionReset
+ e.mu.Lock()
+ e.state = stateError
+ e.hardError = tcpip.ErrConnectionReset
+ e.mu.Unlock()
+ return false
}
} else if s.flagIsSet(flagAck) {
// Patch the window size in the segment according to the
@@ -807,7 +816,7 @@ func (e *endpoint) handleSegments() *tcpip.Error {
e.snd.sendAck()
}
- return nil
+ return true
}
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
@@ -818,10 +827,9 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
var closeWaker sleep.Waker
defer func() {
- // e.mu is expected to be held upon entering this section.
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
- e.completeWorkerLocked()
+ e.completeWorker()
if e.snd != nil {
e.snd.resendTimer.cleanup()
@@ -830,12 +838,6 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
if closeTimer != nil {
closeTimer.Stop()
}
-
- if e.drainDone != nil {
- close(e.drainDone)
- }
-
- e.mu.Unlock()
}()
if !passive {
@@ -854,7 +856,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.mu.Lock()
e.state = stateError
e.hardError = err
- // Lock released in deferred statement.
+ drained := e.drainDone != nil
+ e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
return err
}
@@ -885,7 +892,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
// wakes up.
funcs := []struct {
w *sleep.Waker
- f func() *tcpip.Error
+ f func() bool
}{
{
w: &e.sndWaker,
@@ -901,22 +908,24 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
},
{
w: &closeWaker,
- f: func() *tcpip.Error {
- return tcpip.ErrConnectionAborted
+ f: func() bool {
+ e.resetConnection(tcpip.ErrConnectionAborted)
+ return false
},
},
{
w: &e.snd.resendWaker,
- f: func() *tcpip.Error {
+ f: func() bool {
if !e.snd.retransmitTimerExpired() {
- return tcpip.ErrTimeout
+ e.resetConnection(tcpip.ErrTimeout)
+ return false
}
- return nil
+ return true
},
},
{
w: &e.notificationWaker,
- f: func() *tcpip.Error {
+ f: func() bool {
n := e.fetchNotifications()
if n&notifyNonZeroReceiveWindow != 0 {
e.rcv.nonZeroWindow()
@@ -943,7 +952,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
closeWaker.Assert()
})
}
- return nil
+ return true
},
},
}
@@ -960,10 +969,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
- if err := funcs[v].f(); err != nil {
- e.mu.Lock()
- e.resetConnectionLocked(err)
- // Lock released in deferred statement.
+ if !funcs[v].f() {
return nil
}
}
@@ -971,7 +977,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
// Mark endpoint as closed.
e.mu.Lock()
e.state = stateClosed
- // Lock released in deferred statement.
+ e.mu.Unlock()
return nil
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 3f87c4cac..f26b28632 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -94,7 +94,7 @@ type endpoint struct {
state endpointState
isPortReserved bool `state:"manual"`
isRegistered bool
- boundNICID tcpip.NICID `state:"manual"`
+ boundNICID tcpip.NICID
route stack.Route `state:"manual"`
v6only bool
isConnectNotified bool
@@ -118,7 +118,7 @@ type endpoint struct {
// workerCleanup specifies if the worker goroutine must perform cleanup
// before exitting. This can only be set to true when workerRunning is
// also true, and they're both protected by the mutex.
- workerCleanup bool `state:"zerovalue"`
+ workerCleanup bool
// sendTSOk is used to indicate when the TS Option has been negotiated.
// When sendTSOk is true every non-RST segment should carry a TS as per
@@ -326,7 +326,13 @@ func (e *endpoint) Close() {
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ // While we hold the lock, determine if the cleanup should happen
+ // inline or if we should tell the worker (if any) to do the cleanup.
e.mu.Lock()
+ worker := e.workerRunning
+ if worker {
+ e.workerCleanup = true
+ }
// We always release ports inline so that they are immediately available
// for reuse after Close() is called. If also registered, it means this
@@ -342,32 +348,29 @@ func (e *endpoint) Close() {
}
}
- // Either perform the local cleanup or kick the worker to make sure it
- // knows it needs to cleanup.
- if !e.workerRunning {
- e.cleanupLocked()
+ e.mu.Unlock()
+
+ // Now that we don't hold the lock anymore, either perform the local
+ // cleanup or kick the worker to make sure it knows it needs to cleanup.
+ if !worker {
+ e.cleanup()
} else {
- e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose)
}
-
- e.mu.Unlock()
}
-// cleanupLocked frees all resources associated with the endpoint. It is called
-// after Close() is called and the worker goroutine (if any) is done with its
-// work.
-func (e *endpoint) cleanupLocked() {
+// cleanup frees all resources associated with the endpoint. It is called after
+// Close() is called and the worker goroutine (if any) is done with its work.
+func (e *endpoint) cleanup() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
close(e.acceptedChan)
for n := range e.acceptedChan {
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.resetConnection(tcpip.ErrConnectionAborted)
n.Close()
}
}
- e.workerCleanup = false
if e.isRegistered {
e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b1e249bff..deef670d5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -12,6 +12,17 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
)
+// ErrSaveRejection indicates a failed save due to unsupported tcp endpoint
+// state.
+type ErrSaveRejection struct {
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported endpoint state: " + e.Err.Error()
+}
+
func (e *endpoint) drainSegmentLocked() {
// Drain only up to once.
if e.drainDone != nil {
@@ -37,7 +48,8 @@ func (e *endpoint) beforeSave() {
defer e.mu.Unlock()
switch e.state {
- case stateInitial, stateBound:
+ case stateInitial:
+ case stateBound:
case stateListen:
if !e.segmentQueue.empty() {
e.drainSegmentLocked()
@@ -50,11 +62,9 @@ func (e *endpoint) beforeSave() {
fallthrough
case stateConnected:
// FIXME
- panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
- case stateClosed, stateError:
- if e.workerRunning {
- panic(fmt.Sprintf("endpoint still has worker running in closed or error state"))
- }
+ panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ case stateClosed:
+ case stateError:
default:
panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 07e4bfd73..c742fc394 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -30,7 +30,7 @@ type segment struct {
segmentEntry
refCnt int32
id stack.TransportEndpointID
- route stack.Route `state:"manual"`
+ route stack.Route
data buffer.VectorisedView
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
index 214e0dfca..631bd11d3 100644
--- a/tools/go_generics/defs.bzl
+++ b/tools/go_generics/defs.bzl
@@ -64,22 +64,22 @@ def _go_template_instance_impl(ctx):
if t not in ctx.attr.types:
fail("Missing value for type %s in %s" % (t, ctx.attr.template.label))
- # Check that all defined types are expected by the template.
+ # Check that all defined types are expected by the template.
for t in ctx.attr.types:
if (t not in template.types) and (t not in template.opt_types):
fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label))
- # Check that all required consts are defined.
+ # Check that all required consts are defined.
for t in template.consts:
if t not in ctx.attr.consts:
fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label))
- # Check that all defined consts are expected by the template.
+ # Check that all defined consts are expected by the template.
for t in ctx.attr.consts:
if (t not in template.consts) and (t not in template.opt_consts):
fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label))
- # Build the argument list.
+ # Build the argument list.
args = ["-i=%s" % template.file.path, "-o=%s" % output.path]
args += ["-p=%s" % ctx.attr.package]