diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/rcv.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 173 |
1 files changed, 81 insertions, 92 deletions
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index bc6793fc6..fc11b4ba9 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // receiver holds the state necessary to receive TCP segments and turn them @@ -29,26 +30,15 @@ import ( // // +stateify savable type receiver struct { + stack.TCPReceiverState ep *endpoint - rcvNxt seqnum.Value - - // rcvAcc is one beyond the last acceptable sequence number. That is, - // the "largest" sequence value that the receiver has announced to the - // its peer that it's willing to accept. This may be different than - // rcvNxt + rcvWnd if the receive window is reduced; in that case we - // have to reduce the window as we receive more data instead of - // shrinking it. - rcvAcc seqnum.Value - // rcvWnd is the non-scaled receive window last advertised to the peer. rcvWnd seqnum.Size - // rcvWUP is the rcvNxt value at the last window update sent. + // rcvWUP is the RcvNxt value at the last window update sent. rcvWUP seqnum.Value - rcvWndScale uint8 - // prevBufused is the snapshot of endpoint rcvBufUsed taken when we // advertise a receive window. prevBufUsed int @@ -58,9 +48,6 @@ type receiver struct { // pendingRcvdSegments is bounded by the receive buffer size of the // endpoint. pendingRcvdSegments segmentHeap - // pendingBufUsed tracks the total number of bytes (including segment - // overhead) currently queued in pendingRcvdSegments. - pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` @@ -68,12 +55,14 @@ type receiver struct { func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), + ep: ep, + TCPReceiverState: stack.TCPReceiverState{ + RcvNxt: irs + 1, + RcvAcc: irs.Add(rcvWnd + 1), + RcvWndScale: rcvWndScale, + }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - rcvWndScale: rcvWndScale, lastRcvdAckTime: time.Now(), } } @@ -84,34 +73,34 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // r.rcvWnd could be much larger than the window size we advertised in our // outgoing packets, we should use what we have advertised for acceptability // test. - scaledWindowSize := r.rcvWnd >> r.rcvWndScale + scaledWindowSize := r.rcvWnd >> r.RcvWndScale if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. scaledWindowSize = math.MaxUint16 } - advertisedWindowSize := scaledWindowSize << r.rcvWndScale - return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) + advertisedWindowSize := scaledWindowSize << r.RcvWndScale + return header.Acceptable(segSeq, segLen, r.RcvNxt, r.RcvNxt.Add(advertisedWindowSize)) } // currentWindow returns the available space in the window that was advertised // last to our peer. func (r *receiver) currentWindow() (curWnd seqnum.Size) { endOfWnd := r.rcvWUP.Add(r.rcvWnd) - if endOfWnd.LessThan(r.rcvNxt) { - // return 0 if r.rcvNxt is past the end of the previously advertised window. + if endOfWnd.LessThan(r.RcvNxt) { + // return 0 if r.RcvNxt is past the end of the previously advertised window. // This can happen because we accept a large segment completely even if // accepting it causes it to partially exceed the advertised window. return 0 } - return r.rcvNxt.Size(endOfWnd) + return r.RcvNxt.Size(endOfWnd) } // getSendParams returns the parameters needed by the sender when building // segments to send. -func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { +func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() - unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + unackLen := int(r.ep.snd.MaxSentAck.Size(r.RcvNxt)) bufUsed := r.ep.receiveBufferUsed() // Grow the right edge of the window only for payloads larger than the @@ -139,18 +128,18 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // edge, as we are still advertising a window that we think can be serviced. toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed - // Update rcvAcc only if new window is > previously advertised window. We + // Update RcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we // would end up dropping bytes that might already be in flight. // ==================================================== sequence space. // ^ ^ ^ ^ - // rcvWUP rcvNxt rcvAcc new rcvAcc + // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { - // If the new window moves the right edge, then update rcvAcc. - r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) + if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + // If the new window moves the right edge, then update RcvAcc. + r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -162,9 +151,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd - r.rcvWUP = r.rcvNxt + r.rcvWUP = r.RcvNxt r.prevBufUsed = bufUsed - scaledWnd := r.rcvWnd >> r.rcvWndScale + scaledWnd := r.rcvWnd >> r.RcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() @@ -177,9 +166,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Ensure that the stashed receive window always reflects what // is being advertised. - r.rcvWnd = scaledWnd << r.rcvWndScale + r.rcvWnd = scaledWnd << r.RcvWndScale } - return r.rcvNxt, scaledWnd + return r.RcvNxt, scaledWnd } // nonZeroWindow is called when the receive window grows from zero to nonzero; @@ -201,13 +190,13 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // If the segment doesn't include the seqnum we're expecting to // consume now, we're missing a segment. We cannot proceed until // we receive that segment though. - if !r.rcvNxt.InWindow(segSeq, segLen) { + if !r.RcvNxt.InWindow(segSeq, segLen) { return false } // Trim segment to eliminate already acknowledged data. - if segSeq.LessThan(r.rcvNxt) { - diff := segSeq.Size(r.rcvNxt) + if segSeq.LessThan(r.RcvNxt) { + diff := segSeq.Size(r.RcvNxt) segLen -= diff segSeq.UpdateForward(diff) s.sequenceNumber.UpdateForward(diff) @@ -217,35 +206,35 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Move segment to ready-to-deliver list. Wakeup any waiters. r.ep.readyToRead(s) - } else if segSeq != r.rcvNxt { + } else if segSeq != r.RcvNxt { return false } // Update the segment that we're expecting to consume. - r.rcvNxt = segSeq.Add(segLen) + r.RcvNxt = segSeq.Add(segLen) // In cases of a misbehaving sender which could send more than the // advertised window, we could end up in a situation where we get a // segment that exceeds the window advertised. Instead of partially // accepting the segment and discarding bytes beyond the advertised - // window, we accept the whole segment and make sure r.rcvAcc is moved - // forward to match r.rcvNxt to indicate that the window is now closed. + // window, we accept the whole segment and make sure r.RcvAcc is moved + // forward to match r.RcvNxt to indicate that the window is now closed. // // In absence of this check the r.acceptable() check fails and accepts // segments that should be dropped because rcvWnd is calculated as - // the size of the interval (rcvNxt, rcvAcc] which becomes extremely - // large if rcvAcc is ever less than rcvNxt. - if r.rcvAcc.LessThan(r.rcvNxt) { - r.rcvAcc = r.rcvNxt + // the size of the interval (RcvNxt, RcvAcc] which becomes extremely + // large if RcvAcc is ever less than RcvNxt. + if r.RcvAcc.LessThan(r.RcvNxt) { + r.RcvAcc = r.RcvNxt } // Trim SACK Blocks to remove any SACK information that covers // sequence numbers that have been consumed. - TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { - r.rcvNxt++ + r.RcvNxt++ // Send ACK immediately. r.ep.snd.sendAck() @@ -260,7 +249,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -280,7 +269,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { - r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() + r.PendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() // Note that slice truncation does not allow garbage collection of @@ -295,7 +284,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -323,40 +312,40 @@ func (r *receiver) updateRTT() { // estimate the round-trip time by observing the time between when a byte // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. - r.ep.rcvListMu.Lock() - if r.ep.rcvAutoParams.rttMeasureTime.IsZero() { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { // New measurement. - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) { - r.ep.rcvListMu.Unlock() + if r.RcvNxt.LessThan(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber) { + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime) + rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. - if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt { - r.ep.rcvAutoParams.rtt = rtt + if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { + r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { - r.ep.rcvListMu.Lock() - rcvClosed := r.ep.rcvClosed || r.closed - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := r.ep.rcvQueueInfo.RcvClosed || r.closed + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { case StateCloseWait, StateClosing, StateLastAck: - if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + if !s.sequenceNumber.LessThanEq(r.RcvNxt) { // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -384,17 +373,17 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // The ESTABLISHED state processing is here where if the ACK check // fails, we ignore the packet: // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + if r.ep.snd.SndNxt.LessThan(s.ackNumber) { r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., - // SHUT_RD) then any data past the rcvNxt should + // SHUT_RD) then any data past the RcvNxt should // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + if state != StateCloseWait && rcvClosed && r.RcvNxt.LessThan(endDataSeq) { return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { @@ -403,7 +392,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If it's a retransmission of an old data segment // or a pure ACK then allow it. - if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.RcvNxt) || s.logicalLen() == 0 { break } @@ -413,7 +402,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // then the only acceptable segment is a // FIN. Since FIN can technically also carry // data we verify that the segment carrying a - // FIN ends at exactly e.rcvNxt+1. + // FIN ends at exactly e.RcvNxt+1. // // From RFC793 page 25. // @@ -423,7 +412,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -435,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // end has closed and the peer is yet to send a FIN. Hence we // compare only the payload. segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + if rcvClosed && !segEnd.LessThanEq(r.RcvNxt) { return true, nil } return false, nil @@ -477,13 +466,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { - r.ep.rcvListMu.Lock() - r.pendingBufUsed += s.segMemSize() - r.ep.rcvListMu.Unlock() + if r.ep.receiveBufferAvailable() > 0 && r.PendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed += s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) - UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.RcvNxt) } // Immediately send an ack so that the peer knows it may @@ -508,15 +497,15 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { segSeq := s.sequenceNumber // Skip segment altogether if it has already been acknowledged. - if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + if !segSeq.Add(segLen-1).LessThan(r.RcvNxt) && !r.consumeSegment(s, segSeq, segLen) { break } heap.Pop(&r.pendingRcvdSegments) - r.ep.rcvListMu.Lock() - r.pendingBufUsed -= s.segMemSize() - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed -= s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.decRef() } return false, nil @@ -558,7 +547,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } @@ -569,11 +558,11 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn } // Update Timestamp if required. See RFC7323, section-4.3. - if r.ep.sendTSOk && s.parsedOptions.TS { - r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + if r.ep.SendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() @@ -584,8 +573,8 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // carries data then just send an ACK. This is according to RFC 793, // page 37. // - // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. - if segSeq != r.rcvNxt || segLen != 0 { + // NOTE: In TIME_WAIT the only acceptable sequence number is RcvNxt. + if segSeq != r.RcvNxt || segLen != 0 { r.ep.snd.sendAck() } return false, false |