diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_rack_test.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_sack_test.go | 128 |
3 files changed, 166 insertions, 44 deletions
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 37f064aaf..7911e6b85 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1131,6 +1131,14 @@ func (s *sender) SetPipe() { s.outstanding = pipe } +// shouldEnterRecovery returns true if the sender should enter fast recovery +// based on dupAck count and sack scoreboard. +// See RFC 6675 section 5. +func (s *sender) shouldEnterRecovery() bool { + return s.dupAckCount >= nDupAckThreshold || + (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna)) +} + // detectLoss is called when an ack is received and returns whether a loss is // detected. It manages the state related to duplicate acks and determines if // a retransmit is needed according to the rules in RFC 6582 (NewReno). @@ -1155,7 +1163,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // Do not enter fast recovery until we reach nDupAckThreshold or the // first unacknowledged byte is considered lost as per SACK scoreboard. - if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) { + if !s.shouldEnterRecovery() { // RFC 6675 Step 3. s.fr.highRxt = s.sndUna - 1 // Do run SetPipe() to calculate the outstanding segments. @@ -1169,7 +1177,11 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // We only do the check here, the incrementing of last to the highest // sequence number transmitted till now is done when enterRecovery // is invoked. - if !s.fr.last.LessThan(seg.ackNumber) { + // + // Note that we only enter recovery when at least one more byte of data + // beyond s.fr.last (the highest byte that was outstanding when fast + // retransmit was last entered) is acked. + if !s.fr.last.LessThan(seg.ackNumber - 1) { s.dupAckCount = 0 return false } @@ -1351,7 +1363,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if s.fr.active { // Leave fast recovery if it acknowledges all the data covered by // this fast recovery session. - if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) { + if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) { s.leaveRecovery() } } else { diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index b397bb7ff..92f0f6cee 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -159,9 +159,11 @@ func TestRACKDetectReorder(t *testing.T) { <-probeDone } -func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte { +func sendAndReceive(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte { setStackSACKPermitted(t, c, true) - setStackRACKPermitted(t, c) + if enableRACK { + setStackRACKPermitted(t, c) + } createConnectedWithSACKAndTS(c) data := make([]byte, numPackets*maxPayload) @@ -222,7 +224,7 @@ func TestRACKDetectDSACK(t *testing.T) { addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) numPackets := 8 - data := sendAndReceive(t, c, numPackets) + data := sendAndReceive(t, c, numPackets, true /* enableRACK */) // Cumulative ACK for [1-5] packets. seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) @@ -263,7 +265,7 @@ func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) numPackets := 10 - data := sendAndReceive(t, c, numPackets) + data := sendAndReceive(t, c, numPackets, true /* enableRACK */) // Cumulative ACK for [1-5] packets. seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) @@ -308,7 +310,7 @@ func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) numPackets := 10 - sendAndReceive(t, c, numPackets) + sendAndReceive(t, c, numPackets, true /* enableRACK */) // ACK [1-5] packets. seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) @@ -352,7 +354,7 @@ func TestRACKDetectDSACKSingleDup(t *testing.T) { addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) numPackets := 4 - data := sendAndReceive(t, c, numPackets) + data := sendAndReceive(t, c, numPackets, true /* enableRACK */) // Send ACK for #1 packet. bytesRead := maxPayload @@ -401,7 +403,7 @@ func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) numPackets := 6 - data := sendAndReceive(t, c, numPackets) + data := sendAndReceive(t, c, numPackets, true /* enableRACK */) // Send ACK for #1 packet. bytesRead := maxPayload @@ -455,7 +457,7 @@ func TestRACKDetectDSACKDup(t *testing.T) { addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) numPackets := 7 - data := sendAndReceive(t, c, numPackets) + data := sendAndReceive(t, c, numPackets, true /* enableRACK */) // Send ACK for #1 packet. bytesRead := maxPayload @@ -523,7 +525,7 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) { }) numPackets := 10 - data := sendAndReceive(t, c, numPackets) + data := sendAndReceive(t, c, numPackets, true /* enableRACK */) // Cumulative ACK for [1-5] packets. seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) @@ -586,7 +588,7 @@ func TestRACKCheckReorderWindow(t *testing.T) { addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) const numPackets = 7 - sendAndReceive(t, c, numPackets) + sendAndReceive(t, c, numPackets, true /* enableRACK */) // Send ACK for #1 packet. bytesRead := maxPayload @@ -615,7 +617,7 @@ func TestRACKWithDuplicateACK(t *testing.T) { defer c.Cleanup() const numPackets = 4 - data := sendAndReceive(t, c, numPackets) + data := sendAndReceive(t, c, numPackets, true /* enableRACK */) // Send three duplicate ACKs to trigger fast recovery. The first // segment is considered as lost and will be retransmitted after @@ -654,3 +656,43 @@ func TestRACKWithDuplicateACK(t *testing.T) { t.Error(err) } } + +// TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK +// is received. +func TestRACKUpdateSackedOut(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + ackNum := 0 + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.SackedOut is what we expect. + if state.Sender.SackedOut != 2 && ackNum == 0 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) + } + + if state.Sender.SackedOut != 0 && ackNum == 1 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) + } + if ackNum > 0 { + close(probeDone) + } + ackNum++ + }) + + sendAndReceive(t, c, 8, true /* enableRACK */) + + // ACK for [3-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) + bytesRead := 2 * maxPayload + end := start.Add(seqnum.Size(bytesRead)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + bytesRead += 3 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 5024bc925..f15d0a2d1 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -593,44 +593,112 @@ func TestSACKRecovery(t *testing.T) { } } -// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK -// is received. -func TestSACKUpdateSackedOut(t *testing.T) { +// TestRecoveryEntry tests the following two properties of entering recovery: +// - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK +// scoreboard but dupack count is still below threshold. +// - Only enter recovery when at least one more byte of data beyond the highest +// byte that was outstanding when fast retransmit was last entered is acked. +func TestRecoveryEntry(t *testing.T) { c := context.New(t, uint32(mtu)) defer c.Cleanup() - probeDone := make(chan struct{}) - ackNum := 0 - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint Sender.SackedOut is what we expect. - if state.Sender.SackedOut != 2 && ackNum == 0 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) - } + numPackets := 5 + data := sendAndReceive(t, c, numPackets, false /* enableRACK */) + + // Ack #1 packet. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, maxPayload) + + // Now SACK #3, #4 and #5 packets. This will simulate a situation where + // SND.UNA should be considered lost and the sender should enter fast recovery + // (even though dupack count is still below threshold). + p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + p3End := p3Start.Add(maxPayload) + p4Start := p3End + p4End := p4Start.Add(maxPayload) + p5Start := p4End + p5End := p5Start.Add(maxPayload) + c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}}) + + // Expect #2 to be retransmitted. + c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize) - if state.Sender.SackedOut != 0 && ackNum == 1 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // SACK recovery must have happened. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + // #2 was retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + // No RTOs should have fired yet. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, } - if ackNum > 0 { - close(probeDone) + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } } - ackNum++ - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } - sendAndReceive(t, c, 8) + // Send 4 more packets. + var r bytes.Reader + data = append(data, data...) + r.Reset(data[5*maxPayload : 9*maxPayload]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // ACK for [3-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) - bytesRead := 2 * maxPayload - end := start.Add(seqnum.Size(bytesRead)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + var sackBlocks []header.SACKBlock + bytesRead := numPackets * maxPayload + for i := 0; i < 4; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + if i > 0 { + pStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) + sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)}) + c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks) + } + bytesRead += maxPayload + } - bytesRead += 3 * maxPayload - c.SendAck(seq, bytesRead) + // #6 should be retransmitted after RTO. The sender should NOT enter fast + // recovery because the highest byte that was outstanding when fast recovery + // was last entered is #5 packet's end. And the sender requires at least one + // more byte beyond that (#6 packet start) to be acked to enter recovery. + c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) + c.SendAck(seq, 9*maxPayload) - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone + metricPollFn = func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // Only 1 SACK recovery must have happened. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + // #2 and #6 were retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 2}, + // RTO should have fired once. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } } |