diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2019-05-03 10:49:58 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-05-03 10:51:18 -0700 |
commit | 458fe955a74bca6c33cb321901d771cf146f5cc6 (patch) | |
tree | 97641fe6228cf7444c32d97dbf39e04bca245b5d | |
parent | 95614bbefa2f4657c77b2040630088fdec7f5dd1 (diff) |
Implement support for SACK based recovery(RFC 6675).
PiperOrigin-RevId: 246536003
Change-Id: I118b745f45040be9c70cb6a1028acdb06c78d8c9
-rw-r--r-- | pkg/tcpip/stack/stack.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/sack_scoreboard.go | 73 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/sack_scoreboard_test.go | 105 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 723 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_sack_test.go | 216 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 32 |
10 files changed, 964 insertions, 227 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c82822ee2..9d8e8cda5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -102,6 +102,18 @@ type TCPFastRecoveryState struct { // MaxCwnd is the maximum value we are permitted to grow the congestion // window during recovery. This is set at the time we enter recovery. MaxCwnd int + + // HighRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. + // See: RFC 6675 Section 2 for details. + HighRxt seqnum.Value + + // RescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. + // See: RFC 6675 Section 2 for details. + RescueRxt seqnum.Value } // TCPReceiverState holds a copy of the internal state of the receiver for @@ -1024,7 +1036,7 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra // AddTCPProbe installs a probe function that will be invoked on every segment // received by a given TCP endpoint. The probe function is passed a copy of the -// TCP endpoint state. +// TCP endpoint state before and after processing of the segment. // // NOTE: TCPProbe is added only to endpoints created after this call. Endpoints // created prior to this call will not call the probe function. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index eaa67aeb7..3b927d82e 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -790,7 +790,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // seg.seq = snd.nxt-1. e.keepalive.unacked++ e.keepalive.Unlock() - e.snd.sendSegment(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) + e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) e.resetKeepaliveTimer(false) return nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 982f491cc..00962a63e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1673,10 +1673,12 @@ func (e *endpoint) completeState() stack.TCPEndpointState { LastSendTime: e.snd.lastSendTime, DupAckCount: e.snd.dupAckCount, FastRecovery: stack.TCPFastRecoveryState{ - Active: e.snd.fr.active, - First: e.snd.fr.first, - Last: e.snd.fr.last, - MaxCwnd: e.snd.fr.maxCwnd, + Active: e.snd.fr.active, + First: e.snd.fr.first, + Last: e.snd.fr.last, + MaxCwnd: e.snd.fr.maxCwnd, + HighRxt: e.snd.fr.highRxt, + RescueRxt: e.snd.fr.rescueRxt, }, SndCwnd: e.snd.sndCwnd, Ssthresh: e.snd.sndSsthresh, diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go index 99560d5b4..1c5766a42 100644 --- a/pkg/tcpip/transport/tcp/sack_scoreboard.go +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -38,6 +38,13 @@ const ( // // +stateify savable type SACKScoreboard struct { + // smss is defined in RFC5681 as following: + // + // The SMSS is the size of the largest segment that the sender can + // transmit. This value can be based on the maximum transmission unit + // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm, + // RMSS (see next item), or other factors. The size does not include + // the TCP/IP headers and options. smss uint16 maxSACKED seqnum.Value sacked seqnum.Size `state:"nosave"` @@ -138,6 +145,10 @@ func (s *SACKScoreboard) Insert(r header.SACKBlock) { // IsSACKED returns true if the a given range of sequence numbers denoted by r // are already covered by SACK information in the scoreboard. func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { + if s.Empty() { + return false + } + found := false s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { sacked := i.(header.SACKBlock) @@ -205,17 +216,46 @@ func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum return sackBlocks, s.maxSACKED } -// IsLost implements the IsLost(SeqNum) operation defined in RFC 3517 section 4. -// -// This routine returns whether the given sequence number is considered to be -// lost. The routine returns true when either nDupAckThreshold discontiguous -// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS) -// bytes with sequence numbers greater than 'SeqNum' have been SACKed. -// Otherwise, the routine returns false. -func (s *SACKScoreboard) IsLost(r header.SACKBlock) bool { +// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675 +// section 4 but operates on a range of sequence numbers and returns true if +// there are at least nDupAckThreshold SACK blocks greater than the range being +// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED +// with sequence numbers greater than the block being checked. +func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool { + if s.Empty() { + return false + } nDupSACK := 0 nDupSACKBytes := seqnum.Size(0) isLost := false + + // We need to check if the immediate lower (if any) sacked + // range contains or partially overlaps with r. + searchMore := true + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + sacked := i.(header.SACKBlock) + if sacked.Contains(r) { + searchMore = false + return false + } + if sacked.End.LessThanEq(r.Start) { + // all sequence numbers covered by sacked are below + // r so we continue searching. + return false + } + // There is a partial overlap. In this case we r.Start is + // between sacked.Start & sacked.End and r.End extends beyond + // sacked.End. + // Move r.Start to sacked.End and continuing searching blocks + // above r.Start. + r.Start = sacked.End + return false + }) + + if !searchMore { + return isLost + } + s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool { sacked := i.(header.SACKBlock) if sacked.Contains(r) { @@ -232,6 +272,18 @@ func (s *SACKScoreboard) IsLost(r header.SACKBlock) bool { return isLost } +// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section +// 4. +// +// This routine returns whether the given sequence number is considered to be +// lost. The routine returns true when either nDupAckThreshold discontiguous +// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS) +// bytes with sequence numbers greater than 'SeqNum' have been SACKed. +// Otherwise, the routine returns false. +func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool { + return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)}) +} + // Empty returns true if the SACK scoreboard has no entries, false otherwise. func (s *SACKScoreboard) Empty() bool { return s.ranges.Len() == 0 @@ -247,3 +299,8 @@ func (s *SACKScoreboard) Sacked() seqnum.Size { func (s *SACKScoreboard) MaxSACKED() seqnum.Value { return s.maxSACKED } + +// SMSS returns the sender's MSS as held by the SACK scoreboard. +func (s *SACKScoreboard) SMSS() uint16 { + return s.smss +} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go index 8f6890cdf..b59eedc9d 100644 --- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go +++ b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go @@ -97,31 +97,120 @@ func TestSACKScoreboardIsSACKED(t *testing.T) { } } -func TestSACKScoreboardIsLost(t *testing.T) { +func TestSACKScoreboardIsRangeLost(t *testing.T) { s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 50}) + s.Insert(header.SACKBlock{1, 25}) + s.Insert(header.SACKBlock{25, 50}) s.Insert(header.SACKBlock{51, 100}) s.Insert(header.SACKBlock{111, 120}) s.Insert(header.SACKBlock{101, 110}) s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{145, 146}) + s.Insert(header.SACKBlock{147, 148}) + s.Insert(header.SACKBlock{149, 150}) + s.Insert(header.SACKBlock{153, 154}) + s.Insert(header.SACKBlock{155, 156}) testCases := []struct { block header.SACKBlock lost bool }{ + // Block not covered by SACK block and has more than + // nDupAckThreshold discontiguous SACK blocks after it as well + // as (nDupAckThreshold -1) * 10 (smss) bytes that have been + // SACKED above the sequence number covered by this block. {block: header.SACKBlock{0, 1}, lost: true}, + + // These blocks have all been SACKed and should not be + // considered lost. {block: header.SACKBlock{1, 2}, lost: false}, + {block: header.SACKBlock{25, 26}, lost: false}, {block: header.SACKBlock{1, 45}, lost: false}, + + // Same as the first case above. {block: header.SACKBlock{50, 51}, lost: true}, - // This one should return true because there are - // > (nDupAckThreshold - 1) * 10 (smss) bytes that have been sacked above - // this sequence number. - {block: header.SACKBlock{119, 120}, lost: true}, + + // This block has been SACKed and should not be considered lost. + {block: header.SACKBlock{119, 120}, lost: false}, + + // This one should return true because there are > + // (nDupAckThreshold - 1) * 10 (smss) bytes that have been + // sacked above this sequence number. {block: header.SACKBlock{120, 121}, lost: true}, + + // This block has been SACKed and should not be considered lost. {block: header.SACKBlock{125, 126}, lost: false}, + + // This block has not been SACKed and there are nDupAckThreshold + // number of SACKed blocks after it. + {block: header.SACKBlock{141, 145}, lost: true}, + + // This block has not been SACKed and there are less than + // nDupAckThreshold SACKed sequences after it. + {block: header.SACKBlock{151, 152}, lost: false}, + } + for _, tc := range testCases { + if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { + t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) + } + } +} + +func TestSACKScoreboardIsLost(t *testing.T) { + s := tcp.NewSACKScoreboard(10, 0) + s.Insert(header.SACKBlock{1, 25}) + s.Insert(header.SACKBlock{25, 50}) + s.Insert(header.SACKBlock{51, 100}) + s.Insert(header.SACKBlock{111, 120}) + s.Insert(header.SACKBlock{101, 110}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{145, 146}) + s.Insert(header.SACKBlock{147, 148}) + s.Insert(header.SACKBlock{149, 150}) + s.Insert(header.SACKBlock{153, 154}) + s.Insert(header.SACKBlock{155, 156}) + testCases := []struct { + seq seqnum.Value + lost bool + }{ + // Sequence number not covered by SACK block and has more than + // nDupAckThreshold discontiguous SACK blocks after it as well + // as (nDupAckThreshold -1) * 10 (smss) bytes that have been + // SACKED above the sequence number. + {seq: 0, lost: true}, + + // These sequence numbers have all been SACKed and should not be + // considered lost. + {seq: 1, lost: false}, + {seq: 25, lost: false}, + {seq: 45, lost: false}, + + // Same as first case above. + {seq: 50, lost: true}, + + // This block has been SACKed and should not be considered lost. + {seq: 119, lost: false}, + + // This one should return true because there are > + // (nDupAckThreshold - 1) * 10 (smss) bytes that have been + // sacked above this sequence number. + {seq: 120, lost: true}, + + // This sequence number has been SACKed and should not be + // considered lost. + {seq: 125, lost: false}, + + // This sequence number has not been SACKed and there are + // nDupAckThreshold number of SACKed blocks after it. + {seq: 141, lost: true}, + + // This sequence number has not been SACKed and there are less + // than nDupAckThreshold SACKed sequences after it. + {seq: 151, lost: false}, } for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.block); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.block, got, want) + if want, got := tc.lost, s.IsLost(tc.seq); got != want { + t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) } } } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 187effb6b..450d9fbc1 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -179,3 +179,8 @@ func (s *segment) parse() bool { s.window = seqnum.Size(h.WindowSize()) return true } + +// sackBlock returns a header.SACKBlock that represents this segment. +func (s *segment) sackBlock() header.SACKBlock { + return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 50743670e..afc1d0a55 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -172,6 +172,18 @@ type fastRecovery struct { // receiver intentionally sends duplicate acks to artificially inflate // the sender's cwnd. maxCwnd int + + // highRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. + // See: RFC 6675 Section 2 for details. + highRxt seqnum.Value + + // rescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. + // See: RFC 6675 Section 2 for details. + rescueRxt seqnum.Value } func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { @@ -195,7 +207,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint maxSentAck: irs + 1, fr: fastRecovery{ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. - last: iss, + last: iss, + highRxt: iss, + rescueRxt: iss, }, gso: ep.gso != nil, } @@ -212,12 +226,15 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.sndWndScale = uint8(sndWndScale) } - // Initialize SACK Scoreboard. - s.ep.scoreboard = NewSACKScoreboard(mss, iss) s.resendTimer.init(&s.resendWaker) s.updateMaxPayloadSize(int(ep.route.MTU()), 0) + // Initialize SACK Scoreboard after updating max payload size as we use + // the maxPayloadSize as the smss when determining if a segment is lost + // etc. + s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + return s } @@ -256,6 +273,16 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { s.ep.gso.MSS = uint16(m) } + if count == 0 { + // updateMaxPayloadSize is also called when the sender is created. + // and there is no data to send in such cases. Return immediately. + return + } + + // Update the scoreboard's smss to reflect the new lowered + // maxPayloadSize. + s.ep.scoreboard.smss = uint16(m) + s.outstanding -= count if s.outstanding < 0 { s.outstanding = 0 @@ -285,7 +312,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // sendAck sends an ACK segment. func (s *sender) sendAck() { - s.sendSegment(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) + s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) } // updateRTO updates the retransmit timeout when a new roud-trip time is @@ -350,17 +377,20 @@ func (s *sender) resendSegment() { // Resend the segment. if seg := s.writeList.Front(); seg != nil { if seg.data.Size() > s.maxPayloadSize { - available := s.maxPayloadSize - // Split this segment up. - nSeg := seg.clone() - nSeg.data.TrimFront(available) - nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) - s.writeList.InsertAfter(seg, nSeg) - seg.data.CapLength(available) - } - s.sendSegment(seg.data, seg.flags, seg.sequenceNumber) + s.splitSeg(seg, s.maxPayloadSize) + } + + // See: RFC 6675 section 5 Step 4.3 + // + // To prevent retransmission, set both the HighRXT and RescueRXT + // to the highest sequence number in the retransmitted segment. + s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() - s.ep.stack.Stats().TCP.Retransmits.Increment() + + // Run SetPipe() as per RFC 6675 section 5 Step 4.4 + s.SetPipe() } } @@ -386,6 +416,14 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. + // + // Retransmit timeouts: + // After a retransmit timeout, record the highest sequence number + // transmitted in the variable recover, and exit the fast recovery + // procedure if applicable. + s.fr.last = s.sndNxt - 1 + if s.fr.active { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it @@ -393,11 +431,6 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveFastRecovery() } - // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. - // We store the highest sequence number transmitted in cases where - // we were not in fast recovery. - s.fr.last = s.sndNxt - 1 - s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and @@ -439,152 +472,323 @@ func (s *sender) pCount(seg *segment) int { return (size-1)/s.maxPayloadSize + 1 } -// sendData sends new data segments. It is called when data becomes available or -// when the send window opens up. -func (s *sender) sendData() { - limit := s.maxPayloadSize - if s.gso { - limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) - } - // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. - // "A TCP SHOULD set cwnd to no more than RW before beginning - // transmission if the TCP has not sent data in the interval exceeding - // the retrasmission timeout." - if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { - if s.sndCwnd > InitialCwnd { - s.sndCwnd = InitialCwnd - } +// splitSeg splits a given segment at the size specified and inserts the +// remainder as a new segment after the current one in the write list. +func (s *sender) splitSeg(seg *segment, size int) { + if seg.data.Size() <= size { + return } + // Split this segment up. + nSeg := seg.clone() + nSeg.data.TrimFront(size) + nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) + s.writeList.InsertAfter(seg, nSeg) + seg.data.CapLength(size) +} - seg := s.writeNext - end := s.sndUna.Add(s.sndWnd) - var dataSent bool - for ; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize - if cwndLimit < limit { - limit = cwndLimit - } - // We abuse the flags field to determine if we have already - // assigned a sequence number to this segment. - if seg.flags == 0 { - // Merge segments if allowed. - if seg.data.Size() != 0 { - available := int(s.sndNxt.Size(end)) - if available > limit { - available = limit +// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that +// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2 +// is handled by the normal send logic. +func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { + var s3 *segment + var s4 *segment + smss := s.ep.scoreboard.SMSS() + // Step 1. + for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + if !s.isAssignedSequenceNumber(seg) { + break + } + segSeq := seg.sequenceNumber + if seg.data.Size() > int(smss) { + s.splitSeg(seg, int(smss)) + } + // See RFC 6675 Section 4 + // + // 1. If there exists a smallest unSACKED sequence number + // 'S2' that meets the following 3 criteria for determinig + // loss, the sequence range of one segment of up to SMSS + // octects starting with S2 MUST be returned. + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + // NextSeg(): + // + // (1.a) S2 is greater than HighRxt + // (1.b) S2 is less than highest octect covered by + // any received SACK. + if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { + // NextSeg(): + // (1.c) IsLost(S2) returns true. + if s.ep.scoreboard.IsLost(segSeq) { + return seg, s3, s4 } - - // nextTooBig indicates that the next segment was too - // large to entirely fit in the current segment. It would - // be possible to split the next segment and merge the - // portion that fits, but unexpectedly splitting segments - // can have user visible side-effects which can break - // applications. For example, RFC 7766 section 8 says - // that the length and data of a DNS response should be - // sent in the same TCP segment to avoid triggering bugs - // in poorly written DNS implementations. - var nextTooBig bool - - for seg.Next() != nil && seg.Next().data.Size() != 0 { - if seg.data.Size()+seg.Next().data.Size() > available { - nextTooBig = true - break - } - - seg.data.Append(seg.Next().data) - - // Consume the segment that we just merged in. - s.writeList.Remove(seg.Next()) + // NextSeg(): + // + // (3): If the conditions for rules (1) and (2) + // fail, but there exists an unSACKed sequence + // number S3 that meets the criteria for + // detecting loss given in steps 1.a and 1.b + // above (specifically excluding (1.c)) then one + // segment of upto SMSS octets starting with S3 + // SHOULD be returned. + if s3 == nil { + s3 = seg } - - if !nextTooBig && seg.data.Size() < available { - // Segment is not full. - if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { - // Nagle's algorithm. From Wikipedia: - // Nagle's algorithm works by combining a number of - // small outgoing messages and sending them all at - // once. Specifically, as long as there is a sent - // packet for which the sender has received no - // acknowledgment, the sender should keep buffering - // its output until it has a full packet's worth of - // output, thus allowing output to be sent all at - // once. - break - } - if atomic.LoadUint32(&s.ep.cork) != 0 { - // Hold back the segment until full. - break + } + // NextSeg(): + // + // (4) If the conditions for (1), (2) and (3) fail, + // but there exists outstanding unSACKED data, we + // provide the opportunity for a single "rescue" + // retransmission per entry into loss recovery. If + // HighACK is greater than RescueRxt, the one + // segment of upto SMSS octects that MUST include + // the highest outstanding unSACKed sequence number + // SHOULD be returned. + if s.fr.rescueRxt.LessThan(s.sndUna - 1) { + if s4 != nil { + if s4.sequenceNumber.LessThan(segSeq) { + s4 = seg } + } else { + s4 = seg } + s.fr.rescueRxt = s.fr.last } - - // Assign flags. We don't do it above so that we can merge - // additional data if Nagle holds the segment. - seg.sequenceNumber = s.sndNxt - seg.flags = header.TCPFlagAck | header.TCPFlagPsh } + } - var segEnd seqnum.Value - if seg.data.Size() == 0 { - if s.writeList.Back() != seg { - panic("FIN segments must be the final segment in the write list.") - } - seg.flags = header.TCPFlagAck | header.TCPFlagFin - segEnd = seg.sequenceNumber.Add(1) - } else { - // We're sending a non-FIN segment. - if seg.flags&header.TCPFlagFin != 0 { - panic("Netstack queues FIN segments without data.") - } - - if !seg.sequenceNumber.LessThan(end) { - break - } + return nil, s3, s4 +} +// maybeSendSegment tries to send the specified segment and either coalesces +// other segments into this one or splits the specified segment based on the +// lower of the specified limit value or the receivers window size specified by +// end. +func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) { + // We abuse the flags field to determine if we have already + // assigned a sequence number to this segment. + if !s.isAssignedSequenceNumber(seg) { + // Merge segments if allowed. + if seg.data.Size() != 0 { available := int(seg.sequenceNumber.Size(end)) if available > limit { available = limit } - if seg.data.Size() > available { - // Split this segment up. - nSeg := seg.clone() - nSeg.data.TrimFront(available) - nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) - s.writeList.InsertAfter(seg, nSeg) - seg.data.CapLength(available) + // nextTooBig indicates that the next segment was too + // large to entirely fit in the current segment. It + // would be possible to split the next segment and merge + // the portion that fits, but unexpectedly splitting + // segments can have user visible side-effects which can + // break applications. For example, RFC 7766 section 8 + // says that the length and data of a DNS response + // should be sent in the same TCP segment to avoid + // triggering bugs in poorly written DNS + // implementations. + var nextTooBig bool + for seg.Next() != nil && seg.Next().data.Size() != 0 { + if seg.data.Size()+seg.Next().data.Size() > available { + nextTooBig = true + break + } + seg.data.Append(seg.Next().data) + + // Consume the segment that we just merged in. + s.writeList.Remove(seg.Next()) + } + if !nextTooBig && seg.data.Size() < available { + // Segment is not full. + if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { + // Nagle's algorithm. From Wikipedia: + // Nagle's algorithm works by + // combining a number of small + // outgoing messages and sending them + // all at once. Specifically, as long + // as there is a sent packet for which + // the sender has received no + // acknowledgment, the sender should + // keep buffering its output until it + // has a full packet's worth of + // output, thus allowing output to be + // sent all at once. + return false + } + if atomic.LoadUint32(&s.ep.cork) != 0 { + // Hold back the segment until full. + return false + } } + } + + // Assign flags. We don't do it above so that we can merge + // additional data if Nagle holds the segment. + seg.sequenceNumber = s.sndNxt + seg.flags = header.TCPFlagAck | header.TCPFlagPsh + } - s.outstanding += s.pCount(seg) - segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + var segEnd seqnum.Value + if seg.data.Size() == 0 { + if s.writeList.Back() != seg { + panic("FIN segments must be the final segment in the write list.") + } + seg.flags = header.TCPFlagAck | header.TCPFlagFin + segEnd = seg.sequenceNumber.Add(1) + } else { + // We're sending a non-FIN segment. + if seg.flags&header.TCPFlagFin != 0 { + panic("Netstack queues FIN segments without data.") } - if !dataSent { - dataSent = true - // We are sending data, so we should stop the keepalive timer to - // ensure that no keepalives are sent while there is pending data. - s.ep.disableKeepaliveTimer() + if !seg.sequenceNumber.LessThan(end) { + return false + } + + available := int(seg.sequenceNumber.Size(end)) + if available == 0 { + return false } + if available > limit { + available = limit + } + + if seg.data.Size() > available { + s.splitSeg(seg, available) + } + + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + } + + s.sendSegment(seg) - if !seg.xmitTime.IsZero() { - s.ep.stack.Stats().TCP.Retransmits.Increment() - if s.sndCwnd < s.sndSsthresh { - s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() + // Update sndNxt if we actually sent new data (as opposed to + // retransmitting some previously sent data). + if s.sndNxt.LessThan(segEnd) { + s.sndNxt = segEnd + } + + return true +} + +// handleSACKRecovery implements the loss recovery phase as described in RFC6675 +// section 5, step C. +func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { + s.SetPipe() + for s.outstanding < s.sndCwnd { + nextSeg, s3, s4 := s.NextSeg() + if nextSeg == nil { + // NextSeg(): + // + // Step (2): "If no sequence number 'S2' per rule (1) + // exists but there exists available unsent data and the + // receiver's advertised window allows, the sequence + // range of one segment of up to SMSS octets of + // previously unsent data starting with sequence number + // HighData+1 MUST be returned." + for seg := s.writeNext; seg != nil; seg = seg.Next() { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + continue + } + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. + // + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding++ + s.writeNext = seg.Next() + nextSeg = seg + break + } + if nextSeg != nil { + continue } } + rescueRtx := false + if nextSeg == nil && s3 != nil { + nextSeg = s3 + } + if nextSeg == nil && s4 != nil { + nextSeg = s4 + rescueRtx = true + } + if nextSeg == nil { + break + } + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + s.fr.highRxt = segEnd - 1 + } + + // RFC 6675, Step C.4. + // + // "The estimate of the amount of data outstanding in the network + // must be updated by incrementing pipe by the number of octets + // transmitted in (C.1)." + s.outstanding++ + dataSent = true + s.sendSegment(nextSeg) + } + return dataSent +} + +// sendData sends new data segments. It is called when data becomes available or +// when the send window opens up. +func (s *sender) sendData() { + limit := s.maxPayloadSize + if s.gso { + limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) + } + end := s.sndUna.Add(s.sndWnd) + + // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. + // "A TCP SHOULD set cwnd to no more than RW before beginning + // transmission if the TCP has not sent data in the interval exceeding + // the retrasmission timeout." + if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { + if s.sndCwnd > InitialCwnd { + s.sndCwnd = InitialCwnd + } + } - seg.xmitTime = time.Now() - s.sendSegment(seg.data, seg.flags, seg.sequenceNumber) + var dataSent bool - // Update sndNxt if we actually sent new data (as opposed to - // retransmitting some previously sent data). - if s.sndNxt.LessThan(segEnd) { - s.sndNxt = segEnd + // RFC 6675 recovery algorithm step C 1-5. + if s.fr.active && s.ep.sackPermitted { + dataSent = s.handleSACKRecovery(s.maxPayloadSize, end) + } else { + for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + if cwndLimit < limit { + limit = cwndLimit + } + if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + continue + } + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding++ + s.writeNext = seg.Next() } } - // Remember the next segment we'll write. - s.writeNext = seg + if dataSent { + // We sent data, so we should stop the keepalive timer to ensure + // that no keepalives are sent while there is pending data. + s.ep.disableKeepaliveTimer() + } // Enable the timer if we have pending data and it's not enabled yet. if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { @@ -599,91 +803,176 @@ func (s *sender) sendData() { func (s *sender) enterFastRecovery() { s.fr.active = true // Save state to reflect we're now in fast recovery. + // // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. - // We inflat the cwnd by 3 to account for the 3 packets which triggered + // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. s.sndCwnd = s.sndSsthresh + 3 s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding + if s.ep.sackPermitted { + s.ep.stack.Stats().TCP.SACKRecovery.Increment() + return + } s.ep.stack.Stats().TCP.FastRecovery.Increment() } func (s *sender) leaveFastRecovery() { s.fr.active = false - s.fr.first = 0 - s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = 0 s.dupAckCount = 0 // Deflate cwnd. It had been artificially inflated when new dups arrived. s.sndCwnd = s.sndSsthresh - // As recovery is now complete, delete all SACK information for acked - // data. - s.ep.scoreboard.Delete(s.sndUna) s.cc.PostRecovery() } -// checkDuplicateAck is called when an ack is received. It manages the state -// related to duplicate acks and determines if a retransmit is needed according -// to the rules in RFC 6582 (NewReno). -func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { +func (s *sender) handleFastRecovery(seg *segment) (rtx bool) { ack := seg.ackNumber - if s.fr.active { - // We are in fast recovery mode. Ignore the ack if it's out of - // range. - if !ack.InRange(s.sndUna, s.sndNxt+1) { - return false - } + // We are in fast recovery mode. Ignore the ack if it's out of + // range. + if !ack.InRange(s.sndUna, s.sndNxt+1) { + return false + } - // Leave fast recovery if it acknowledges all the data covered by - // this fast recovery session. - if s.fr.last.LessThan(ack) { - s.leaveFastRecovery() - return false - } + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if s.fr.last.LessThan(ack) { + s.leaveFastRecovery() + return false + } - // Don't count this as a duplicate if it is carrying data or - // updating the window. - if seg.logicalLen() != 0 || s.sndWnd != seg.window { - return false + if s.ep.sackPermitted { + // When SACK is enabled we let retransmission be governed by + // the SACK logic. + return false + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if seg.logicalLen() != 0 || s.sndWnd != seg.window { + return false + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if ack == s.fr.first { + // We received a dup, inflate the congestion window by 1 packet + // if we're not at the max yet. Only inflate the window if + // regular FastRecovery is in use, RFC6675 does not require + // inflating cwnd on duplicate ACKs. + if s.sndCwnd < s.fr.maxCwnd { + s.sndCwnd++ } + return false + } + + // A partial ack was received. Retransmit this packet and + // remember it so that we don't retransmit it again. We don't + // inflate the window because we're putting the same packet back + // onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + s.fr.first = ack + s.dupAckCount = 0 + return true +} + +// isAssignedSequenceNumber relies on the fact that we only set flags once a +// sequencenumber is assigned and that is only done right before we send the +// segment. As a result any segment that has a non-zero flag has a valid +// sequence number assigned to it. +func (s *sender) isAssignedSequenceNumber(seg *segment) bool { + return seg.flags != 0 +} - // Inflate the congestion window if we're getting duplicate acks - // for the packet we retransmitted. - if ack == s.fr.first { - // We received a dup, inflate the congestion window by 1 - // packet if we're not at the max yet. - if s.sndCwnd < s.fr.maxCwnd { - s.sndCwnd++ +// SetPipe implements the SetPipe() function described in RFC6675. Netstack +// maintains the congestion window in number of packets and not bytes, so +// SetPipe() here measures number of outstanding packets rather than actual +// outstanding bytes in the network. +func (s *sender) SetPipe() { + // If SACK isn't permitted or it is permitted but recovery is not active + // then ignore pipe calculations. + if !s.ep.sackPermitted || !s.fr.active { + return + } + pipe := 0 + smss := seqnum.Size(s.ep.scoreboard.SMSS()) + for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() { + // With GSO each segment can be much larger than SMSS. So check the segment + // in SMSS sized ranges. + segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size())) + for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) { + endSeq := startSeq.Add(smss) + if segEnd.LessThan(endSeq) { + endSeq = segEnd + } + sb := header.SACKBlock{startSeq, endSeq} + // SetPipe(): + // + // After initializing pipe to zero, the following steps are + // taken for each octet 'S1' in the sequence space between + // HighACK and HighData that has not been SACKed: + if !s1.sequenceNumber.LessThan(s.sndNxt) { + break + } + if s.ep.scoreboard.IsSACKED(sb) { + continue + } + + // SetPipe(): + // + // (a) If IsLost(S1) returns false, Pipe is incremened by 1. + // + // NOTE: here we mark the whole segment as lost. We do not try + // and test every byte in our write buffer as we maintain our + // pipe in terms of oustanding packets and not bytes. + if !s.ep.scoreboard.IsRangeLost(sb) { + pipe++ + } + // SetPipe(): + // (b) If S1 <= HighRxt, Pipe is incremented by 1. + if s1.sequenceNumber.LessThanEq(s.fr.highRxt) { + pipe++ } - return false } + } + s.outstanding = pipe +} - // A partial ack was received. Retransmit this packet and - // remember it so that we don't retransmit it again. We don't - // inflate the window because we're putting the same packet back - // onto the wire. - // - // N.B. The retransmit timer will be reset by the caller. - s.fr.first = ack - s.dupAckCount = 0 - return true +// checkDuplicateAck is called when an ack is received. It manages the state +// related to duplicate acks and determines if a retransmit is needed according +// to the rules in RFC 6582 (NewReno). +func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { + ack := seg.ackNumber + if s.fr.active { + return s.handleFastRecovery(seg) } // We're not in fast recovery yet. A segment is considered a duplicate // only if it doesn't carry any data and doesn't update the send window, // because if it does, it wasn't sent in response to an out-of-order - // segment. + // segment. If SACK is enabled then we have an additional check to see + // if the segment carries new SACK information. If it does then it is + // considered a duplicate ACK as per RFC6675. if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt { - s.dupAckCount = 0 - return false + if !s.ep.sackPermitted || !seg.hasNewSACKInfo { + s.dupAckCount = 0 + return false + } } s.dupAckCount++ - // Do not enter fast recovery until we reach nDupAckThreshold. - if s.dupAckCount < nDupAckThreshold { + + // Do not enter fast recovery until we reach nDupAckThreshold or the + // first unacknowledged byte is considered lost as per SACK scoreboard. + if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) { + // RFC 6675 Step 3. + s.fr.highRxt = s.sndUna - 1 + // Do run SetPipe() to calculate the outstanding segments. + s.SetPipe() return false } @@ -696,7 +985,6 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { s.dupAckCount = 0 return false } - s.cc.HandleNDupAcks() s.enterFastRecovery() s.dupAckCount = 0 @@ -737,6 +1025,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { seg.hasNewSACKInfo = true } } + s.SetPipe() } // Count the duplicates and do the fast retransmit if needed. @@ -749,9 +1038,6 @@ func (s *sender) handleRcvdSegment(seg *segment) { ack := seg.ackNumber if (ack - 1).InRange(s.sndUna, s.sndNxt) { s.dupAckCount = 0 - // When an ack is received we must reset the timer. We stop it - // here and it will be restarted later if needed. - s.resendTimer.disable() // See : https://tools.ietf.org/html/rfc1323#section-3.3. // Specifically we should only update the RTO using TSEcr if the @@ -767,6 +1053,11 @@ func (s *sender) handleRcvdSegment(seg *segment) { elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond s.updateRTO(elapsed) } + + // When an ack is received we must rearm the timer. + // RFC 6298 5.2 + s.resendTimer.enable(s.rto) + // Remove all acknowledged data from the write list. acked := s.sndUna.Size(ack) s.sndUna = ack @@ -792,7 +1083,13 @@ func (s *sender) handleRcvdSegment(seg *segment) { s.writeNext = seg.Next() } s.writeList.Remove(seg) - s.outstanding -= s.pCount(seg) + + // if SACK is enabled then Only reduce outstanding if + // the segment was not previously SACKED as these have + // already been accounted for in SetPipe(). + if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + s.outstanding -= s.pCount(seg) + } seg.decRef() ackLeft -= datalen } @@ -815,8 +1112,16 @@ func (s *sender) handleRcvdSegment(seg *segment) { if s.outstanding < 0 { s.outstanding = 0 } - } + s.SetPipe() + + // If all outstanding data was acknowledged the disable the timer. + // RFC 6298 Rule 5.3 + if s.sndUna == s.sndNxt { + s.outstanding = 0 + s.resendTimer.disable() + } + } // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. if rtx { @@ -827,12 +1132,26 @@ func (s *sender) handleRcvdSegment(seg *segment) { // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - s.sendData() + if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo { + s.sendData() + } } -// sendSegment sends a new segment containing the given payload, flags and -// sequence number. -func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { +// sendSegment sends the specified segment. +func (s *sender) sendSegment(seg *segment) *tcpip.Error { + if !seg.xmitTime.IsZero() { + s.ep.stack.Stats().TCP.Retransmits.Increment() + if s.sndCwnd < s.sndSsthresh { + s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() + } + } + seg.xmitTime = time.Now() + return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) +} + +// sendSegmentFromView sends a new segment containing the given payload, flags +// and sequence number. +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime @@ -843,5 +1162,19 @@ func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum. // Remember the max sent ack. s.maxSentAck = rcvNxt + // Every time a packet containing data is sent (including a + // retransmission), if SACK is enabled then use the conservative timer + // described in RFC6675 Section 4.0, otherwise follow the standard time + // described in RFC6298 Section 5.2. + if data.Size() != 0 { + if s.ep.sackPermitted { + s.resendTimer.enable(s.rto) + } else { + if !s.resendTimer.enabled() { + s.resendTimer.enable(s.rto) + } + } + } + return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index dbfbd5c4f..025d133be 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -16,22 +16,33 @@ package tcp_test import ( "fmt" + "log" "reflect" "testing" + "time" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context" ) -// createConnectWithSACKPermittedOption creates and connects c.ep with the +// createConnectedWithSACKPermittedOption creates and connects c.ep with the // SACKPermitted option enabled if the stack in the context has the SACK support // enabled. func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) } +// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS +// option enabled if the stack in the context has SACK and TS enabled. +func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { + return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) +} + func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { @@ -348,3 +359,206 @@ func TestTrimSackBlockList(t *testing.T) { } } } + +func TestSACKRecovery(t *testing.T) { + const maxPayload = 10 + // See: tcp.makeOptions for why tsOptionSize is set to 12 here. + const tsOptionSize = 12 + // Enabling SACK means the payload size is reduced to account + // for the extra space required for the TCP options. + // + // We increase the MTU by 40 bytes to account for SACK and Timestamp + // options. + const maxTCPOptionSize = 40 + + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + // We use log.Printf instead of t.Logf here because this probe + // can fire even when the test function has finished. This is + // because closing the endpoint in cleanup() does not mean the + // actual worker loop terminates immediately as it still has to + // do a full TCP shutdown. But this test can finish running + // before the shutdown is done. Using t.Logf in such a case + // causes the test to panic due to logging after test finished. + log.Printf("state: %+v\n", s) + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) + end := start.Add(10) + for i := 0; i < 3; i++ { + c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) + end = end.Add(10) + } + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) + + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) + } + } + + // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause + // window inflation and sending of packets is completely handled by the + // SACK Recovery algorithm. We should see no packets being released, as + // the cwnd at this point after entering recovery should be half of the + // outstanding number of packets in flight. + for i := 0; i < 7; i++ { + c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) + end = end.Add(10) + } + + recover := bytesRead + + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) + + // Acknowledge half of the pending data. This along with the 10 sacked + // segments above should reduce the outstanding below the current + // congestion window allowing the sender to transmit data. + rtxOffset = bytesRead - expected*maxPayload/2 + + // Now send a partial ACK w/ a SACK block that indicates that the next 3 + // segments are lost and we have received 6 segments after the lost + // segments. This should cause the sender to immediately transmit all 3 + // segments in response to this ACK unlike in FastRecovery where only 1 + // segment is retransmitted per ACK. + start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) + end = start.Add(60) + c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) + + // At this point, we acked expected/2 packets and we SACKED 6 packets and + // 3 segments were considered lost due to the SACK block we sent. + // + // So total packets outstanding can be calculated as follows after 7 + // iterations of slow start -> 10/20/40/80/160/320/640. So expected + // should be 640 at start, then we went to recover at which point the + // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the + // network). + // Outstanding at this point after acking half the window + // (320 packets) will be: + // outstanding = 640-320-6(due to SACK block)-3 = 311 + // + // The last 3 is due to the fact that the first 3 packets after + // rtxOffset will be considered lost due to the SACK blocks sent. + // Receive the retransmit due to partial ack. + + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) + // Receive the 2 extra packets that should have been retransmitted as + // those should be considered lost and immediately retransmitted based + // on the SACK information in the previous ACK sent above. + for i := 0; i < 2; i++ { + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) + } + + // Now we should get 9 more new unsent packets as the cwnd is 323 and + // outstanding is 311. + for i := 0; i < 9; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + // In SACK recovery only the first segment is fast retransmitted when + // entering recovery. + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { + t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) + } + + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) + + // Acknowledge all pending data to recover point. + c.SendAck(790, recover) + + // At this point, the cwnd should reset to expected/2 and there are 9 + // packets outstanding. + // + // Now in the first iteration since there are 9 packets outstanding. + // We would expect to get expected/2 - 9 packets. But subsequent + // iterations will send us expected/2 + 1 (per iteration). + expected = expected/2 - 9 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 9 + } + expected++ + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a8b290dae..6e3ba5922 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1792,12 +1792,12 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Receive SYN packet. b := c.GetPacket() - + mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -1812,7 +1812,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCPFlags(header.TCPFlagSyn), checker.SrcPort(tcp.SourcePort()), checker.SeqNum(tcp.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -2737,7 +2737,8 @@ func TestFastRecovery(t *testing.T) { // A partial ACK during recovery should reduce congestion window by the // number acked. Since we had "expected" packets outstanding before sending // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 7. Which means the sender should not send any more packets + // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered + // fast recovery). Which means the sender should not send any more packets // till we ack this one. c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) @@ -2843,7 +2844,7 @@ func TestRetransmit(t *testing.T) { } if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index fa721a7f8..e08eb6533 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -355,13 +355,27 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // SendAck sends an ACK packet. func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { + c.SendAckWithSACK(seq, bytesReceived, nil) +} + +// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. +func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { + options := make([]byte, 40) + offset := 0 + if len(sackBlocks) > 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) + } + c.SendPacket(nil, &Headers{ SrcPort: TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(testInitialSequenceNumber).Add(1), + SeqNum: seq, AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), RcvWnd: 30000, + TCPOpts: options[:offset], }) } @@ -369,9 +383,17 @@ func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { // verifies that the packet packet payload of packet matches the slice // of data indicated by offset & size. func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) +} + +// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint +// and verifies that the packet packet payload of packet matches the slice of +// data indicated by offset & size and skips optlen bytes in addition to the IP +// TCP headers when comparing the data. +func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { b := c.GetPacket() checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize), + checker.PayloadLen(size+header.TCPMinimumSize+optlen), checker.TCP( checker.DstPort(TestPort), checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), @@ -381,7 +403,7 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { ) pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) } } @@ -683,12 +705,14 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * b := c.GetPacket() // Validate that the syn has the timestamp option and a valid // TS value. + mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) + checker.IPv4(c.t, b, checker.TCP( checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagSyn), checker.TCPSynOptions(header.TCPSynOptions{ - MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize), + MSS: mss, TS: true, WS: defaultWindowScale, SACKPermitted: c.SACKEnabled(), |