summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2019-05-03 10:49:58 -0700
committerShentubot <shentubot@google.com>2019-05-03 10:51:18 -0700
commit458fe955a74bca6c33cb321901d771cf146f5cc6 (patch)
tree97641fe6228cf7444c32d97dbf39e04bca245b5d /pkg/tcpip/transport
parent95614bbefa2f4657c77b2040630088fdec7f5dd1 (diff)
Implement support for SACK based recovery(RFC 6675).
PiperOrigin-RevId: 246536003 Change-Id: I118b745f45040be9c70cb6a1028acdb06c78d8c9
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/tcp/connect.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go73
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard_test.go105
-rw-r--r--pkg/tcpip/transport/tcp/segment.go5
-rw-r--r--pkg/tcpip/transport/tcp/snd.go723
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go216
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go11
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go32
9 files changed, 951 insertions, 226 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index eaa67aeb7..3b927d82e 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -790,7 +790,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
// seg.seq = snd.nxt-1.
e.keepalive.unacked++
e.keepalive.Unlock()
- e.snd.sendSegment(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
+ e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
e.resetKeepaliveTimer(false)
return nil
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 982f491cc..00962a63e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1673,10 +1673,12 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
LastSendTime: e.snd.lastSendTime,
DupAckCount: e.snd.dupAckCount,
FastRecovery: stack.TCPFastRecoveryState{
- Active: e.snd.fr.active,
- First: e.snd.fr.first,
- Last: e.snd.fr.last,
- MaxCwnd: e.snd.fr.maxCwnd,
+ Active: e.snd.fr.active,
+ First: e.snd.fr.first,
+ Last: e.snd.fr.last,
+ MaxCwnd: e.snd.fr.maxCwnd,
+ HighRxt: e.snd.fr.highRxt,
+ RescueRxt: e.snd.fr.rescueRxt,
},
SndCwnd: e.snd.sndCwnd,
Ssthresh: e.snd.sndSsthresh,
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
index 99560d5b4..1c5766a42 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -38,6 +38,13 @@ const (
//
// +stateify savable
type SACKScoreboard struct {
+ // smss is defined in RFC5681 as following:
+ //
+ // The SMSS is the size of the largest segment that the sender can
+ // transmit. This value can be based on the maximum transmission unit
+ // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm,
+ // RMSS (see next item), or other factors. The size does not include
+ // the TCP/IP headers and options.
smss uint16
maxSACKED seqnum.Value
sacked seqnum.Size `state:"nosave"`
@@ -138,6 +145,10 @@ func (s *SACKScoreboard) Insert(r header.SACKBlock) {
// IsSACKED returns true if the a given range of sequence numbers denoted by r
// are already covered by SACK information in the scoreboard.
func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+
found := false
s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
sacked := i.(header.SACKBlock)
@@ -205,17 +216,46 @@ func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum
return sackBlocks, s.maxSACKED
}
-// IsLost implements the IsLost(SeqNum) operation defined in RFC 3517 section 4.
-//
-// This routine returns whether the given sequence number is considered to be
-// lost. The routine returns true when either nDupAckThreshold discontiguous
-// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS)
-// bytes with sequence numbers greater than 'SeqNum' have been SACKed.
-// Otherwise, the routine returns false.
-func (s *SACKScoreboard) IsLost(r header.SACKBlock) bool {
+// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675
+// section 4 but operates on a range of sequence numbers and returns true if
+// there are at least nDupAckThreshold SACK blocks greater than the range being
+// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED
+// with sequence numbers greater than the block being checked.
+func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
nDupSACK := 0
nDupSACKBytes := seqnum.Size(0)
isLost := false
+
+ // We need to check if the immediate lower (if any) sacked
+ // range contains or partially overlaps with r.
+ searchMore := true
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ searchMore = false
+ return false
+ }
+ if sacked.End.LessThanEq(r.Start) {
+ // all sequence numbers covered by sacked are below
+ // r so we continue searching.
+ return false
+ }
+ // There is a partial overlap. In this case we r.Start is
+ // between sacked.Start & sacked.End and r.End extends beyond
+ // sacked.End.
+ // Move r.Start to sacked.End and continuing searching blocks
+ // above r.Start.
+ r.Start = sacked.End
+ return false
+ })
+
+ if !searchMore {
+ return isLost
+ }
+
s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
sacked := i.(header.SACKBlock)
if sacked.Contains(r) {
@@ -232,6 +272,18 @@ func (s *SACKScoreboard) IsLost(r header.SACKBlock) bool {
return isLost
}
+// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section
+// 4.
+//
+// This routine returns whether the given sequence number is considered to be
+// lost. The routine returns true when either nDupAckThreshold discontiguous
+// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS)
+// bytes with sequence numbers greater than 'SeqNum' have been SACKed.
+// Otherwise, the routine returns false.
+func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool {
+ return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)})
+}
+
// Empty returns true if the SACK scoreboard has no entries, false otherwise.
func (s *SACKScoreboard) Empty() bool {
return s.ranges.Len() == 0
@@ -247,3 +299,8 @@ func (s *SACKScoreboard) Sacked() seqnum.Size {
func (s *SACKScoreboard) MaxSACKED() seqnum.Value {
return s.maxSACKED
}
+
+// SMSS returns the sender's MSS as held by the SACK scoreboard.
+func (s *SACKScoreboard) SMSS() uint16 {
+ return s.smss
+}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
index 8f6890cdf..b59eedc9d 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
@@ -97,31 +97,120 @@ func TestSACKScoreboardIsSACKED(t *testing.T) {
}
}
-func TestSACKScoreboardIsLost(t *testing.T) {
+func TestSACKScoreboardIsRangeLost(t *testing.T) {
s := tcp.NewSACKScoreboard(10, 0)
- s.Insert(header.SACKBlock{1, 50})
+ s.Insert(header.SACKBlock{1, 25})
+ s.Insert(header.SACKBlock{25, 50})
s.Insert(header.SACKBlock{51, 100})
s.Insert(header.SACKBlock{111, 120})
s.Insert(header.SACKBlock{101, 110})
s.Insert(header.SACKBlock{121, 141})
+ s.Insert(header.SACKBlock{145, 146})
+ s.Insert(header.SACKBlock{147, 148})
+ s.Insert(header.SACKBlock{149, 150})
+ s.Insert(header.SACKBlock{153, 154})
+ s.Insert(header.SACKBlock{155, 156})
testCases := []struct {
block header.SACKBlock
lost bool
}{
+ // Block not covered by SACK block and has more than
+ // nDupAckThreshold discontiguous SACK blocks after it as well
+ // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
+ // SACKED above the sequence number covered by this block.
{block: header.SACKBlock{0, 1}, lost: true},
+
+ // These blocks have all been SACKed and should not be
+ // considered lost.
{block: header.SACKBlock{1, 2}, lost: false},
+ {block: header.SACKBlock{25, 26}, lost: false},
{block: header.SACKBlock{1, 45}, lost: false},
+
+ // Same as the first case above.
{block: header.SACKBlock{50, 51}, lost: true},
- // This one should return true because there are
- // > (nDupAckThreshold - 1) * 10 (smss) bytes that have been sacked above
- // this sequence number.
- {block: header.SACKBlock{119, 120}, lost: true},
+
+ // This block has been SACKed and should not be considered lost.
+ {block: header.SACKBlock{119, 120}, lost: false},
+
+ // This one should return true because there are >
+ // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
+ // sacked above this sequence number.
{block: header.SACKBlock{120, 121}, lost: true},
+
+ // This block has been SACKed and should not be considered lost.
{block: header.SACKBlock{125, 126}, lost: false},
+
+ // This block has not been SACKed and there are nDupAckThreshold
+ // number of SACKed blocks after it.
+ {block: header.SACKBlock{141, 145}, lost: true},
+
+ // This block has not been SACKed and there are less than
+ // nDupAckThreshold SACKed sequences after it.
+ {block: header.SACKBlock{151, 152}, lost: false},
+ }
+ for _, tc := range testCases {
+ if want, got := tc.lost, s.IsRangeLost(tc.block); got != want {
+ t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want)
+ }
+ }
+}
+
+func TestSACKScoreboardIsLost(t *testing.T) {
+ s := tcp.NewSACKScoreboard(10, 0)
+ s.Insert(header.SACKBlock{1, 25})
+ s.Insert(header.SACKBlock{25, 50})
+ s.Insert(header.SACKBlock{51, 100})
+ s.Insert(header.SACKBlock{111, 120})
+ s.Insert(header.SACKBlock{101, 110})
+ s.Insert(header.SACKBlock{121, 141})
+ s.Insert(header.SACKBlock{121, 141})
+ s.Insert(header.SACKBlock{145, 146})
+ s.Insert(header.SACKBlock{147, 148})
+ s.Insert(header.SACKBlock{149, 150})
+ s.Insert(header.SACKBlock{153, 154})
+ s.Insert(header.SACKBlock{155, 156})
+ testCases := []struct {
+ seq seqnum.Value
+ lost bool
+ }{
+ // Sequence number not covered by SACK block and has more than
+ // nDupAckThreshold discontiguous SACK blocks after it as well
+ // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
+ // SACKED above the sequence number.
+ {seq: 0, lost: true},
+
+ // These sequence numbers have all been SACKed and should not be
+ // considered lost.
+ {seq: 1, lost: false},
+ {seq: 25, lost: false},
+ {seq: 45, lost: false},
+
+ // Same as first case above.
+ {seq: 50, lost: true},
+
+ // This block has been SACKed and should not be considered lost.
+ {seq: 119, lost: false},
+
+ // This one should return true because there are >
+ // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
+ // sacked above this sequence number.
+ {seq: 120, lost: true},
+
+ // This sequence number has been SACKed and should not be
+ // considered lost.
+ {seq: 125, lost: false},
+
+ // This sequence number has not been SACKed and there are
+ // nDupAckThreshold number of SACKed blocks after it.
+ {seq: 141, lost: true},
+
+ // This sequence number has not been SACKed and there are less
+ // than nDupAckThreshold SACKed sequences after it.
+ {seq: 151, lost: false},
}
for _, tc := range testCases {
- if want, got := tc.lost, s.IsLost(tc.block); got != want {
- t.Errorf("s.IsLost(%v) = %v, want %v", tc.block, got, want)
+ if want, got := tc.lost, s.IsLost(tc.seq); got != want {
+ t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want)
}
}
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 187effb6b..450d9fbc1 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -179,3 +179,8 @@ func (s *segment) parse() bool {
s.window = seqnum.Size(h.WindowSize())
return true
}
+
+// sackBlock returns a header.SACKBlock that represents this segment.
+func (s *segment) sackBlock() header.SACKBlock {
+ return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())}
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 50743670e..afc1d0a55 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -172,6 +172,18 @@ type fastRecovery struct {
// receiver intentionally sends duplicate acks to artificially inflate
// the sender's cwnd.
maxCwnd int
+
+ // highRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ highRxt seqnum.Value
+
+ // rescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ rescueRxt seqnum.Value
}
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
@@ -195,7 +207,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
maxSentAck: irs + 1,
fr: fastRecovery{
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
- last: iss,
+ last: iss,
+ highRxt: iss,
+ rescueRxt: iss,
},
gso: ep.gso != nil,
}
@@ -212,12 +226,15 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.sndWndScale = uint8(sndWndScale)
}
- // Initialize SACK Scoreboard.
- s.ep.scoreboard = NewSACKScoreboard(mss, iss)
s.resendTimer.init(&s.resendWaker)
s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
+ // Initialize SACK Scoreboard after updating max payload size as we use
+ // the maxPayloadSize as the smss when determining if a segment is lost
+ // etc.
+ s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+
return s
}
@@ -256,6 +273,16 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
s.ep.gso.MSS = uint16(m)
}
+ if count == 0 {
+ // updateMaxPayloadSize is also called when the sender is created.
+ // and there is no data to send in such cases. Return immediately.
+ return
+ }
+
+ // Update the scoreboard's smss to reflect the new lowered
+ // maxPayloadSize.
+ s.ep.scoreboard.smss = uint16(m)
+
s.outstanding -= count
if s.outstanding < 0 {
s.outstanding = 0
@@ -285,7 +312,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// sendAck sends an ACK segment.
func (s *sender) sendAck() {
- s.sendSegment(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
+ s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
}
// updateRTO updates the retransmit timeout when a new roud-trip time is
@@ -350,17 +377,20 @@ func (s *sender) resendSegment() {
// Resend the segment.
if seg := s.writeList.Front(); seg != nil {
if seg.data.Size() > s.maxPayloadSize {
- available := s.maxPayloadSize
- // Split this segment up.
- nSeg := seg.clone()
- nSeg.data.TrimFront(available)
- nSeg.sequenceNumber.UpdateForward(seqnum.Size(available))
- s.writeList.InsertAfter(seg, nSeg)
- seg.data.CapLength(available)
- }
- s.sendSegment(seg.data, seg.flags, seg.sequenceNumber)
+ s.splitSeg(seg, s.maxPayloadSize)
+ }
+
+ // See: RFC 6675 section 5 Step 4.3
+ //
+ // To prevent retransmission, set both the HighRXT and RescueRXT
+ // to the highest sequence number in the retransmitted segment.
+ s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.sendSegment(seg)
s.ep.stack.Stats().TCP.FastRetransmit.Increment()
- s.ep.stack.Stats().TCP.Retransmits.Increment()
+
+ // Run SetPipe() as per RFC 6675 section 5 Step 4.4
+ s.SetPipe()
}
}
@@ -386,6 +416,14 @@ func (s *sender) retransmitTimerExpired() bool {
// below.
s.rto *= 2
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
+ //
+ // Retransmit timeouts:
+ // After a retransmit timeout, record the highest sequence number
+ // transmitted in the variable recover, and exit the fast recovery
+ // procedure if applicable.
+ s.fr.last = s.sndNxt - 1
+
if s.fr.active {
// We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
@@ -393,11 +431,6 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveFastRecovery()
}
- // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
- // We store the highest sequence number transmitted in cases where
- // we were not in fast recovery.
- s.fr.last = s.sndNxt - 1
-
s.cc.HandleRTOExpired()
// Mark the next segment to be sent as the first unacknowledged one and
@@ -439,152 +472,323 @@ func (s *sender) pCount(seg *segment) int {
return (size-1)/s.maxPayloadSize + 1
}
-// sendData sends new data segments. It is called when data becomes available or
-// when the send window opens up.
-func (s *sender) sendData() {
- limit := s.maxPayloadSize
- if s.gso {
- limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
- }
- // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
- // "A TCP SHOULD set cwnd to no more than RW before beginning
- // transmission if the TCP has not sent data in the interval exceeding
- // the retrasmission timeout."
- if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
- if s.sndCwnd > InitialCwnd {
- s.sndCwnd = InitialCwnd
- }
+// splitSeg splits a given segment at the size specified and inserts the
+// remainder as a new segment after the current one in the write list.
+func (s *sender) splitSeg(seg *segment, size int) {
+ if seg.data.Size() <= size {
+ return
}
+ // Split this segment up.
+ nSeg := seg.clone()
+ nSeg.data.TrimFront(size)
+ nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
+ s.writeList.InsertAfter(seg, nSeg)
+ seg.data.CapLength(size)
+}
- seg := s.writeNext
- end := s.sndUna.Add(s.sndWnd)
- var dataSent bool
- for ; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
- cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
- if cwndLimit < limit {
- limit = cwndLimit
- }
- // We abuse the flags field to determine if we have already
- // assigned a sequence number to this segment.
- if seg.flags == 0 {
- // Merge segments if allowed.
- if seg.data.Size() != 0 {
- available := int(s.sndNxt.Size(end))
- if available > limit {
- available = limit
+// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that
+// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2
+// is handled by the normal send logic.
+func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
+ var s3 *segment
+ var s4 *segment
+ smss := s.ep.scoreboard.SMSS()
+ // Step 1.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if !s.isAssignedSequenceNumber(seg) {
+ break
+ }
+ segSeq := seg.sequenceNumber
+ if seg.data.Size() > int(smss) {
+ s.splitSeg(seg, int(smss))
+ }
+ // See RFC 6675 Section 4
+ //
+ // 1. If there exists a smallest unSACKED sequence number
+ // 'S2' that meets the following 3 criteria for determinig
+ // loss, the sequence range of one segment of up to SMSS
+ // octects starting with S2 MUST be returned.
+ if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) {
+ // NextSeg():
+ //
+ // (1.a) S2 is greater than HighRxt
+ // (1.b) S2 is less than highest octect covered by
+ // any received SACK.
+ if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
+ // NextSeg():
+ // (1.c) IsLost(S2) returns true.
+ if s.ep.scoreboard.IsLost(segSeq) {
+ return seg, s3, s4
}
-
- // nextTooBig indicates that the next segment was too
- // large to entirely fit in the current segment. It would
- // be possible to split the next segment and merge the
- // portion that fits, but unexpectedly splitting segments
- // can have user visible side-effects which can break
- // applications. For example, RFC 7766 section 8 says
- // that the length and data of a DNS response should be
- // sent in the same TCP segment to avoid triggering bugs
- // in poorly written DNS implementations.
- var nextTooBig bool
-
- for seg.Next() != nil && seg.Next().data.Size() != 0 {
- if seg.data.Size()+seg.Next().data.Size() > available {
- nextTooBig = true
- break
- }
-
- seg.data.Append(seg.Next().data)
-
- // Consume the segment that we just merged in.
- s.writeList.Remove(seg.Next())
+ // NextSeg():
+ //
+ // (3): If the conditions for rules (1) and (2)
+ // fail, but there exists an unSACKed sequence
+ // number S3 that meets the criteria for
+ // detecting loss given in steps 1.a and 1.b
+ // above (specifically excluding (1.c)) then one
+ // segment of upto SMSS octets starting with S3
+ // SHOULD be returned.
+ if s3 == nil {
+ s3 = seg
}
-
- if !nextTooBig && seg.data.Size() < available {
- // Segment is not full.
- if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
- // Nagle's algorithm. From Wikipedia:
- // Nagle's algorithm works by combining a number of
- // small outgoing messages and sending them all at
- // once. Specifically, as long as there is a sent
- // packet for which the sender has received no
- // acknowledgment, the sender should keep buffering
- // its output until it has a full packet's worth of
- // output, thus allowing output to be sent all at
- // once.
- break
- }
- if atomic.LoadUint32(&s.ep.cork) != 0 {
- // Hold back the segment until full.
- break
+ }
+ // NextSeg():
+ //
+ // (4) If the conditions for (1), (2) and (3) fail,
+ // but there exists outstanding unSACKED data, we
+ // provide the opportunity for a single "rescue"
+ // retransmission per entry into loss recovery. If
+ // HighACK is greater than RescueRxt, the one
+ // segment of upto SMSS octects that MUST include
+ // the highest outstanding unSACKed sequence number
+ // SHOULD be returned.
+ if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
+ if s4 != nil {
+ if s4.sequenceNumber.LessThan(segSeq) {
+ s4 = seg
}
+ } else {
+ s4 = seg
}
+ s.fr.rescueRxt = s.fr.last
}
-
- // Assign flags. We don't do it above so that we can merge
- // additional data if Nagle holds the segment.
- seg.sequenceNumber = s.sndNxt
- seg.flags = header.TCPFlagAck | header.TCPFlagPsh
}
+ }
- var segEnd seqnum.Value
- if seg.data.Size() == 0 {
- if s.writeList.Back() != seg {
- panic("FIN segments must be the final segment in the write list.")
- }
- seg.flags = header.TCPFlagAck | header.TCPFlagFin
- segEnd = seg.sequenceNumber.Add(1)
- } else {
- // We're sending a non-FIN segment.
- if seg.flags&header.TCPFlagFin != 0 {
- panic("Netstack queues FIN segments without data.")
- }
-
- if !seg.sequenceNumber.LessThan(end) {
- break
- }
+ return nil, s3, s4
+}
+// maybeSendSegment tries to send the specified segment and either coalesces
+// other segments into this one or splits the specified segment based on the
+// lower of the specified limit value or the receivers window size specified by
+// end.
+func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) {
+ // We abuse the flags field to determine if we have already
+ // assigned a sequence number to this segment.
+ if !s.isAssignedSequenceNumber(seg) {
+ // Merge segments if allowed.
+ if seg.data.Size() != 0 {
available := int(seg.sequenceNumber.Size(end))
if available > limit {
available = limit
}
- if seg.data.Size() > available {
- // Split this segment up.
- nSeg := seg.clone()
- nSeg.data.TrimFront(available)
- nSeg.sequenceNumber.UpdateForward(seqnum.Size(available))
- s.writeList.InsertAfter(seg, nSeg)
- seg.data.CapLength(available)
+ // nextTooBig indicates that the next segment was too
+ // large to entirely fit in the current segment. It
+ // would be possible to split the next segment and merge
+ // the portion that fits, but unexpectedly splitting
+ // segments can have user visible side-effects which can
+ // break applications. For example, RFC 7766 section 8
+ // says that the length and data of a DNS response
+ // should be sent in the same TCP segment to avoid
+ // triggering bugs in poorly written DNS
+ // implementations.
+ var nextTooBig bool
+ for seg.Next() != nil && seg.Next().data.Size() != 0 {
+ if seg.data.Size()+seg.Next().data.Size() > available {
+ nextTooBig = true
+ break
+ }
+ seg.data.Append(seg.Next().data)
+
+ // Consume the segment that we just merged in.
+ s.writeList.Remove(seg.Next())
+ }
+ if !nextTooBig && seg.data.Size() < available {
+ // Segment is not full.
+ if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ // Nagle's algorithm. From Wikipedia:
+ // Nagle's algorithm works by
+ // combining a number of small
+ // outgoing messages and sending them
+ // all at once. Specifically, as long
+ // as there is a sent packet for which
+ // the sender has received no
+ // acknowledgment, the sender should
+ // keep buffering its output until it
+ // has a full packet's worth of
+ // output, thus allowing output to be
+ // sent all at once.
+ return false
+ }
+ if atomic.LoadUint32(&s.ep.cork) != 0 {
+ // Hold back the segment until full.
+ return false
+ }
}
+ }
+
+ // Assign flags. We don't do it above so that we can merge
+ // additional data if Nagle holds the segment.
+ seg.sequenceNumber = s.sndNxt
+ seg.flags = header.TCPFlagAck | header.TCPFlagPsh
+ }
- s.outstanding += s.pCount(seg)
- segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ var segEnd seqnum.Value
+ if seg.data.Size() == 0 {
+ if s.writeList.Back() != seg {
+ panic("FIN segments must be the final segment in the write list.")
+ }
+ seg.flags = header.TCPFlagAck | header.TCPFlagFin
+ segEnd = seg.sequenceNumber.Add(1)
+ } else {
+ // We're sending a non-FIN segment.
+ if seg.flags&header.TCPFlagFin != 0 {
+ panic("Netstack queues FIN segments without data.")
}
- if !dataSent {
- dataSent = true
- // We are sending data, so we should stop the keepalive timer to
- // ensure that no keepalives are sent while there is pending data.
- s.ep.disableKeepaliveTimer()
+ if !seg.sequenceNumber.LessThan(end) {
+ return false
+ }
+
+ available := int(seg.sequenceNumber.Size(end))
+ if available == 0 {
+ return false
}
+ if available > limit {
+ available = limit
+ }
+
+ if seg.data.Size() > available {
+ s.splitSeg(seg, available)
+ }
+
+ segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ }
+
+ s.sendSegment(seg)
- if !seg.xmitTime.IsZero() {
- s.ep.stack.Stats().TCP.Retransmits.Increment()
- if s.sndCwnd < s.sndSsthresh {
- s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
+ // Update sndNxt if we actually sent new data (as opposed to
+ // retransmitting some previously sent data).
+ if s.sndNxt.LessThan(segEnd) {
+ s.sndNxt = segEnd
+ }
+
+ return true
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ s.SetPipe()
+ for s.outstanding < s.sndCwnd {
+ nextSeg, s3, s4 := s.NextSeg()
+ if nextSeg == nil {
+ // NextSeg():
+ //
+ // Step (2): "If no sequence number 'S2' per rule (1)
+ // exists but there exists available unsent data and the
+ // receiver's advertised window allows, the sequence
+ // range of one segment of up to SMSS octets of
+ // previously unsent data starting with sequence number
+ // HighData+1 MUST be returned."
+ for seg := s.writeNext; seg != nil; seg = seg.Next() {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ continue
+ }
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
+ nextSeg = seg
+ break
+ }
+ if nextSeg != nil {
+ continue
}
}
+ rescueRtx := false
+ if nextSeg == nil && s3 != nil {
+ nextSeg = s3
+ }
+ if nextSeg == nil && s4 != nil {
+ nextSeg = s4
+ rescueRtx = true
+ }
+ if nextSeg == nil {
+ break
+ }
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ s.fr.highRxt = segEnd - 1
+ }
+
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ s.outstanding++
+ dataSent = true
+ s.sendSegment(nextSeg)
+ }
+ return dataSent
+}
+
+// sendData sends new data segments. It is called when data becomes available or
+// when the send window opens up.
+func (s *sender) sendData() {
+ limit := s.maxPayloadSize
+ if s.gso {
+ limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
+ }
+ end := s.sndUna.Add(s.sndWnd)
+
+ // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
+ // "A TCP SHOULD set cwnd to no more than RW before beginning
+ // transmission if the TCP has not sent data in the interval exceeding
+ // the retrasmission timeout."
+ if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
+ if s.sndCwnd > InitialCwnd {
+ s.sndCwnd = InitialCwnd
+ }
+ }
- seg.xmitTime = time.Now()
- s.sendSegment(seg.data, seg.flags, seg.sequenceNumber)
+ var dataSent bool
- // Update sndNxt if we actually sent new data (as opposed to
- // retransmitting some previously sent data).
- if s.sndNxt.LessThan(segEnd) {
- s.sndNxt = segEnd
+ // RFC 6675 recovery algorithm step C 1-5.
+ if s.fr.active && s.ep.sackPermitted {
+ dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
+ } else {
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ continue
+ }
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
}
}
- // Remember the next segment we'll write.
- s.writeNext = seg
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
// Enable the timer if we have pending data and it's not enabled yet.
if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
@@ -599,91 +803,176 @@ func (s *sender) sendData() {
func (s *sender) enterFastRecovery() {
s.fr.active = true
// Save state to reflect we're now in fast recovery.
+ //
// See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
- // We inflat the cwnd by 3 to account for the 3 packets which triggered
+ // We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
+ if s.ep.sackPermitted {
+ s.ep.stack.Stats().TCP.SACKRecovery.Increment()
+ return
+ }
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
func (s *sender) leaveFastRecovery() {
s.fr.active = false
- s.fr.first = 0
- s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = 0
s.dupAckCount = 0
// Deflate cwnd. It had been artificially inflated when new dups arrived.
s.sndCwnd = s.sndSsthresh
- // As recovery is now complete, delete all SACK information for acked
- // data.
- s.ep.scoreboard.Delete(s.sndUna)
s.cc.PostRecovery()
}
-// checkDuplicateAck is called when an ack is received. It manages the state
-// related to duplicate acks and determines if a retransmit is needed according
-// to the rules in RFC 6582 (NewReno).
-func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
ack := seg.ackNumber
- if s.fr.active {
- // We are in fast recovery mode. Ignore the ack if it's out of
- // range.
- if !ack.InRange(s.sndUna, s.sndNxt+1) {
- return false
- }
+ // We are in fast recovery mode. Ignore the ack if it's out of
+ // range.
+ if !ack.InRange(s.sndUna, s.sndNxt+1) {
+ return false
+ }
- // Leave fast recovery if it acknowledges all the data covered by
- // this fast recovery session.
- if s.fr.last.LessThan(ack) {
- s.leaveFastRecovery()
- return false
- }
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if s.fr.last.LessThan(ack) {
+ s.leaveFastRecovery()
+ return false
+ }
- // Don't count this as a duplicate if it is carrying data or
- // updating the window.
- if seg.logicalLen() != 0 || s.sndWnd != seg.window {
- return false
+ if s.ep.sackPermitted {
+ // When SACK is enabled we let retransmission be governed by
+ // the SACK logic.
+ return false
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if seg.logicalLen() != 0 || s.sndWnd != seg.window {
+ return false
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if ack == s.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if s.sndCwnd < s.fr.maxCwnd {
+ s.sndCwnd++
}
+ return false
+ }
+
+ // A partial ack was received. Retransmit this packet and
+ // remember it so that we don't retransmit it again. We don't
+ // inflate the window because we're putting the same packet back
+ // onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ s.fr.first = ack
+ s.dupAckCount = 0
+ return true
+}
+
+// isAssignedSequenceNumber relies on the fact that we only set flags once a
+// sequencenumber is assigned and that is only done right before we send the
+// segment. As a result any segment that has a non-zero flag has a valid
+// sequence number assigned to it.
+func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
+ return seg.flags != 0
+}
- // Inflate the congestion window if we're getting duplicate acks
- // for the packet we retransmitted.
- if ack == s.fr.first {
- // We received a dup, inflate the congestion window by 1
- // packet if we're not at the max yet.
- if s.sndCwnd < s.fr.maxCwnd {
- s.sndCwnd++
+// SetPipe implements the SetPipe() function described in RFC6675. Netstack
+// maintains the congestion window in number of packets and not bytes, so
+// SetPipe() here measures number of outstanding packets rather than actual
+// outstanding bytes in the network.
+func (s *sender) SetPipe() {
+ // If SACK isn't permitted or it is permitted but recovery is not active
+ // then ignore pipe calculations.
+ if !s.ep.sackPermitted || !s.fr.active {
+ return
+ }
+ pipe := 0
+ smss := seqnum.Size(s.ep.scoreboard.SMSS())
+ for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() {
+ // With GSO each segment can be much larger than SMSS. So check the segment
+ // in SMSS sized ranges.
+ segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size()))
+ for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) {
+ endSeq := startSeq.Add(smss)
+ if segEnd.LessThan(endSeq) {
+ endSeq = segEnd
+ }
+ sb := header.SACKBlock{startSeq, endSeq}
+ // SetPipe():
+ //
+ // After initializing pipe to zero, the following steps are
+ // taken for each octet 'S1' in the sequence space between
+ // HighACK and HighData that has not been SACKed:
+ if !s1.sequenceNumber.LessThan(s.sndNxt) {
+ break
+ }
+ if s.ep.scoreboard.IsSACKED(sb) {
+ continue
+ }
+
+ // SetPipe():
+ //
+ // (a) If IsLost(S1) returns false, Pipe is incremened by 1.
+ //
+ // NOTE: here we mark the whole segment as lost. We do not try
+ // and test every byte in our write buffer as we maintain our
+ // pipe in terms of oustanding packets and not bytes.
+ if !s.ep.scoreboard.IsRangeLost(sb) {
+ pipe++
+ }
+ // SetPipe():
+ // (b) If S1 <= HighRxt, Pipe is incremented by 1.
+ if s1.sequenceNumber.LessThanEq(s.fr.highRxt) {
+ pipe++
}
- return false
}
+ }
+ s.outstanding = pipe
+}
- // A partial ack was received. Retransmit this packet and
- // remember it so that we don't retransmit it again. We don't
- // inflate the window because we're putting the same packet back
- // onto the wire.
- //
- // N.B. The retransmit timer will be reset by the caller.
- s.fr.first = ack
- s.dupAckCount = 0
- return true
+// checkDuplicateAck is called when an ack is received. It manages the state
+// related to duplicate acks and determines if a retransmit is needed according
+// to the rules in RFC 6582 (NewReno).
+func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ if s.fr.active {
+ return s.handleFastRecovery(seg)
}
// We're not in fast recovery yet. A segment is considered a duplicate
// only if it doesn't carry any data and doesn't update the send window,
// because if it does, it wasn't sent in response to an out-of-order
- // segment.
+ // segment. If SACK is enabled then we have an additional check to see
+ // if the segment carries new SACK information. If it does then it is
+ // considered a duplicate ACK as per RFC6675.
if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
- s.dupAckCount = 0
- return false
+ if !s.ep.sackPermitted || !seg.hasNewSACKInfo {
+ s.dupAckCount = 0
+ return false
+ }
}
s.dupAckCount++
- // Do not enter fast recovery until we reach nDupAckThreshold.
- if s.dupAckCount < nDupAckThreshold {
+
+ // Do not enter fast recovery until we reach nDupAckThreshold or the
+ // first unacknowledged byte is considered lost as per SACK scoreboard.
+ if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) {
+ // RFC 6675 Step 3.
+ s.fr.highRxt = s.sndUna - 1
+ // Do run SetPipe() to calculate the outstanding segments.
+ s.SetPipe()
return false
}
@@ -696,7 +985,6 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
s.dupAckCount = 0
return false
}
-
s.cc.HandleNDupAcks()
s.enterFastRecovery()
s.dupAckCount = 0
@@ -737,6 +1025,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
seg.hasNewSACKInfo = true
}
}
+ s.SetPipe()
}
// Count the duplicates and do the fast retransmit if needed.
@@ -749,9 +1038,6 @@ func (s *sender) handleRcvdSegment(seg *segment) {
ack := seg.ackNumber
if (ack - 1).InRange(s.sndUna, s.sndNxt) {
s.dupAckCount = 0
- // When an ack is received we must reset the timer. We stop it
- // here and it will be restarted later if needed.
- s.resendTimer.disable()
// See : https://tools.ietf.org/html/rfc1323#section-3.3.
// Specifically we should only update the RTO using TSEcr if the
@@ -767,6 +1053,11 @@ func (s *sender) handleRcvdSegment(seg *segment) {
elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
s.updateRTO(elapsed)
}
+
+ // When an ack is received we must rearm the timer.
+ // RFC 6298 5.2
+ s.resendTimer.enable(s.rto)
+
// Remove all acknowledged data from the write list.
acked := s.sndUna.Size(ack)
s.sndUna = ack
@@ -792,7 +1083,13 @@ func (s *sender) handleRcvdSegment(seg *segment) {
s.writeNext = seg.Next()
}
s.writeList.Remove(seg)
- s.outstanding -= s.pCount(seg)
+
+ // if SACK is enabled then Only reduce outstanding if
+ // the segment was not previously SACKED as these have
+ // already been accounted for in SetPipe().
+ if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ s.outstanding -= s.pCount(seg)
+ }
seg.decRef()
ackLeft -= datalen
}
@@ -815,8 +1112,16 @@ func (s *sender) handleRcvdSegment(seg *segment) {
if s.outstanding < 0 {
s.outstanding = 0
}
- }
+ s.SetPipe()
+
+ // If all outstanding data was acknowledged the disable the timer.
+ // RFC 6298 Rule 5.3
+ if s.sndUna == s.sndNxt {
+ s.outstanding = 0
+ s.resendTimer.disable()
+ }
+ }
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
if rtx {
@@ -827,12 +1132,26 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- s.sendData()
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ s.sendData()
+ }
}
-// sendSegment sends a new segment containing the given payload, flags and
-// sequence number.
-func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+// sendSegment sends the specified segment.
+func (s *sender) sendSegment(seg *segment) *tcpip.Error {
+ if !seg.xmitTime.IsZero() {
+ s.ep.stack.Stats().TCP.Retransmits.Increment()
+ if s.sndCwnd < s.sndSsthresh {
+ s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
+ }
+ }
+ seg.xmitTime = time.Now()
+ return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
+}
+
+// sendSegmentFromView sends a new segment containing the given payload, flags
+// and sequence number.
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
@@ -843,5 +1162,19 @@ func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum.
// Remember the max sent ack.
s.maxSentAck = rcvNxt
+ // Every time a packet containing data is sent (including a
+ // retransmission), if SACK is enabled then use the conservative timer
+ // described in RFC6675 Section 4.0, otherwise follow the standard time
+ // described in RFC6298 Section 5.2.
+ if data.Size() != 0 {
+ if s.ep.sackPermitted {
+ s.resendTimer.enable(s.rto)
+ } else {
+ if !s.resendTimer.enabled() {
+ s.resendTimer.enable(s.rto)
+ }
+ }
+ }
+
return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index dbfbd5c4f..025d133be 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -16,22 +16,33 @@ package tcp_test
import (
"fmt"
+ "log"
"reflect"
"testing"
+ "time"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
"gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
)
-// createConnectWithSACKPermittedOption creates and connects c.ep with the
+// createConnectedWithSACKPermittedOption creates and connects c.ep with the
// SACKPermitted option enabled if the stack in the context has the SACK support
// enabled.
func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
}
+// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
+// option enabled if the stack in the context has SACK and TS enabled.
+func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
+ return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
+}
+
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
@@ -348,3 +359,206 @@ func TestTrimSackBlockList(t *testing.T) {
}
}
}
+
+func TestSACKRecovery(t *testing.T) {
+ const maxPayload = 10
+ // See: tcp.makeOptions for why tsOptionSize is set to 12 here.
+ const tsOptionSize = 12
+ // Enabling SACK means the payload size is reduced to account
+ // for the extra space required for the TCP options.
+ //
+ // We increase the MTU by 40 bytes to account for SACK and Timestamp
+ // options.
+ const maxTCPOptionSize = 40
+
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ // We use log.Printf instead of t.Logf here because this probe
+ // can fire even when the test function has finished. This is
+ // because closing the endpoint in cleanup() does not mean the
+ // actual worker loop terminates immediately as it still has to
+ // do a full TCP shutdown. But this test can finish running
+ // before the shutdown is done. Using t.Logf in such a case
+ // causes the test to panic due to logging after test finished.
+ log.Printf("state: %+v\n", s)
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Send 3 duplicate acks. This should force an immediate retransmit of
+ // the pending packet and put the sender into fast recovery.
+ rtxOffset := bytesRead - maxPayload*expected
+ start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
+ end := start.Add(10)
+ for i := 0; i < 3; i++ {
+ c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
+ end = end.Add(10)
+ }
+
+ // Receive the retransmitted packet.
+ c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
+
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
+ {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
+ {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ }
+ }
+
+ // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
+ // window inflation and sending of packets is completely handled by the
+ // SACK Recovery algorithm. We should see no packets being released, as
+ // the cwnd at this point after entering recovery should be half of the
+ // outstanding number of packets in flight.
+ for i := 0; i < 7; i++ {
+ c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
+ end = end.Add(10)
+ }
+
+ recover := bytesRead
+
+ // Ensure no new packets arrive.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge half of the pending data. This along with the 10 sacked
+ // segments above should reduce the outstanding below the current
+ // congestion window allowing the sender to transmit data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+
+ // Now send a partial ACK w/ a SACK block that indicates that the next 3
+ // segments are lost and we have received 6 segments after the lost
+ // segments. This should cause the sender to immediately transmit all 3
+ // segments in response to this ACK unlike in FastRecovery where only 1
+ // segment is retransmitted per ACK.
+ start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
+ end = start.Add(60)
+ c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
+
+ // At this point, we acked expected/2 packets and we SACKED 6 packets and
+ // 3 segments were considered lost due to the SACK block we sent.
+ //
+ // So total packets outstanding can be calculated as follows after 7
+ // iterations of slow start -> 10/20/40/80/160/320/640. So expected
+ // should be 640 at start, then we went to recover at which point the
+ // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the
+ // network).
+ // Outstanding at this point after acking half the window
+ // (320 packets) will be:
+ // outstanding = 640-320-6(due to SACK block)-3 = 311
+ //
+ // The last 3 is due to the fact that the first 3 packets after
+ // rtxOffset will be considered lost due to the SACK blocks sent.
+ // Receive the retransmit due to partial ack.
+
+ c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
+ // Receive the 2 extra packets that should have been retransmitted as
+ // those should be considered lost and immediately retransmitted based
+ // on the SACK information in the previous ACK sent above.
+ for i := 0; i < 2; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize)
+ }
+
+ // Now we should get 9 more new unsent packets as the cwnd is 323 and
+ // outstanding is 311.
+ for i := 0; i < 9; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ // In SACK recovery only the first segment is fast retransmitted when
+ // entering recovery.
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
+ t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ }
+
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
+
+ // Acknowledge all pending data to recover point.
+ c.SendAck(790, recover)
+
+ // At this point, the cwnd should reset to expected/2 and there are 9
+ // packets outstanding.
+ //
+ // Now in the first iteration since there are 9 packets outstanding.
+ // We would expect to get expected/2 - 9 packets. But subsequent
+ // iterations will send us expected/2 + 1 (per iteration).
+ expected = expected/2 - 9
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ if i == 0 {
+ // After the first iteration we expect to get the full
+ // congestion window worth of packets in every
+ // iteration.
+ expected += 9
+ }
+ expected++
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index a8b290dae..6e3ba5922 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1792,12 +1792,12 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// Receive SYN packet.
b := c.GetPacket()
-
+ mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize)
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -1812,7 +1812,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.TCPFlags(header.TCPFlagSyn),
checker.SrcPort(tcp.SourcePort()),
checker.SeqNum(tcp.SequenceNumber()),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -2737,7 +2737,8 @@ func TestFastRecovery(t *testing.T) {
// A partial ACK during recovery should reduce congestion window by the
// number acked. Since we had "expected" packets outstanding before sending
// partial ack and we acked expected/2 , the cwnd and outstanding should
- // be expected/2 + 7. Which means the sender should not send any more packets
+ // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered
+ // fast recovery). Which means the sender should not send any more packets
// till we ack this one.
c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
50*time.Millisecond)
@@ -2843,7 +2844,7 @@ func TestRetransmit(t *testing.T) {
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
}
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index fa721a7f8..e08eb6533 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -355,13 +355,27 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
// SendAck sends an ACK packet.
func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
+ c.SendAckWithSACK(seq, bytesReceived, nil)
+}
+
+// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified.
+func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) {
+ options := make([]byte, 40)
+ offset := 0
+ if len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
c.SendPacket(nil, &Headers{
SrcPort: TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(testInitialSequenceNumber).Add(1),
+ SeqNum: seq,
AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
RcvWnd: 30000,
+ TCPOpts: options[:offset],
})
}
@@ -369,9 +383,17 @@ func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
// verifies that the packet packet payload of packet matches the slice
// of data indicated by offset & size.
func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
+ c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
+}
+
+// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint
+// and verifies that the packet packet payload of packet matches the slice of
+// data indicated by offset & size and skips optlen bytes in addition to the IP
+// TCP headers when comparing the data.
+func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
b := c.GetPacket()
checker.IPv4(c.t, b,
- checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.PayloadLen(size+header.TCPMinimumSize+optlen),
checker.TCP(
checker.DstPort(TestPort),
checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
@@ -381,7 +403,7 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
)
pdata := data[offset:][:size]
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 {
c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
}
}
@@ -683,12 +705,14 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
b := c.GetPacket()
// Validate that the syn has the timestamp option and a valid
// TS value.
+ mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
+
checker.IPv4(c.t, b,
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagSyn),
checker.TCPSynOptions(header.TCPSynOptions{
- MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize),
+ MSS: mss,
TS: true,
WS: defaultWindowScale,
SACKPermitted: c.SACKEnabled(),