diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 445 |
1 files changed, 294 insertions, 151 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 65c346046..b706438bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -17,13 +17,14 @@ package tcp import ( "crypto/sha1" "encoding/binary" + "fmt" "hash" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -48,17 +49,14 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 -) -var ( - // SynRcvdCountThreshold is the global maximum number of connections - // that are allowed to be in SYN-RCVD state before TCP starts using SYN - // cookies to accept connections. - // - // It is an exported variable only for testing, and should not otherwise - // be used by importers of this package. + // SynRcvdCountThreshold is the default global maximum number of + // connections that are allowed to be in SYN-RCVD state before TCP + // starts using SYN cookies to accept connections. SynRcvdCountThreshold uint64 = 1000 +) +var ( // mssTable is a slice containing the possible MSS values that we // encode in the SYN cookie with two bits. mssTable = []uint16{536, 1300, 1440, 1460} @@ -73,29 +71,42 @@ func encodeMSS(mss uint16) uint32 { return 0 } -// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is -// protected by a mutex so that we can increment only when it's guaranteed not -// to go above a threshold. -var synRcvdCount struct { - sync.Mutex - value uint64 - pending sync.WaitGroup -} - // listenContext is used by a listening endpoint to store state used while // listening for connections. This struct is allocated by the listen goroutine // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack - rcvWnd seqnum.Size - nonce [2][sha1.BlockSize]byte + stack *stack.Stack + + // synRcvdCount is a reference to the stack level synRcvdCount. + synRcvdCount *synRcvdCounter + + // rcvWnd is the receive window that is sent by this listening context + // in the initial SYN-ACK. + rcvWnd seqnum.Size + + // nonce are random bytes that are initialized once when the context + // is created and used to seed the hash function when generating + // the SYN cookie. + nonce [2][sha1.BlockSize]byte + + // listenEP is a reference to the listening endpoint associated with + // this context. Can be nil if the context is created by the forwarder. listenEP *endpoint + // hasherMu protects hasher. hasherMu sync.Mutex - hasher hash.Hash - v6only bool + // hasher is the hash function used to generate a SYN cookie. + hasher hash.Hash + + // v6Only is true if listenEP is a dual stack socket and has the + // IPV6_V6ONLY option set. + v6Only bool + + // netProto indicates the network protocol(IPv4/v6) for the listening + // endpoint. netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed // by the listening endpoint's worker goroutine. // @@ -114,55 +125,22 @@ func timeStamp() uint32 { return uint32(time.Now().Unix()>>6) & tsMask } -// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD -// state. It succeeds if the increment doesn't make the count go beyond the -// threshold, and fails otherwise. -func incSynRcvdCount() bool { - synRcvdCount.Lock() - - if synRcvdCount.value >= SynRcvdCountThreshold { - synRcvdCount.Unlock() - return false - } - - synRcvdCount.pending.Add(1) - synRcvdCount.value++ - - synRcvdCount.Unlock() - return true -} - -// decSynRcvdCount atomically decrements the global number of endpoints in -// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount -// succeeded. -func decSynRcvdCount() { - synRcvdCount.Lock() - - synRcvdCount.value-- - synRcvdCount.pending.Done() - synRcvdCount.Unlock() -} - -// synCookiesInUse() returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func synCookiesInUse() bool { - synRcvdCount.Lock() - v := synRcvdCount.value - synRcvdCount.Unlock() - return v >= SynRcvdCountThreshold -} - // newListenContext creates a new listen context. -func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ stack: stk, rcvWnd: rcvWnd, hasher: sha1.New(), - v6only: v6only, + v6Only: v6Only, netProto: netProto, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } + p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) + if !ok { + panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) + } + l.synRcvdCount = p.SynRcvdCounter() rand.Read(l.nonce[0][:]) rand.Read(l.nonce[1][:]) @@ -221,85 +199,119 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { // Create a new endpoint. netProto := l.netProto if netProto == 0 { netProto = s.route.NetProto } - n := newEndpoint(l.stack, netProto, nil) - n.v6only = l.v6only + n := newEndpoint(l.stack, netProto, queue) + n.v6only = l.v6Only n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) - n.amss = mssForRoute(&n.route) + n.amss = calculateAdvertisedMSS(n.userMSS, n.route) + n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) n.initGSO() - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { - n.Close() - return nil, err - } - - n.isRegistered = true - - // Create sender and receiver. - // - // The receiver at least temporarily has a zero receive window scale, - // but the caller may change it (before starting the protocol loop). - n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) - n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize())) // Bootstrap the auto tuning algorithm. Starting at zero will result in // a large step function on the first window adjustment causing the // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - return n, nil + return n } -// createEndpoint creates a new endpoint in connected state and then performs -// the TCP 3-way handshake. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. +// +// The new endpoint is returned with e.mu held. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) - if err != nil { - return nil, err - } + isn := generateSecureISN(s.id, l.stack.Seed()) + ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + ep.mu.Lock() + ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) if l.listenEP != nil { l.listenEP.mu.Lock() - if l.listenEP.state != StateListen { + if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() + // Ensure we release any registrations done by the newly + // created endpoint. + ep.mu.Unlock() + ep.Close() + return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + l.listenEP.propagateInheritableOptionsLocked(ep) + + if !ep.reserveTupleLocked() { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + + return nil, tcpip.ErrConnectionAborted + } + + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } - // Perform the 3-way handshake. - h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) + // Register new endpoint so that packets are routed to it. + if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } - h.resetToSynRcvd(cookie, irs, opts) + ep.drainClosingSegmentQueue() + + return nil, err + } + + ep.isRegistered = true + + // Perform the 3-way handshake. + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { + ep.mu.Unlock() ep.Close() + ep.notifyAborted() + if l.listenEP != nil { l.removePendingEndpoint(ep) } + + ep.drainClosingSegmentQueue() + return nil, err } - ep.mu.Lock() - ep.stack.Stats().TCP.CurrentEstablished.Increment() - ep.state = StateEstablished - ep.mu.Unlock() + ep.isConnectNotified = true // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -333,23 +345,78 @@ func (l *listenContext) closeAllPendingEndpoints() { } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state, the new endpoint is closed -// instead. +// endpoint has transitioned out of the listen state (acceptedChan is nil), +// the new endpoint is closed instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.state e.pendingAccepted.Add(1) - defer e.pendingAccepted.Done() - acceptedChan := e.acceptedChan e.mu.Unlock() - if state == StateListen { - acceptedChan <- n - e.waiterQueue.Notify(waiter.EventIn) - } else { - n.Close() + defer e.pendingAccepted.Done() + + e.acceptMu.Lock() + for { + if e.acceptedChan == nil { + e.acceptMu.Unlock() + n.notifyProtocolGoroutine(notifyReset) + return + } + select { + case e.acceptedChan <- n: + e.acceptMu.Unlock() + e.waiterQueue.Notify(waiter.EventIn) + return + default: + e.acceptCond.Wait() + } } } +// propagateInheritableOptionsLocked propagates any options set on the listening +// endpoint to the newly created endpoint. +// +// Precondition: e.mu and n.mu must be held. +func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { + n.userTimeout = e.userTimeout + n.portFlags = e.portFlags + n.boundBindToDevice = e.boundBindToDevice + n.boundPortFlags = e.boundPortFlags + n.userMSS = e.userMSS +} + +// reserveTupleLocked reserves an accepted endpoint's tuple. +// +// Preconditions: +// * propagateInheritableOptionsLocked has been called. +// * e.mu is held. +func (e *endpoint) reserveTupleLocked() bool { + dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + if !e.stack.ReserveTuple( + e.effectiveNetProtos, + ProtocolNumber, + e.ID.LocalAddress, + e.ID.LocalPort, + e.boundPortFlags, + e.boundBindToDevice, + dest, + ) { + return false + } + + e.isPortReserved = true + e.boundDest = dest + return true +} + +// notifyAborted wakes up any waiters on registered, but not accepted +// endpoints. +// +// This is strictly not required normally as a socket that was never accepted +// can't really have any registered waiters except when stack.Wait() is called +// which waits for all registered endpoints to stop and expects an EventHUp. +func (e *endpoint) notifyAborted() { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. @@ -357,53 +424,68 @@ func (e *endpoint) deliverAccepted(n *endpoint) { // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer decSynRcvdCount() - defer e.decSynRcvdCount() + defer ctx.synRcvdCount.dec() + defer func() { + e.mu.Lock() + e.decSynRcvdCount() + e.mu.Unlock() + }() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) + + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } func (e *endpoint) incSynRcvdCount() bool { - e.mu.Lock() - if e.synRcvdCount >= cap(e.acceptedChan) { - e.mu.Unlock() - return false + e.acceptMu.Lock() + canInc := e.synRcvdCount < cap(e.acceptedChan) + e.acceptMu.Unlock() + if canInc { + e.synRcvdCount++ } - e.synRcvdCount++ - e.mu.Unlock() - return true + return canInc } func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() e.synRcvdCount-- - e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { - e.mu.Lock() - if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { - e.mu.Unlock() - return true - } - e.mu.Unlock() - return false + e.acceptMu.Lock() + full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + e.acceptMu.Unlock() + return full } // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - switch s.flags { - case header.TCPFlagSyn: + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() + if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + // If the endpoint is shutdown, reply with reset. + // + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + replyWithReset(s, e.sendTOS, e.ttl) + return + } + + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) - if incSynRcvdCount() { + if ctx.synRcvdCount.inc() { // Only handle the syn if the following conditions hold // - accept queue is not full. // - number of connections in synRcvd state is less than the @@ -413,7 +495,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. return } - decSynRcvdCount() + ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -430,23 +512,33 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) // Send SYN without window scaling because we currently - // dont't encode this information in the cookie. + // don't encode this information in the cookie. // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(timeStampOffset()), + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: calculateAdvertisedMSS(e.userMSS, s.route), } - e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + }, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -459,7 +551,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } - if !synCookiesInUse() { + if !ctx.synRcvdCount.synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. @@ -467,10 +567,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s) + replyWithReset(s, e.sendTOS, e.ttl) return } + iss := s.ackNumber - 1 + irs := s.sequenceNumber - 1 + // Since SYN cookies are in use this is potentially an ACK to a // SYN-ACK we sent but don't have a half open connection state // as cookies are being used to protect against a potential SYN @@ -481,7 +584,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // when under a potential syn flood attack. // // Validate the cookie. - data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) + data, ok := ctx.isCookieValid(s.id, iss, irs) if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -506,13 +609,35 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) - if err != nil { + n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + + n.mu.Lock() + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + + if !n.reserveTupleLocked() { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } + n.isRegistered = true + // clear the tsOffset for the newly created // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was @@ -520,8 +645,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - n.stack.Stats().TCP.CurrentEstablished.Increment() - n.state = StateEstablished + n.isConnectNotified = true + n.transitionToStateEstablishedLocked(&handshake{ + ep: n, + iss: iss, + ackNum: irs + 1, + rcvWnd: seqnum.Size(n.initialReceiveWindow()), + sndWnd: s.window, + rcvWndScale: e.rcvWndScaleForHandshake(), + sndWndScale: rcvdSynOptions.WS, + mss: rcvdSynOptions.MSS, + }) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -532,6 +666,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } @@ -540,16 +678,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) + v6Only := e.v6only + ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections // to the endpoint. - e.mu.Lock() - e.state = StateClose + e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() @@ -562,15 +698,20 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { } e.mu.Unlock() + e.drainClosingSegmentQueue() + // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { - switch index, _ := s.Fetch(true); index { + e.mu.Unlock() + index, _ := s.Fetch(true) + e.mu.Lock() + switch index { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { @@ -583,7 +724,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.decRef() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } case wakerForNewSegment: |