summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go90
-rw-r--r--pkg/tcpip/transport/tcp/connect.go78
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go495
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go5
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go38
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go6
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go8
-rw-r--r--pkg/tcpip/transport/tcp/snd.go3
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go41
10 files changed, 402 insertions, 370 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 85049e54e..4d7602d54 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -221,7 +221,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
}
// createConnectingEndpoint creates a new endpoint in a connecting state, with
-// the connection parameters given by the arguments.
+// the connection parameters given by the arguments. The endpoint is returned
+// with n.mu held.
func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
@@ -243,21 +244,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
- // Now inherit any socket options that should be inherited from the
- // listening endpoint.
- // In case of Forwarder listenEP will be nil and hence this check.
- if l.listenEP != nil {
- l.listenEP.propagateInheritableOptions(n)
- }
-
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
-
// Create sender and receiver.
//
// The receiver at least temporarily has a zero receive window scale,
@@ -269,11 +255,27 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ n.mu.Lock()
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+ return nil, err
+ }
+
+ n.isRegistered = true
+
return n, nil
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
// and then performs the TCP 3-way handshake.
+//
+// The new endpoint is returned with e.mu held.
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
@@ -289,9 +291,25 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
l.listenEP.mu.Unlock()
+ // Ensure we release any registrations done by the newly
+ // created endpoint.
+ ep.mu.Unlock()
+ ep.Close()
+
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ l.listenEP.propagateInheritableOptionsLocked(ep)
+
deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
@@ -299,6 +317,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
+ ep.mu.Unlock()
ep.Close()
// Wake up any waiters. This is strictly not required normally
// as a socket that was never accepted can't really have any
@@ -312,9 +331,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
}
return nil, err
}
- ep.mu.Lock()
ep.isConnectNotified = true
- ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
@@ -366,12 +383,12 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
}
-// propagateInheritableOptions propagates any options set on the listening
+// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
-func (e *endpoint) propagateInheritableOptions(n *endpoint) {
- e.mu.Lock()
+//
+// Precondition: e.mu and n.mu must be held.
+func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
- e.mu.Unlock()
}
// handleSynSegment is called in its own goroutine once the listening endpoint
@@ -382,7 +399,11 @@ func (e *endpoint) propagateInheritableOptions(n *endpoint) {
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer decSynRcvdCount()
- defer e.decSynRcvdCount()
+ defer func() {
+ e.mu.Lock()
+ e.decSynRcvdCount()
+ e.mu.Unlock()
+ }()
defer s.decRef()
n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
@@ -399,29 +420,21 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
func (e *endpoint) incSynRcvdCount() bool {
- e.mu.Lock()
- if e.synRcvdCount >= cap(e.acceptedChan) {
- e.mu.Unlock()
+ if e.synRcvdCount >= (cap(e.acceptedChan)) {
return false
}
e.synRcvdCount++
- e.mu.Unlock()
return true
}
func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
e.synRcvdCount--
- e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
- e.mu.Lock()
if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
- e.mu.Unlock()
return true
}
- e.mu.Unlock()
return false
}
@@ -559,6 +572,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
// randomly offset when the original SYN-ACK was
@@ -593,14 +610,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
v6only := e.v6only
- e.mu.Unlock()
ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
- e.mu.Lock()
e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
@@ -622,7 +637,10 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
- switch index, _ := s.Fetch(true); index {
+ e.mu.Unlock()
+ index, _ := s.Fetch(true)
+ e.mu.Lock()
+ switch index {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
@@ -635,7 +653,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
s.decRef()
}
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
case wakerForNewSegment:
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index be86af502..edb37a549 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -61,6 +61,9 @@ const (
)
// handshake holds the state used during a TCP 3-way handshake.
+//
+// NOTE: handshake.ep.mu is held during handshake processing. It is released if
+// we are going to block and reacquired when we start processing an event.
type handshake struct {
ep *endpoint
state handshakeState
@@ -209,9 +212,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.mss = opts.MSS
h.sndWndScale = opts.WS
h.deferAccept = deferAccept
- h.ep.mu.Lock()
h.ep.setEndpointState(StateSynRecv)
- h.ep.mu.Unlock()
}
// checkAck checks if the ACK number, if present, of a segment received during
@@ -241,9 +242,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
// was acceptable then signal the user "error: connection reset", drop
// the segment, enter CLOSED state, delete TCB, and return."
- h.ep.mu.Lock()
h.ep.workerCleanup = true
- h.ep.mu.Unlock()
// Although the RFC above calls out ECONNRESET, Linux actually returns
// ECONNREFUSED here so we do as well.
return tcpip.ErrConnectionRefused
@@ -281,9 +280,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if s.flagIsSet(header.TCPFlagAck) {
h.state = handshakeCompleted
- h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
- h.ep.mu.Unlock()
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -293,11 +290,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// but resend our own SYN and wait for it to be acknowledged in the
// SYN-RCVD state.
h.state = handshakeSynRcvd
- h.ep.mu.Lock()
ttl := h.ep.ttl
amss := h.ep.amss
h.ep.setEndpointState(StateSynRecv)
- h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
@@ -357,10 +352,6 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- h.ep.mu.RLock()
- amss := h.ep.amss
- h.ep.mu.RUnlock()
-
h.resetState()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@@ -368,7 +359,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
- MSS: amss,
+ MSS: h.ep.amss,
}
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@@ -399,15 +390,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
}
h.state = handshakeCompleted
- h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
+
// If the segment has data then requeue it for the receiver
// to process it again once main loop is started.
if s.data.Size() > 0 {
s.incRef()
h.ep.enqueueSegment(s)
}
- h.ep.mu.Unlock()
return nil
}
@@ -493,7 +483,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
if n&notifyDrain != 0 {
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
}
}
@@ -535,7 +527,6 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
- h.ep.mu.Lock()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
@@ -546,7 +537,6 @@ func (h *handshake) execute() *tcpip.Error {
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
- h.ep.mu.Unlock()
// Execute is also called in a listen context so we want to make sure we
// only send the TS/SACK option when we received the TS/SACK in the
@@ -563,7 +553,11 @@ func (h *handshake) execute() *tcpip.Error {
h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
for h.state != handshakeCompleted {
- switch index, _ := s.Fetch(true); index {
+ h.ep.mu.Unlock()
+ index, _ := s.Fetch(true)
+ h.ep.mu.Lock()
+ switch index {
+
case wakerForResend:
timeOut *= 2
if timeOut > MaxRTO {
@@ -600,7 +594,9 @@ func (h *handshake) execute() *tcpip.Error {
}
}
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
}
case wakerForNewSegment:
@@ -1016,7 +1012,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
- e.mu.Lock()
switch e.EndpointState() {
// In case of a RST in CLOSE-WAIT linux moves
// the socket to closed state with an error set
@@ -1040,11 +1035,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
case StateCloseWait:
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
- e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
- e.mu.Unlock()
// RFC 793, page 37 states that "in all states
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
@@ -1157,9 +1150,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
- e.mu.RLock()
state := e.state
- e.mu.RUnlock()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
@@ -1182,9 +1173,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
- e.mu.RLock()
userTimeout := e.userTimeout
- e.mu.RUnlock()
e.keepalive.Lock()
if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
@@ -1248,6 +1237,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// goroutine and is responsible for sending segments and handling received
// segments.
func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
+ e.mu.Lock()
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1269,7 +1259,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.mu.Unlock()
- e.workMu.Unlock()
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1280,16 +1269,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// completion.
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
- e.mu.Lock()
h.ep.setEndpointState(StateSynSent)
- e.mu.Unlock()
if err := h.execute(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
- e.mu.Lock()
e.setEndpointState(StateError)
e.HardError = err
@@ -1302,9 +1288,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- e.mu.Lock()
drained := e.drainDone != nil
- e.mu.Unlock()
if drained {
close(e.drainDone)
<-e.undrain
@@ -1330,10 +1314,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// This means the socket is being closed due
// to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
- e.mu.Lock()
e.transitionToStateCloseLocked()
e.workerCleanup = true
- e.mu.Unlock()
return nil
},
},
@@ -1388,7 +1370,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if n&notifyClose != 0 && closeTimer == nil {
- e.mu.Lock()
if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
@@ -1397,7 +1378,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
})
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
- e.mu.Unlock()
}
if n&notifyKeepaliveChanged != 0 {
@@ -1417,7 +1397,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
}
@@ -1460,7 +1442,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.rcvListMu.Unlock()
- e.mu.Lock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
@@ -1468,7 +1449,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
- e.mu.Lock()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1480,16 +1460,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
loop:
for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
e.mu.Unlock()
- e.workMu.Unlock()
v, _ := s.Fetch(true)
- e.workMu.Lock()
+ e.mu.Lock()
- // We need to double check here because the notification maybe
+ // We need to double check here because the notification may be
// stale by the time we got around to processing it.
- //
- // NOTE: since we now hold the workMu the processors cannot
- // change the state of the endpoint so it's safe to proceed
- // after this check.
switch e.EndpointState() {
case StateError:
// If the endpoint has already transitioned to an ERROR
@@ -1502,21 +1477,17 @@ loop:
case StateTimeWait:
fallthrough
case StateClose:
- e.mu.Lock()
break loop
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
return nil
}
- e.mu.Lock()
}
}
- state := e.EndpointState()
- e.mu.Unlock()
var reuseTW func()
- if state == StateTimeWait {
+ if e.EndpointState() == StateTimeWait {
// Disable close timer as we now entering real TIME_WAIT.
if closeTimer != nil {
closeTimer.Stop()
@@ -1526,14 +1497,11 @@ loop:
s.Done()
// Wake up any waiters before we enter TIME_WAIT.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
- e.mu.Lock()
e.workerCleanup = true
- e.mu.Unlock()
reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
- e.mu.Lock()
if e.EndpointState() != StateError {
e.transitionToStateCloseLocked()
}
@@ -1649,9 +1617,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
defer timeWaitTimer.Stop()
for {
- e.workMu.Unlock()
+ e.mu.Unlock()
v, _ := s.Fetch(true)
- e.workMu.Lock()
+ e.mu.Lock()
switch v {
case newSegment:
extendTimeWait, reuseTW := e.handleTimeWaitSegments()
@@ -1674,7 +1642,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
e.handleTimeWaitSegments()
}
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
return nil
}
case timeWaitDone:
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index d792b07d6..90ac956a9 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -128,7 +128,7 @@ func (p *processor) handleSegments() {
continue
}
- if !ep.workMu.TryLock() {
+ if !ep.mu.TryLock() {
ep.newSegmentWaker.Assert()
continue
}
@@ -138,12 +138,10 @@ func (p *processor) handleSegments() {
if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
// Send any active resets if required.
if err != nil {
- ep.mu.Lock()
ep.resetConnectionLocked(err)
- ep.mu.Unlock()
}
ep.notifyProtocolGoroutine(notifyTickleWorker)
- ep.workMu.Unlock()
+ ep.mu.Unlock()
continue
}
@@ -151,7 +149,7 @@ func (p *processor) handleSegments() {
p.epQ.enqueue(ep)
}
- ep.workMu.Unlock()
+ ep.mu.Unlock()
}
}
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 5187a5e25..eb8a9d73e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"fmt"
"math"
+ "runtime"
"strings"
"sync/atomic"
"time"
@@ -33,7 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tmutex"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -283,6 +283,37 @@ func (*EndpointInfo) IsEndpointInfo() {}
// synchronized. The protocol implementation, however, runs in a single
// goroutine.
//
+// Each endpoint has a few mutexes:
+//
+// e.mu -> Primary mutex for an endpoint must be held for all operations except
+// in e.Readiness where acquiring it will result in a deadlock in epoll
+// implementation.
+//
+// The following three mutexes can be acquired independent of e.mu but if
+// acquired with e.mu then e.mu must be acquired first.
+//
+// e.rcvListMu -> Protects the rcvList and associated fields.
+// e.sndBufMu -> Protects the sndQueue and associated fields.
+// e.lastErrorMu -> Protects the lastError field.
+//
+// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different
+// based on the context in which the lock is acquired. In the syscall context
+// e.LockUser/e.UnlockUser should be used and when doing background processing
+// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below
+// in brief.
+//
+// The reason for this locking behaviour is to avoid wakeups to handle packets.
+// In cases where the endpoint is already locked the background processor can
+// queue the packet up and go its merry way and the lock owner will eventually
+// process the backlog when releasing the lock. Similarly when acquiring the
+// lock from say a syscall goroutine we can implement a bit of spinning if we
+// know that the lock is not held by another syscall goroutine. Background
+// processors should never hold the lock for long and we can avoid an expensive
+// sleep/wakeup by spinning for a shortwhile.
+//
+// For more details please see the detailed documentation on
+// e.LockUser/e.UnlockUser methods.
+//
// +stateify savable
type endpoint struct {
EndpointInfo
@@ -299,12 +330,6 @@ type endpoint struct {
// Precondition: epQueue.mu must be held to read/write this field..
pendingProcessing bool `state:"nosave"`
- // workMu is used to arbitrate which goroutine may perform protocol
- // work. Only the main protocol goroutine is expected to call Lock() on
- // it, but other goroutines (e.g., send) may call TryLock() to eagerly
- // perform work without having to wait for the main one to wake up.
- workMu tmutex.Mutex `state:"nosave"`
-
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
@@ -330,15 +355,11 @@ type endpoint struct {
rcvBufSize int
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
- // zeroWindow indicates that the window was closed due to receive buffer
- // space being filled up. This is set by the worker goroutine before
- // moving a segment to the rcvList. This setting is cleared by the
- // endpoint when a Read() call reads enough data for the new window to
- // be non-zero.
- zeroWindow bool
- // The following fields are protected by the mutex.
- mu sync.RWMutex `state:"nosave"`
+ // mu protects all endpoint fields unless documented otherwise. mu must
+ // be acquired before interacting with the endpoint fields.
+ mu sync.Mutex `state:"nosave"`
+ ownedByUser uint32
// state must be read/set using the EndpointState()/setEndpointState() methods.
state EndpointState `state:".(EndpointState)"`
@@ -583,14 +604,93 @@ func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
return maxMSS
}
+// LockUser tries to lock e.mu and if it fails it will check if the lock is held
+// by another syscall goroutine. If yes, then it will goto sleep waiting for the
+// lock to be released, if not then it will spin till it acquires the lock or
+// another syscall goroutine acquires it in which case it will goto sleep as
+// described above.
+//
+// The assumption behind spinning here being that background packet processing
+// should not be holding the lock for long and spinning reduces latency as we
+// avoid an expensive sleep/wakeup of of the syscall goroutine).
+func (e *endpoint) LockUser() {
+ for {
+ // Try first if the sock is locked then check if it's owned
+ // by another user goroutine if not then we spin, otherwise
+ // we just goto sleep on the Lock() and wait.
+ if !e.mu.TryLock() {
+ // If socket is owned by the user then just goto sleep
+ // as the lock could be held for a reasonably long time.
+ if atomic.LoadUint32(&e.ownedByUser) == 1 {
+ e.mu.Lock()
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+ // Spin but yield the processor since the lower half
+ // should yield the lock soon.
+ runtime.Gosched()
+ continue
+ }
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+}
+
+// UnlockUser will check if there are any segments already queued for processing
+// and process any such segments before unlocking e.mu. This is required because
+// we when packets arrive and endpoint lock is already held then such packets
+// are queued up to be processed. If the lock is held by the endpoint goroutine
+// then it will process these packets but if the lock is instead held by the
+// syscall goroutine then we can have the syscall goroutine process the backlog
+// before unlocking.
+//
+// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the
+// endpoint. It's also required eventually when we get rid of the endpoint
+// protocol goroutine altogether.
+//
+// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
+func (e *endpoint) UnlockUser() {
+ // Lock segment queue before checking so that we avoid a race where
+ // segments can be queued between the time we check if queue is empty
+ // and actually unlock the endpoint mutex.
+ for {
+ e.segmentQueue.mu.Lock()
+ if e.segmentQueue.emptyLocked() {
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ e.segmentQueue.mu.Unlock()
+ return
+ }
+ e.segmentQueue.mu.Unlock()
+
+ switch e.EndpointState() {
+ case StateEstablished:
+ if err := e.handleSegments(true /* fastPath */); err != nil {
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ default:
+ // Since we are waking the endpoint goroutine here just unlock
+ // and let it process the queued segments.
+ e.newSegmentWaker.Assert()
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ return
+ }
+ }
+}
+
// StopWork halts packet processing. Only to be used in tests.
func (e *endpoint) StopWork() {
- e.workMu.Lock()
+ e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
func (e *endpoint) ResumeWork() {
- e.workMu.Unlock()
+ e.mu.Unlock()
}
// setEndpointState updates the state of the endpoint to state atomically. This
@@ -709,8 +809,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
- e.workMu.Lock()
e.tsOffset = timeStampOffset()
return e
@@ -721,9 +819,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result := waiter.EventMask(0)
- e.mu.RLock()
- defer e.mu.RUnlock()
-
switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
@@ -823,20 +918,22 @@ func (e *endpoint) Abort() {
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
- e.mu.Lock()
- closed := e.closed
- e.closed = true
- e.mu.Unlock()
- if closed {
+ e.LockUser()
+ defer e.UnlockUser()
+ if e.closed {
return
}
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
- e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
-
- e.mu.Lock()
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdownLocked()
+}
+// closeNoShutdown closes the endpoint without doing a full shutdown. This is
+// used when a connection needs to be aborted with a RST and we want to skip
+// a full 4 way TCP shutdown.
+func (e *endpoint) closeNoShutdownLocked() {
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
@@ -853,6 +950,8 @@ func (e *endpoint) Close() {
e.boundPortFlags = ports.Flags{}
}
+ // Mark endpoint as closed.
+ e.closed = true
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
switch e.EndpointState() {
@@ -873,8 +972,6 @@ func (e *endpoint) Close() {
// goroutine terminates.
e.notifyProtocolGoroutine(notifyClose)
}
-
- e.mu.Unlock()
}
// closePendingAcceptableConnections closes all connections that have completed
@@ -909,7 +1006,6 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
// after Close() is called and the worker goroutine (if any) is done with its
// work.
func (e *endpoint) cleanupLocked() {
-
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
@@ -954,18 +1050,18 @@ func (e *endpoint) initialReceiveWindow() int {
// ModerateRecvBuf adjusts the receive buffer and the advertised window
// based on the number of bytes copied to user space.
func (e *endpoint) ModerateRecvBuf(copied int) {
- e.mu.RLock()
+ e.LockUser()
+ defer e.UnlockUser()
+
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
- e.mu.RUnlock()
return
}
now := time.Now()
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
e.rcvAutoParams.copied += copied
e.rcvListMu.Unlock()
- e.mu.RUnlock()
return
}
prevRTTCopied := e.rcvAutoParams.copied + copied
@@ -1021,7 +1117,6 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvAutoParams.measureTime = now
e.rcvAutoParams.copied = 0
e.rcvListMu.Unlock()
- e.mu.RUnlock()
}
// IPTables implements tcpip.Endpoint.IPTables.
@@ -1031,7 +1126,7 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
+ e.LockUser()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
@@ -1041,7 +1136,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.mu.RUnlock()
+ e.UnlockUser()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
@@ -1051,7 +1146,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
@@ -1124,13 +1219,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
- e.mu.RLock()
+ e.LockUser()
e.sndBufMu.Lock()
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
@@ -1142,113 +1237,68 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// are copying data in.
if !opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
// Fetch data.
v, perr := p.Payload(avail)
if perr != nil || len(v) == 0 {
- if opts.Atomic { // See above.
+ // Note that perr may be nil if len(v) == 0.
+ if opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
- // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- if opts.Atomic {
+ queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
// Add data to the send queue.
s := newSegmentFromView(&e.route, e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
- }
- if e.workMu.TryLock() {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
-
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
-
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
-
- }
// Do the work inline.
e.handleWrite()
- e.workMu.Unlock()
- } else {
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ e.UnlockUser()
+ return int64(len(v)), nil, nil
+ }
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
+ if opts.Atomic {
+ // Locks released in queueAndSend()
+ return queueAndSend()
+ }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ e.LockUser()
+ e.sndBufMu.Lock()
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
- }
- // Let the protocol goroutine do the work.
- e.sndWaker.Assert()
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
- return int64(len(v)), nil, nil
+ // Locks released in queueAndSend()
+ return queueAndSend()
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
@@ -1339,6 +1389,9 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
// SetSockOptBool sets a socket option.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+
switch opt {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
@@ -1346,9 +1399,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
// We only allow this to be set when we're in the initial state.
if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
@@ -1379,7 +1429,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
mask := uint32(notifyReceiveWindowChanged)
- e.mu.RLock()
+ e.LockUser()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@@ -1409,8 +1459,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
+
e.rcvListMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
e.notifyProtocolGoroutine(mask)
return nil
@@ -1466,15 +1517,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.ReuseAddressOption:
- e.mu.Lock()
+ e.LockUser()
e.reuseAddr = v != 0
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.ReusePortOption:
- e.mu.Lock()
+ e.LockUser()
e.reusePort = v != 0
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.BindToDeviceOption:
@@ -1482,9 +1533,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
- e.mu.Lock()
+ e.LockUser()
e.bindToDevice = id
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.QuickAckOption:
@@ -1500,16 +1551,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
return tcpip.ErrInvalidOptionValue
}
- e.mu.Lock()
+ e.LockUser()
e.userMSS = uint16(userMSS)
- e.mu.Unlock()
+ e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
case tcpip.TTLOption:
- e.mu.Lock()
+ e.LockUser()
e.ttl = uint8(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.KeepaliveEnabledOption:
@@ -1541,15 +1592,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.TCPUserTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
e.userTimeout = time.Duration(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.BroadcastOption:
- e.mu.Lock()
+ e.LockUser()
e.broadcast = v != 0
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.CongestionControlOption:
@@ -1563,22 +1614,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
availCC := strings.Split(string(avail), " ")
for _, cc := range availCC {
if v == tcpip.CongestionControlOption(cc) {
- // Acquire the work mutex as we may need to
- // reinitialize the congestion control state.
- e.mu.Lock()
+ e.LockUser()
state := e.EndpointState()
e.cc = v
- e.mu.Unlock()
switch state {
case StateEstablished:
- e.workMu.Lock()
- e.mu.Lock()
if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
- e.mu.Unlock()
- e.workMu.Unlock()
}
+ e.UnlockUser()
return nil
}
}
@@ -1588,23 +1633,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNoSuchFile
case tcpip.IPv4TOSOption:
- e.mu.Lock()
+ e.LockUser()
// TODO(gvisor.dev/issue/995): ECN is not currently supported,
// ignore the bits for now.
e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
+ e.LockUser()
// TODO(gvisor.dev/issue/995): ECN is not currently supported,
// ignore the bits for now.
e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.TCPLingerTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
if v < 0 {
// Same as effectively disabling TCPLinger timeout.
v = 0
@@ -1622,16 +1667,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
v = stkTCPLingerTimeout
}
e.tcpLingerTimeout = time.Duration(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.TCPDeferAcceptOption:
- e.mu.Lock()
+ e.LockUser()
if time.Duration(v) > MaxRTO {
v = tcpip.TCPDeferAcceptOption(MaxRTO)
}
e.deferAccept = time.Duration(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
default:
@@ -1641,8 +1686,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// readyReceiveSize returns the number of bytes ready to be received.
func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint cannot be in listen state.
if e.EndpointState() == StateListen {
@@ -1664,9 +1709,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return false, tcpip.ErrUnknownProtocolOption
}
- e.mu.Lock()
+ e.LockUser()
v := e.v6only
- e.mu.Unlock()
+ e.UnlockUser()
return v, nil
}
@@ -1730,9 +1775,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.ReuseAddressOption:
- e.mu.RLock()
+ e.LockUser()
v := e.reuseAddr
- e.mu.RUnlock()
+ e.UnlockUser()
*o = 0
if v {
@@ -1741,9 +1786,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.ReusePortOption:
- e.mu.RLock()
+ e.LockUser()
v := e.reusePort
- e.mu.RUnlock()
+ e.UnlockUser()
*o = 0
if v {
@@ -1752,9 +1797,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.BindToDeviceOption:
- e.mu.RLock()
+ e.LockUser()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
+ e.UnlockUser()
return nil
case *tcpip.QuickAckOption:
@@ -1765,16 +1810,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.TTLOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TTLOption(e.ttl)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
- e.mu.RLock()
+ e.LockUser()
snd := e.snd
- e.mu.RUnlock()
+ e.UnlockUser()
if snd != nil {
snd.rtt.Lock()
o.RTT = snd.rtt.srtt
@@ -1813,9 +1858,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.TCPUserTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.OutOfBandInlineOption:
@@ -1824,9 +1869,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.BroadcastOption:
- e.mu.Lock()
+ e.LockUser()
v := e.broadcast
- e.mu.Unlock()
+ e.UnlockUser()
*o = 0
if v {
@@ -1835,33 +1880,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.CongestionControlOption:
- e.mu.Lock()
+ e.LockUser()
*o = e.cc
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.IPv4TOSOption:
- e.mu.RLock()
+ e.LockUser()
*o = tcpip.IPv4TOSOption(e.sendTOS)
- e.mu.RUnlock()
+ e.UnlockUser()
return nil
case *tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
+ e.LockUser()
*o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.mu.RUnlock()
+ e.UnlockUser()
return nil
case *tcpip.TCPLingerTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.TCPDeferAcceptOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
default:
@@ -1901,8 +1946,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// yet accepted by the app, they are restored without running the main goroutine
// here.
func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
connectingAddr := addr.Addr
@@ -2071,9 +2116,13 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.mu.Lock()
+ e.LockUser()
+ defer e.UnlockUser()
+ return e.shutdownLocked(flags)
+}
+
+func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
e.shutdownFlags |= flags
- finQueued := false
switch {
case e.EndpointState().connected():
// Close for read.
@@ -2087,24 +2136,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.mu.Unlock()
- // Try to send an active reset immediately if the
- // work mutex is available.
- if e.workMu.TryLock() {
- e.mu.Lock()
- // We need to double check here to make
- // sure worker has not transitioned the
- // endpoint out of a connected state
- // before trying to send a reset.
- if e.EndpointState().connected() {
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.notifyProtocolGoroutine(notifyTickleWorker)
- }
- e.mu.Unlock()
- e.workMu.Unlock()
- } else {
- e.notifyProtocolGoroutine(notifyReset)
- }
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ // Wake up worker to terminate loop.
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
}
@@ -2116,42 +2150,32 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// Already closed.
e.sndBufMu.Unlock()
if e.EndpointState() == StateTimeWait {
- e.mu.Unlock()
return tcpip.ErrNotConnected
}
- break
+ return nil
}
// Queue fin segment.
s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
- finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
e.sndBufMu.Unlock()
+ e.handleClose()
}
+ return nil
case e.EndpointState() == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
}
+ return nil
+
default:
- e.mu.Unlock()
return tcpip.ErrNotConnected
}
- e.mu.Unlock()
- if finQueued {
- if e.workMu.TryLock() {
- e.handleClose()
- e.workMu.Unlock()
- } else {
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
- }
- }
- return nil
}
// Listen puts the endpoint in "listen" mode, which allows it to accept
@@ -2166,8 +2190,8 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error {
}
func (e *endpoint) listen(backlog int) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
// Allow the backlog to be adjusted if the endpoint is not shutting down.
// When the endpoint shuts down, it sets workerCleanup to true, and from
@@ -2229,7 +2253,6 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop() {
- e.mu.Lock()
e.workerRunning = true
e.mu.Unlock()
wakerInitDone := make(chan struct{})
@@ -2240,8 +2263,8 @@ func (e *endpoint) startAcceptedLoop() {
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// Endpoint must be in listen state before it can accept connections.
if e.EndpointState() != StateListen {
@@ -2260,8 +2283,8 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
// Bind binds the endpoint to a specific local port and optionally address.
func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
return e.bindLocked(addr)
}
@@ -2339,8 +2362,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// GetLocalAddress returns the address to which the endpoint is bound.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
return tcpip.FullAddress{
Addr: e.ID.LocalAddress,
@@ -2351,8 +2374,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
// GetRemoteAddress returns the address to which the endpoint is connected.
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
@@ -2419,7 +2442,6 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
func (e *endpoint) readyToRead(s *segment) {
- e.mu.RLock()
e.rcvListMu.Lock()
if s != nil {
s.incRef()
@@ -2434,7 +2456,6 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
- e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventIn)
}
@@ -2578,9 +2599,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
s.SegTime = time.Now()
// Copy EndpointID.
- e.mu.Lock()
s.ID = stack.TCPEndpointID(e.ID)
- e.mu.Unlock()
// Copy endpoint rcv state.
e.rcvListMu.Lock()
@@ -2710,10 +2729,10 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
+ e.LockUser()
// Make a copy of the endpoint info.
ret := e.EndpointInfo
- e.mu.RUnlock()
+ e.UnlockUser()
return &ret
}
@@ -2728,9 +2747,9 @@ func (e *endpoint) Wait() {
e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
defer e.waiterQueue.EventUnregister(&waitEntry)
for {
- e.mu.Lock()
+ e.LockUser()
running := e.workerRunning
- e.mu.Unlock()
+ e.UnlockUser()
if !running {
break
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 4a46f0ec5..9175de441 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -162,8 +162,8 @@ func (e *endpoint) loadState(state EndpointState) {
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
- // as the endpoint is still being loaded and the stack reference to increment
- // metrics is not yet initialized.
+ // as the endpoint is still being loaded and the stack reference is not
+ // yet initialized.
atomic.StoreUint32((*uint32)(&e.state), uint32(state))
}
@@ -180,7 +180,6 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
state := e.origEndpointState
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 73098d904..b0f918bb4 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -95,7 +95,7 @@ const (
)
type protocol struct {
- mu sync.Mutex
+ mu sync.RWMutex
sackEnabled bool
delayEnabled bool
sendBufferSize SendBufferSizeOption
@@ -273,57 +273,57 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
- p.mu.Lock()
+ p.mu.RLock()
*v = SACKEnabled(p.sackEnabled)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *DelayEnabled:
- p.mu.Lock()
+ p.mu.RLock()
*v = DelayEnabled(p.delayEnabled)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *SendBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.sendBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *ReceiveBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.recvBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.CongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.CongestionControlOption(p.congestionControl)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.AvailableCongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.ModerateReceiveBufferOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.TCPLingerTimeoutOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index d80aff1b6..caf8977b3 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -168,7 +168,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
- r.ep.mu.Lock()
switch r.ep.EndpointState() {
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
@@ -183,7 +182,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateFinWait2:
r.ep.setEndpointState(StateTimeWait)
}
- r.ep.mu.Unlock()
// Flush out any pending segments, except the very first one if
// it happens to be the one we're handling now because the
@@ -208,7 +206,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
- r.ep.mu.Lock()
switch r.ep.EndpointState() {
case StateFinWait1:
r.ep.setEndpointState(StateFinWait2)
@@ -222,7 +219,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateLastAck:
r.ep.transitionToStateCloseLocked()
}
- r.ep.mu.Unlock()
}
return true
@@ -336,10 +332,8 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
- r.ep.mu.RLock()
state := r.ep.EndpointState()
closed := r.ep.closed
- r.ep.mu.RUnlock()
if state != StateEstablished {
drop, err := r.handleRcvdSegmentClosing(s, state, closed)
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index bd20a7ee9..48a257137 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -28,10 +28,16 @@ type segmentQueue struct {
used int
}
+// emptyLocked determines if the queue is empty.
+// Preconditions: q.mu must be held.
+func (q *segmentQueue) emptyLocked() bool {
+ return q.used == 0
+}
+
// empty determines if the queue is empty.
func (q *segmentQueue) empty() bool {
q.mu.Lock()
- r := q.used == 0
+ r := q.emptyLocked()
q.mu.Unlock()
return r
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 657c3146e..17fed4ec5 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -455,9 +455,7 @@ func (s *sender) retransmitTimerExpired() bool {
// Give up if we've waited more than a minute since the last resend or
// if a user time out is set and we have exceeded the user specified
// timeout since the first retransmission.
- s.ep.mu.RLock()
uto := s.ep.userTimeout
- s.ep.mu.RUnlock()
if s.firstRetransmittedSegXmitTime.IsZero() {
// We store the original xmitTime of the segment that we are
@@ -713,7 +711,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
default:
s.ep.setEndpointState(StateFinWait1)
}
-
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 5b2b16afa..39d36d2ba 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2236,9 +2236,17 @@ func TestSegmentMerging(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- // Prevent the endpoint from processing packets.
- test.stop(c.EP)
+ // Send 10 1 byte segments to fill up InitialWindow but don't
+ // ACK. That should prevent anymore packets from going out.
+ for i := 0; i < 10; i++ {
+ view := buffer.NewViewFromBytes([]byte{0})
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %v", i+1, err)
+ }
+ }
+ // Now send the segments that should get merged as the congestion
+ // window is full and we won't be able to send any more packets.
var allData []byte
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
@@ -2248,8 +2256,29 @@ func TestSegmentMerging(t *testing.T) {
}
}
- // Let the endpoint process the segments that we just sent.
- test.resume(c.EP)
+ // Check that we get 10 packets of 1 byte each.
+ for i := 0; i < 10; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(header.TCPMinimumSize+1),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
+ RcvWnd: 30000,
+ })
// Check that data is received.
b := c.GetPacket()
@@ -2257,7 +2286,7 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+11),
checker.AckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
@@ -2273,7 +2302,7 @@ func TestSegmentMerging(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))),
+ AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
RcvWnd: 30000,
})
})