diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/snd.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 628 |
1 files changed, 628 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go new file mode 100644 index 000000000..ad94aecd8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/snd.go @@ -0,0 +1,628 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "math" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +const ( + // minRTO is the minimum allowed value for the retransmit timeout. + minRTO = 200 * time.Millisecond + + // InitialCwnd is the initial congestion window. + InitialCwnd = 10 +) + +// sender holds the state necessary to send TCP segments. +type sender struct { + ep *endpoint + + // lastSendTime is the timestamp when the last packet was sent. + lastSendTime time.Time + + // dupAckCount is the number of duplicated acks received. It is used for + // fast retransmit. + dupAckCount int + + // fr holds state related to fast recovery. + fr fastRecovery + + // sndCwnd is the congestion window, in packets. + sndCwnd int + + // sndSsthresh is the threshold between slow start and congestion + // avoidance. + sndSsthresh int + + // sndCAAckCount is the number of packets acknowledged during congestion + // avoidance. When enough packets have been ack'd (typically cwnd + // packets), the congestion window is incremented by one. + sndCAAckCount int + + // outstanding is the number of outstanding packets, that is, packets + // that have been sent but not yet acknowledged. + outstanding int + + // sndWnd is the send window size. + sndWnd seqnum.Size + + // sndUna is the next unacknowledged sequence number. + sndUna seqnum.Value + + // sndNxt is the sequence number of the next segment to be sent. + sndNxt seqnum.Value + + // sndNxtList is the sequence number of the next segment to be added to + // the send list. + sndNxtList seqnum.Value + + // rttMeasureSeqNum is the sequence number being used for the latest RTT + // measurement. + rttMeasureSeqNum seqnum.Value + + // rttMeasureTime is the time when the rttMeasureSeqNum was sent. + rttMeasureTime time.Time + + closed bool + writeNext *segment + writeList segmentList + resendTimer timer `state:"nosave"` + resendWaker sleep.Waker `state:"nosave"` + + // srtt, rttvar & rto are the "smoothed round-trip time", "round-trip + // time variation" and "retransmit timeout", as defined in section 2 of + // RFC 6298. + srtt time.Duration + rttvar time.Duration + rto time.Duration + srttInited bool + + // maxPayloadSize is the maximum size of the payload of a given segment. + // It is initialized on demand. + maxPayloadSize int + + // sndWndScale is the number of bits to shift left when reading the send + // window size from a segment. + sndWndScale uint8 + + // maxSentAck is the maxium acknowledgement actually sent. + maxSentAck seqnum.Value +} + +// fastRecovery holds information related to fast recovery from a packet loss. +type fastRecovery struct { + // active whether the endpoint is in fast recovery. The following fields + // are only meaningful when active is true. + active bool + + // first and last represent the inclusive sequence number range being + // recovered. + first seqnum.Value + last seqnum.Value + + // maxCwnd is the maximum value the congestion window may be inflated to + // due to duplicate acks. This exists to avoid attacks where the + // receiver intentionally sends duplicate acks to artificially inflate + // the sender's cwnd. + maxCwnd int +} + +func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { + s := &sender{ + ep: ep, + sndCwnd: InitialCwnd, + sndSsthresh: math.MaxInt64, + sndWnd: sndWnd, + sndUna: iss + 1, + sndNxt: iss + 1, + sndNxtList: iss + 1, + rto: 1 * time.Second, + rttMeasureSeqNum: iss + 1, + lastSendTime: time.Now(), + maxPayloadSize: int(mss), + maxSentAck: irs + 1, + fr: fastRecovery{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + last: iss, + }, + } + + // A negative sndWndScale means that no scaling is in use, otherwise we + // store the scaling value. + if sndWndScale > 0 { + s.sndWndScale = uint8(sndWndScale) + } + + s.updateMaxPayloadSize(int(ep.route.MTU()), 0) + + s.resendTimer.init(&s.resendWaker) + + return s +} + +// updateMaxPayloadSize updates the maximum payload size based on the given +// MTU. If this is in response to "packet too big" control packets (indicated +// by the count argument), it also reduces the number of oustanding packets and +// attempts to retransmit the first packet above the MTU size. +func (s *sender) updateMaxPayloadSize(mtu, count int) { + m := mtu - header.TCPMinimumSize + + // Calculate the maximum option size. + var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock + options := s.ep.makeOptions(maxSackBlocks[:]) + m -= len(options) + putOptions(options) + + // We don't adjust up for now. + if m >= s.maxPayloadSize { + return + } + + // Make sure we can transmit at least one byte. + if m <= 0 { + m = 1 + } + + s.maxPayloadSize = m + + s.outstanding -= count + if s.outstanding < 0 { + s.outstanding = 0 + } + + // Rewind writeNext to the first segment exceeding the MTU. Do nothing + // if it is already before such a packet. + for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + if seg == s.writeNext { + // We got to writeNext before we could find a segment + // exceeding the MTU. + break + } + + if seg.data.Size() > m { + // We found a segment exceeding the MTU. Rewind + // writeNext and try to retransmit it. + s.writeNext = seg + break + } + } + + // Since we likely reduced the number of outstanding packets, we may be + // ready to send some more. + s.sendData() +} + +// sendAck sends an ACK segment. +func (s *sender) sendAck() { + s.sendSegment(nil, flagAck, s.sndNxt) +} + +// updateRTO updates the retransmit timeout when a new roud-trip time is +// available. This is done in accordance with section 2 of RFC 6298. +func (s *sender) updateRTO(rtt time.Duration) { + if !s.srttInited { + s.rttvar = rtt / 2 + s.srtt = rtt + s.srttInited = true + } else { + diff := s.srtt - rtt + if diff < 0 { + diff = -diff + } + s.rttvar = (3*s.rttvar + diff) / 4 + s.srtt = (7*s.srtt + rtt) / 8 + } + + s.rto = s.srtt + 4*s.rttvar + if s.rto < minRTO { + s.rto = minRTO + } +} + +// resendSegment resends the first unacknowledged segment. +func (s *sender) resendSegment() { + // Don't use any segments we already sent to measure RTT as they may + // have been affected by packets being lost. + s.rttMeasureSeqNum = s.sndNxt + + // Resend the segment. + if seg := s.writeList.Front(); seg != nil { + s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber) + } +} + +// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, +// page 6, eq. 4. It is called when we detect congestion in the network. +func (s *sender) reduceSlowStartThreshold() { + s.sndSsthresh = s.outstanding / 2 + if s.sndSsthresh < 2 { + s.sndSsthresh = 2 + } +} + +// retransmitTimerExpired is called when the retransmit timer expires, and +// unacknowledged segments are assumed lost, and thus need to be resent. +// Returns true if the connection is still usable, or false if the connection +// is deemed lost. +func (s *sender) retransmitTimerExpired() bool { + // Check if the timer actually expired or if it's a spurious wake due + // to a previously orphaned runtime timer. + if !s.resendTimer.checkExpiration() { + return true + } + + // Give up if we've waited more than a minute since the last resend. + if s.rto >= 60*time.Second { + return false + } + + // Set new timeout. The timer will be restarted by the call to sendData + // below. + s.rto *= 2 + + if s.fr.active { + // We were attempting fast recovery but were not successful. + // Leave the state. We don't need to update ssthresh because it + // has already been updated when entered fast-recovery. + s.leaveFastRecovery() + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. + // We store the highest sequence number transmitted in cases where + // we were not in fast recovery. + s.fr.last = s.sndNxt - 1 + + // We lost a packet, so reduce ssthresh. + s.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + s.sndCwnd = 1 + + // Mark the next segment to be sent as the first unacknowledged one and + // start sending again. Set the number of outstanding packets to 0 so + // that we'll be able to retransmit. + // + // We'll keep on transmitting (or retransmitting) as we get acks for + // the data we transmit. + s.outstanding = 0 + s.writeNext = s.writeList.Front() + s.sendData() + + return true +} + +// sendData sends new data segments. It is called when data becomes available or +// when the send window opens up. +func (s *sender) sendData() { + limit := s.maxPayloadSize + + // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. + // "A TCP SHOULD set cwnd to no more than RW before beginning + // transmission if the TCP has not sent data in the interval exceeding + // the retrasmission timeout." + if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { + if s.sndCwnd > InitialCwnd { + s.sndCwnd = InitialCwnd + } + } + + // TODO: We currently don't merge multiple send buffers + // into one segment if they happen to fit. We should do that + // eventually. + var seg *segment + end := s.sndUna.Add(s.sndWnd) + for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + // We abuse the flags field to determine if we have already + // assigned a sequence number to this segment. + if seg.flags == 0 { + seg.sequenceNumber = s.sndNxt + seg.flags = flagAck | flagPsh + } + + var segEnd seqnum.Value + if seg.data.Size() == 0 { + seg.flags = flagAck + + s.ep.rcvListMu.Lock() + rcvBufUsed := s.ep.rcvBufUsed + s.ep.rcvListMu.Unlock() + + s.ep.mu.Lock() + // We're sending a FIN by default + fl := flagFin + if (s.ep.shutdownFlags&tcpip.ShutdownRead) != 0 && rcvBufUsed > 0 { + // If there is unread data we must send a RST. + // For more information see RFC 2525 section 2.17. + fl = flagRst + } + s.ep.mu.Unlock() + seg.flags |= uint8(fl) + + segEnd = seg.sequenceNumber.Add(1) + } else { + // We're sending a non-FIN segment. + if !seg.sequenceNumber.LessThan(end) { + break + } + + available := int(seg.sequenceNumber.Size(end)) + if available > limit { + available = limit + } + + if seg.data.Size() > available { + // Split this segment up. + nSeg := seg.clone() + nSeg.data.TrimFront(available) + nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) + s.writeList.InsertAfter(seg, nSeg) + seg.data.CapLength(available) + } + + s.outstanding++ + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + } + + s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber) + + // Update sndNxt if we actually sent new data (as opposed to + // retransmitting some previously sent data). + if s.sndNxt.LessThan(segEnd) { + s.sndNxt = segEnd + } + } + + // Remember the next segment we'll write. + s.writeNext = seg + + // Enable the timer if we have pending data and it's not enabled yet. + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + s.resendTimer.enable(s.rto) + } +} + +func (s *sender) enterFastRecovery() { + // Save state to reflect we're now in fast recovery. + s.reduceSlowStartThreshold() + // Save state to reflect we're now in fast recovery. + // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. + // We inflat the cwnd by 3 to account for the 3 packets which triggered + // the 3 duplicate ACKs and are now not in flight. + s.sndCwnd = s.sndSsthresh + 3 + s.fr.first = s.sndUna + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = s.sndCwnd + s.outstanding + s.fr.active = true +} + +func (s *sender) leaveFastRecovery() { + s.fr.active = false + s.fr.first = 0 + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = 0 + s.dupAckCount = 0 + + // Deflate cwnd. It had been artificially inflated when new dups arrived. + s.sndCwnd = s.sndSsthresh +} + +// checkDuplicateAck is called when an ack is received. It manages the state +// related to duplicate acks and determines if a retransmit is needed according +// to the rules in RFC 6582 (NewReno). +func (s *sender) checkDuplicateAck(seg *segment) bool { + ack := seg.ackNumber + if s.fr.active { + // We are in fast recovery mode. Ignore the ack if it's out of + // range. + if !ack.InRange(s.sndUna, s.sndNxt+1) { + return false + } + + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if s.fr.last.LessThan(ack) { + s.leaveFastRecovery() + return false + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if seg.logicalLen() != 0 || s.sndWnd != seg.window { + return false + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if ack == s.fr.first { + // We received a dup, inflate the congestion window by 1 + // packet if we're not at the max yet. + if s.sndCwnd < s.fr.maxCwnd { + s.sndCwnd++ + } + return false + } + + // A partial ack was received. Retransmit this packet and + // remember it so that we don't retransmit it again. We don't + // inflate the window because we're putting the same packet back + // onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + s.fr.first = ack + return true + } + + // We're not in fast recovery yet. A segment is considered a duplicate + // only if it doesn't carry any data and doesn't update the send window, + // because if it does, it wasn't sent in response to an out-of-order + // segment. + if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt { + s.dupAckCount = 0 + return false + } + + // Enter fast recovery when we reach 3 dups. + s.dupAckCount++ + if s.dupAckCount != 3 { + return false + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 + // + // We only do the check here, the incrementing of last to the highest + // sequence number transmitted till now is done when enterFastRecovery + // is invoked. + if !s.fr.last.LessThan(seg.ackNumber) { + s.dupAckCount = 0 + return false + } + s.enterFastRecovery() + s.dupAckCount = 0 + return true +} + +// updateCwnd updates the congestion window based on the number of packets that +// were acknowledged. +func (s *sender) updateCwnd(packetsAcked int) { + if s.sndCwnd < s.sndSsthresh { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := s.sndCwnd + packetsAcked + if newcwnd >= s.sndSsthresh { + newcwnd = s.sndSsthresh + s.sndCAAckCount = 0 + } + + packetsAcked -= newcwnd - s.sndCwnd + s.sndCwnd = newcwnd + if packetsAcked == 0 { + // We've consumed all ack'd packets. + return + } + } + + // Consume the packets in congestion avoidance mode. + s.sndCAAckCount += packetsAcked + if s.sndCAAckCount >= s.sndCwnd { + s.sndCwnd += s.sndCAAckCount / s.sndCwnd + s.sndCAAckCount = s.sndCAAckCount % s.sndCwnd + } +} + +// handleRcvdSegment is called when a segment is received; it is responsible for +// updating the send-related state. +func (s *sender) handleRcvdSegment(seg *segment) { + // Check if we can extract an RTT measurement from this ack. + if s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + s.updateRTO(time.Now().Sub(s.rttMeasureTime)) + s.rttMeasureSeqNum = s.sndNxt + } + + // Update Timestamp if required. See RFC7323, section-4.3. + s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + + // Count the duplicates and do the fast retransmit if needed. + rtx := s.checkDuplicateAck(seg) + + // Stash away the current window size. + s.sndWnd = seg.window + + // Ignore ack if it doesn't acknowledge any new data. + ack := seg.ackNumber + if (ack - 1).InRange(s.sndUna, s.sndNxt) { + // When an ack is received we must reset the timer. We stop it + // here and it will be restarted later if needed. + s.resendTimer.disable() + + // Remove all acknowledged data from the write list. + acked := s.sndUna.Size(ack) + s.sndUna = ack + + ackLeft := acked + originalOutstanding := s.outstanding + for ackLeft > 0 { + // We use logicalLen here because we can have FIN + // segments (which are always at the end of list) that + // have no data, but do consume a sequence number. + seg := s.writeList.Front() + datalen := seg.logicalLen() + + if datalen > ackLeft { + seg.data.TrimFront(int(ackLeft)) + break + } + + if s.writeNext == seg { + s.writeNext = seg.Next() + } + s.writeList.Remove(seg) + s.outstanding-- + seg.decRef() + ackLeft -= datalen + } + + // Update the send buffer usage and notify potential waiters. + s.ep.updateSndBufferUsage(int(acked)) + + // If we are not in fast recovery then update the congestion + // window based on the number of acknowledged packets. + if !s.fr.active { + s.updateCwnd(originalOutstanding - s.outstanding) + } + + // It is possible for s.outstanding to drop below zero if we get + // a retransmit timeout, reset outstanding to zero but later + // get an ack that cover previously sent data. + if s.outstanding < 0 { + s.outstanding = 0 + } + } + + // Now that we've popped all acknowledged data from the retransmit + // queue, retransmit if needed. + if rtx { + s.resendSegment() + } + + // Send more data now that some of the pending data has been ack'd, or + // that the window opened up, or the congestion window was inflated due + // to a duplicate ack during fast recovery. This will also re-enable + // the retransmit timer if needed. + s.sendData() +} + +// sendSegment sends a new segment containing the given payload, flags and +// sequence number. +func (s *sender) sendSegment(data *buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { + s.lastSendTime = time.Now() + if seq == s.rttMeasureSeqNum { + s.rttMeasureTime = s.lastSendTime + } + + rcvNxt, rcvWnd := s.ep.rcv.getSendParams() + + // Remember the max sent ack. + s.maxSentAck = rcvNxt + + if data == nil { + return s.ep.sendRaw(nil, flags, seq, rcvNxt, rcvWnd) + } + + if len(data.Views()) > 1 { + panic("send path does not support views with multiple buffers") + } + + return s.ep.sendRaw(data.First(), flags, seq, rcvNxt, rcvWnd) +} |