summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAyush Ranjan <ayushranjan@google.com>2021-02-11 17:17:00 -0800
committergVisor bot <gvisor-bot@google.com>2021-02-11 17:19:47 -0800
commit91cf7b3ca48b9388c9058962ffe798c37e951990 (patch)
treec96f1153e83a6849ede78c8477f95241f1d26da2
parent4314bb0b2b96cc3a84e8dead29812ccb1bfcebe2 (diff)
[netstack] Fix recovery entry and exit checks.
Entry check: - Earlier implementation was preventing us from entering recovery even if SND.UNA is lost but dupAckCount is still below threshold. Fixed that. - We should only enter recovery when at least one more byte of data beyond the highest byte that was outstanding when fast retransmit was last entered is acked. Added that check. Exit check: - Earlier we were checking if SEG.ACK is in range [SND.UNA, SND.NXT]. The intention was to check if any unacknowledged data was ACKed. Note that (SEG.ACK - 1) is actually the sequence number which was ACKed. So we were incorrectly including (SND.UNA - 1) in the range. Fixed the check to now be (SEG.ACK - 1) in range [SND.UNA, SND.NXT). Additionally, moved a RACK specific test to the rack tests file. Added tests for the changes I made. PiperOrigin-RevId: 357091322
-rw-r--r--pkg/tcpip/transport/tcp/snd.go18
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go64
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go128
3 files changed, 166 insertions, 44 deletions
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 37f064aaf..7911e6b85 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -1131,6 +1131,14 @@ func (s *sender) SetPipe() {
s.outstanding = pipe
}
+// shouldEnterRecovery returns true if the sender should enter fast recovery
+// based on dupAck count and sack scoreboard.
+// See RFC 6675 section 5.
+func (s *sender) shouldEnterRecovery() bool {
+ return s.dupAckCount >= nDupAckThreshold ||
+ (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna))
+}
+
// detectLoss is called when an ack is received and returns whether a loss is
// detected. It manages the state related to duplicate acks and determines if
// a retransmit is needed according to the rules in RFC 6582 (NewReno).
@@ -1155,7 +1163,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// Do not enter fast recovery until we reach nDupAckThreshold or the
// first unacknowledged byte is considered lost as per SACK scoreboard.
- if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) {
+ if !s.shouldEnterRecovery() {
// RFC 6675 Step 3.
s.fr.highRxt = s.sndUna - 1
// Do run SetPipe() to calculate the outstanding segments.
@@ -1169,7 +1177,11 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// We only do the check here, the incrementing of last to the highest
// sequence number transmitted till now is done when enterRecovery
// is invoked.
- if !s.fr.last.LessThan(seg.ackNumber) {
+ //
+ // Note that we only enter recovery when at least one more byte of data
+ // beyond s.fr.last (the highest byte that was outstanding when fast
+ // retransmit was last entered) is acked.
+ if !s.fr.last.LessThan(seg.ackNumber - 1) {
s.dupAckCount = 0
return false
}
@@ -1351,7 +1363,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
if s.fr.active {
// Leave fast recovery if it acknowledges all the data covered by
// this fast recovery session.
- if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) {
+ if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) {
s.leaveRecovery()
}
} else {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index b397bb7ff..92f0f6cee 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -159,9 +159,11 @@ func TestRACKDetectReorder(t *testing.T) {
<-probeDone
}
-func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte {
+func sendAndReceive(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte {
setStackSACKPermitted(t, c, true)
- setStackRACKPermitted(t, c)
+ if enableRACK {
+ setStackRACKPermitted(t, c)
+ }
createConnectedWithSACKAndTS(c)
data := make([]byte, numPackets*maxPayload)
@@ -222,7 +224,7 @@ func TestRACKDetectDSACK(t *testing.T) {
addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
numPackets := 8
- data := sendAndReceive(t, c, numPackets)
+ data := sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Cumulative ACK for [1-5] packets.
seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
@@ -263,7 +265,7 @@ func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) {
addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
numPackets := 10
- data := sendAndReceive(t, c, numPackets)
+ data := sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Cumulative ACK for [1-5] packets.
seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
@@ -308,7 +310,7 @@ func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) {
addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
numPackets := 10
- sendAndReceive(t, c, numPackets)
+ sendAndReceive(t, c, numPackets, true /* enableRACK */)
// ACK [1-5] packets.
seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
@@ -352,7 +354,7 @@ func TestRACKDetectDSACKSingleDup(t *testing.T) {
addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
numPackets := 4
- data := sendAndReceive(t, c, numPackets)
+ data := sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Send ACK for #1 packet.
bytesRead := maxPayload
@@ -401,7 +403,7 @@ func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) {
addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
numPackets := 6
- data := sendAndReceive(t, c, numPackets)
+ data := sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Send ACK for #1 packet.
bytesRead := maxPayload
@@ -455,7 +457,7 @@ func TestRACKDetectDSACKDup(t *testing.T) {
addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
numPackets := 7
- data := sendAndReceive(t, c, numPackets)
+ data := sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Send ACK for #1 packet.
bytesRead := maxPayload
@@ -523,7 +525,7 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) {
})
numPackets := 10
- data := sendAndReceive(t, c, numPackets)
+ data := sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Cumulative ACK for [1-5] packets.
seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
@@ -586,7 +588,7 @@ func TestRACKCheckReorderWindow(t *testing.T) {
addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone)
const numPackets = 7
- sendAndReceive(t, c, numPackets)
+ sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Send ACK for #1 packet.
bytesRead := maxPayload
@@ -615,7 +617,7 @@ func TestRACKWithDuplicateACK(t *testing.T) {
defer c.Cleanup()
const numPackets = 4
- data := sendAndReceive(t, c, numPackets)
+ data := sendAndReceive(t, c, numPackets, true /* enableRACK */)
// Send three duplicate ACKs to trigger fast recovery. The first
// segment is considered as lost and will be retransmitted after
@@ -654,3 +656,43 @@ func TestRACKWithDuplicateACK(t *testing.T) {
t.Error(err)
}
}
+
+// TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK
+// is received.
+func TestRACKUpdateSackedOut(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ ackNum := 0
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.SackedOut is what we expect.
+ if state.Sender.SackedOut != 2 && ackNum == 0 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
+ }
+
+ if state.Sender.SackedOut != 0 && ackNum == 1 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
+ }
+ if ackNum > 0 {
+ close(probeDone)
+ }
+ ackNum++
+ })
+
+ sendAndReceive(t, c, 8, true /* enableRACK */)
+
+ // ACK for [3-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
+ bytesRead := 2 * maxPayload
+ end := start.Add(seqnum.Size(bytesRead))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ bytesRead += 3 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 5024bc925..f15d0a2d1 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -593,44 +593,112 @@ func TestSACKRecovery(t *testing.T) {
}
}
-// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK
-// is received.
-func TestSACKUpdateSackedOut(t *testing.T) {
+// TestRecoveryEntry tests the following two properties of entering recovery:
+// - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK
+// scoreboard but dupack count is still below threshold.
+// - Only enter recovery when at least one more byte of data beyond the highest
+// byte that was outstanding when fast retransmit was last entered is acked.
+func TestRecoveryEntry(t *testing.T) {
c := context.New(t, uint32(mtu))
defer c.Cleanup()
- probeDone := make(chan struct{})
- ackNum := 0
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint Sender.SackedOut is what we expect.
- if state.Sender.SackedOut != 2 && ackNum == 0 {
- t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
- }
+ numPackets := 5
+ data := sendAndReceive(t, c, numPackets, false /* enableRACK */)
+
+ // Ack #1 packet.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, maxPayload)
+
+ // Now SACK #3, #4 and #5 packets. This will simulate a situation where
+ // SND.UNA should be considered lost and the sender should enter fast recovery
+ // (even though dupack count is still below threshold).
+ p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
+ p3End := p3Start.Add(maxPayload)
+ p4Start := p3End
+ p4End := p4Start.Add(maxPayload)
+ p5Start := p4End
+ p5End := p5Start.Add(maxPayload)
+ c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}})
+
+ // Expect #2 to be retransmitted.
+ c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize)
- if state.Sender.SackedOut != 0 && ackNum == 1 {
- t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ // SACK recovery must have happened.
+ {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
+ // #2 was retransmitted.
+ {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
+ // No RTOs should have fired yet.
+ {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
}
- if ackNum > 0 {
- close(probeDone)
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
}
- ackNum++
- })
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
- sendAndReceive(t, c, 8)
+ // Send 4 more packets.
+ var r bytes.Reader
+ data = append(data, data...)
+ r.Reset(data[5*maxPayload : 9*maxPayload])
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
- // ACK for [3-5] packets.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
- bytesRead := 2 * maxPayload
- end := start.Add(seqnum.Size(bytesRead))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+ var sackBlocks []header.SACKBlock
+ bytesRead := numPackets * maxPayload
+ for i := 0; i < 4; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ if i > 0 {
+ pStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
+ sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)})
+ c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks)
+ }
+ bytesRead += maxPayload
+ }
- bytesRead += 3 * maxPayload
- c.SendAck(seq, bytesRead)
+ // #6 should be retransmitted after RTO. The sender should NOT enter fast
+ // recovery because the highest byte that was outstanding when fast recovery
+ // was last entered is #5 packet's end. And the sender requires at least one
+ // more byte beyond that (#6 packet start) to be acked to enter recovery.
+ c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize)
+ c.SendAck(seq, 9*maxPayload)
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- <-probeDone
+ metricPollFn = func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ // Only 1 SACK recovery must have happened.
+ {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
+ // #2 and #6 were retransmitted.
+ {tcpStats.Retransmits, "stats.TCP.Retransmits", 2},
+ // RTO should have fired once.
+ {tcpStats.Timeouts, "stats.TCP.Timeouts", 1},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
}