summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go230
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go9
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/reno.go5
-rw-r--r--pkg/tcpip/transport/tcp/snd.go46
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go150
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go47
8 files changed, 478 insertions, 19 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 9ebae6cc7..8b911c295 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -19,6 +19,7 @@ go_library(
srcs = [
"accept.go",
"connect.go",
+ "cubic.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
new file mode 100644
index 000000000..cdb85598d
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -0,0 +1,230 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "time"
+)
+
+// cubicState stores the variables related to TCP CUBIC congestion
+// control algorithm state.
+//
+// See: https://tools.ietf.org/html/rfc8312.
+type cubicState struct {
+ // wLastMax is the previous wMax value.
+ wLastMax float64
+
+ // wMax is the value of the congestion window at the
+ // time of last congestion event.
+ wMax float64
+
+ // t denotes the time when the current congestion avoidance
+ // was entered.
+ t time.Time
+
+ // numCongestionEvents tracks the number of congestion events since last
+ // RTO.
+ numCongestionEvents int
+
+ // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
+ // per RFC.
+ c float64
+
+ // k is the time period that the above function takes to increase the
+ // current window size to W_max if there are no further congestion
+ // events and is calculated using the following equation:
+ //
+ // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
+ k float64
+
+ // beta is the CUBIC multiplication decrease factor. that is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // W_cubic(0)=W_max*beta_cubic.
+ beta float64
+
+ // wC is window computed by CUBIC at time t. It's calculated using the
+ // formula:
+ //
+ // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
+ wC float64
+
+ // wEst is the window computed by CUBIC at time t+RTT i.e
+ // W_cubic(t+RTT).
+ wEst float64
+
+ s *sender
+}
+
+// newCubicCC returns a partially initialized cubic state with the constants
+// beta and c set and t set to current time.
+func newCubicCC(s *sender) *cubicState {
+ return &cubicState{
+ t: time.Now(),
+ beta: 0.7,
+ c: 0.4,
+ s: s,
+ }
+}
+
+// enterCongestionAvoidance is used to initialize cubic in cases where we exit
+// SlowStart without a real congestion event taking place. This can happen when
+// a connection goes back to slow start due to a retransmit and we exceed the
+// previously lowered ssThresh without experiencing packet loss.
+//
+// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
+func (c *cubicState) enterCongestionAvoidance() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.7 &
+ // https://tools.ietf.org/html/rfc8312#section-4.8
+ if c.numCongestionEvents == 0 {
+ c.k = 0
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+ }
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window we cross
+// the ssThresh then it will return the number of packets that must be consumed
+// in congestion avoidance mode.
+func (c *cubicState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := c.s.sndCwnd + packetsAcked
+ enterCA := false
+ if newcwnd >= c.s.sndSsthresh {
+ newcwnd = c.s.sndSsthresh
+ c.s.sndCAAckCount = 0
+ enterCA = true
+ }
+
+ packetsAcked -= newcwnd - c.s.sndCwnd
+ c.s.sndCwnd = newcwnd
+ if enterCA {
+ c.enterCongestionAvoidance()
+ }
+ return packetsAcked
+}
+
+// Update updates cubic's internal state variables. It must be called on every
+// ACK received.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) Update(packetsAcked int) {
+ if c.s.sndCwnd < c.s.sndSsthresh {
+ packetsAcked = c.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ } else {
+ c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, c.s.srtt)
+ }
+}
+
+// cubicCwnd computes the CUBIC congestion window after t seconds from last
+// congestion event.
+func (c *cubicState) cubicCwnd(t float64) float64 {
+ return c.c*math.Pow(t, 3.0) + c.wMax
+}
+
+// getCwnd returns the current congestion window as computed by CUBIC.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
+ elapsed := time.Since(c.t).Seconds()
+
+ // Compute the window as per Cubic after 'elapsed' time
+ // since last congestion event.
+ c.wC = c.cubicCwnd(elapsed - c.k)
+
+ // Compute the TCP friendly estimate of the congestion window.
+ c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+
+ // Make sure in the TCP friendly region CUBIC performs at least
+ // as well as Reno.
+ if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ // TCP Friendly region of cubic.
+ return int(c.wEst)
+ }
+
+ // In Concave/Convex region of CUBIC, calculate what CUBIC window
+ // will be after 1 RTT and use that to grow congestion window
+ // for every ack.
+ tEst := (time.Since(c.t) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.k)
+ // As per 4.3 for each received ACK cwnd must be incremented
+ // by (w_cubic(t+RTT) - cwnd/cwnd.
+ cwnd := float64(sndCwnd)
+ for i := 0; i < packetsAcked; i++ {
+ // Concave/Convex regions of cubic have the same formulas.
+ // See: https://tools.ietf.org/html/rfc8312#section-4.3
+ cwnd += (wtRtt - cwnd) / cwnd
+ }
+ return int(cwnd)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (c *cubicState) HandleNDupAcks() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.5
+ c.numCongestionEvents++
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+ c.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
+func (c *cubicState) HandleRTOExpired() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.6
+ c.t = time.Now()
+ c.numCongestionEvents = 0
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+
+ // We lost a packet, so reduce ssthresh.
+ c.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ c.s.sndCwnd = 1
+}
+
+// fastConvergence implements the logic for Fast Convergence algorithm as
+// described in https://tools.ietf.org/html/rfc8312#section-4.6.
+func (c *cubicState) fastConvergence() {
+ if c.wMax < c.wLastMax {
+ c.wLastMax = c.wMax
+ c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ } else {
+ c.wLastMax = c.wMax
+ }
+ // Recompute k as wMax may have changed.
+ c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+}
+
+// PostRecovery implemements congestionControl.PostRecovery.
+func (c *cubicState) PostRecovery() {
+ c.t = time.Now()
+}
+
+// reduceSlowStartThreshold returns new SsThresh as described in
+// https://tools.ietf.org/html/rfc8312#section-4.7.
+func (c *cubicState) reduceSlowStartThreshold() {
+ c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index de1883d84..8bfb68f91 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -187,6 +187,10 @@ type endpoint struct {
sndWaker sleep.Waker `state:"manual"`
sndCloseWaker sleep.Waker `state:"manual"`
+ // cc stores the name of the Congestion Control algorithm to use for
+ // this endpoint.
+ cc CongestionControlOption
+
// The following are used when a "packet too big" control packet is
// received. They are protected by sndBufMu. They are used to
// communicate to the main protocol goroutine how many such control
@@ -254,6 +258,11 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
e.rcvBufSize = rs.Default
}
+ var cs CongestionControlOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ e.cc = cs
+ }
+
if p := stack.GetTCPProbe(); p != nil {
e.probe = p
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index cbe0e564e..194d3f41d 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -69,6 +69,11 @@ type ReceiveBufferSizeOption struct {
Max int
}
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
// CongestionControlOption sets the current congestion control algorithm.
type CongestionControlOption string
@@ -227,8 +232,8 @@ func init() {
return &protocol{
sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
- congestionControl: "reno",
- availableCongestionControl: []string{"reno"},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
}
})
}
diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go
index 03ae8d747..feb593234 100644
--- a/pkg/tcpip/transport/tcp/reno.go
+++ b/pkg/tcpip/transport/tcp/reno.go
@@ -96,3 +96,8 @@ func (r *renoState) HandleRTOExpired() {
// initial congestion window.
r.s.sndCwnd = 1
}
+
+// PostRecovery implements congestionControl.PostRecovery.
+func (r *renoState) PostRecovery() {
+ // noop.
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 376e81846..568bd7024 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -51,6 +51,11 @@ type congestionControl interface {
// number of packet's that were acked by the most recent cumulative
// acknowledgement.
Update(packetsAcked int)
+
+ // PostRecovery is invoked when the sender is exiting a fast retransmit/
+ // recovery phase. This provides congestion control algorithms a way
+ // to adjust their state when exiting recovery.
+ PostRecovery()
}
// sender holds the state necessary to send TCP segments.
@@ -174,7 +179,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
},
}
- s.cc = newRenoCC(s)
+ s.cc = s.initCongestionControl(ep.cc)
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
@@ -189,6 +194,17 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
return s
}
+func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+ switch congestionControlName {
+ case ccCubic:
+ return newCubicCC(s)
+ case ccReno:
+ fallthrough
+ default:
+ return newRenoCC(s)
+ }
+}
+
// updateMaxPayloadSize updates the maximum payload size based on the given
// MTU. If this is in response to "packet too big" control packets (indicated
// by the count argument), it also reduces the number of outstanding packets and
@@ -409,6 +425,7 @@ func (s *sender) sendData() {
}
func (s *sender) enterFastRecovery() {
+ s.fr.active = true
// Save state to reflect we're now in fast recovery.
// See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
// We inflat the cwnd by 3 to account for the 3 packets which triggered
@@ -417,7 +434,6 @@ func (s *sender) enterFastRecovery() {
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
- s.fr.active = true
}
func (s *sender) leaveFastRecovery() {
@@ -429,12 +445,13 @@ func (s *sender) leaveFastRecovery() {
// Deflate cwnd. It had been artificially inflated when new dups arrived.
s.sndCwnd = s.sndSsthresh
+ s.cc.PostRecovery()
}
// checkDuplicateAck is called when an ack is received. It manages the state
// related to duplicate acks and determines if a retransmit is needed according
// to the rules in RFC 6582 (NewReno).
-func (s *sender) checkDuplicateAck(seg *segment) bool {
+func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
ack := seg.ackNumber
if s.fr.active {
// We are in fast recovery mode. Ignore the ack if it's out of
@@ -474,6 +491,7 @@ func (s *sender) checkDuplicateAck(seg *segment) bool {
//
// N.B. The retransmit timer will be reset by the caller.
s.fr.first = ack
+ s.dupAckCount = 0
return true
}
@@ -508,16 +526,11 @@ func (s *sender) checkDuplicateAck(seg *segment) bool {
return true
}
-// updateCwnd updates the congestion window based on the number of packets that
-// were acknowledged.
-func (s *sender) updateCwnd(packetsAcked int) {
-}
-
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(seg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ if !s.ep.sendTSOk && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
@@ -534,10 +547,25 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// Ignore ack if it doesn't acknowledge any new data.
ack := seg.ackNumber
if (ack - 1).InRange(s.sndUna, s.sndNxt) {
+ s.dupAckCount = 0
// When an ack is received we must reset the timer. We stop it
// here and it will be restarted later if needed.
s.resendTimer.disable()
+ // See : https://tools.ietf.org/html/rfc1323#section-3.3.
+ // Specifically we should only update the RTO using TSEcr if the
+ // following condition holds:
+ //
+ // A TSecr value received in a segment is used to update the
+ // averaged RTT measurement only if the segment acknowledges
+ // some new data, i.e., only if it advances the left edge of
+ // the send window.
+ if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ // TSVal/Ecr values sent by Netstack are at a millisecond
+ // granularity.
+ elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ s.updateRTO(elapsed)
+ }
// Remove all acknowledged data from the write list.
acked := s.sndUna.Size(ack)
s.sndUna = ack
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 45ebca5b1..11410b050 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -17,6 +17,7 @@ package tcp_test
import (
"bytes"
"fmt"
+ "math"
"testing"
"time"
@@ -2005,7 +2006,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Check we don't receive any more packets on this iteration.
// The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
}
// Don't acknowledge the first packet of the last packet train. Let's
@@ -2043,7 +2044,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Check we don't receive any more packets on this iteration.
// The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
// Acknowledge all the data received so far.
c.SendAck(790, bytesRead)
@@ -2054,6 +2055,130 @@ func TestCongestionAvoidance(t *testing.T) {
}
}
+// cubicCwnd returns an estimate of a cubic window given the
+// originalCwnd, wMax, last congestion event time and sRTT.
+func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
+ cwnd := float64(origCwnd)
+ // We wait 50ms between each iteration so sRTT as computed by cubic
+ // should be close to 50ms.
+ elapsed := (time.Since(congEventTime) + sRTT).Seconds()
+ k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
+ wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
+ cwnd += (wtRTT - cwnd) / cwnd
+ return int(cwnd)
+}
+
+func TestCubicCongestionAvoidance(t *testing.T) {
+ maxPayload := 10
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ enableCUBIC(t, c)
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
+ }
+
+ // Don't acknowledge the first packet of the last packet train. Let's
+ // wait for them to time out, which will trigger a restart of slow
+ // start, and initialization of ssthresh to cwnd * 0.7.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge all pending data.
+ c.SendAck(790, bytesRead)
+
+ // Store away the time we sent the ACK and assuming a 200ms RTO
+ // we estimate that the sender will have an RTO 200ms from now
+ // and go back into slow start.
+ packetDropTime := time.Now().Add(200 * time.Millisecond)
+
+ // This part is tricky: when the timeout happened, we had "expected"
+ // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
+ // By acknowledging "expected" packets, the slow-start part will
+ // increase cwnd to expected/2 essentially putting the connection
+ // straight into congestion avoidance.
+ wMax := expected
+ // Lower expected as per cubic spec after a congestion event.
+ expected = int(float64(expected) * 0.7)
+ cwnd := expected
+ for i := 0; i < iterations; i++ {
+ // Cubic grows window independent of ACKs. Cubic Window growth
+ // is a function of time elapsed since last congestion event.
+ // As a result the congestion window does not grow
+ // deterministically in response to ACKs.
+ //
+ // We need to roughly estimate what the cwnd of the sender is
+ // based on when we sent the dupacks.
+ cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
+
+ packetsExpected := cwnd
+ for j := 0; j < packetsExpected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+ t.Logf("expected packets received, next trying to receive any extra packets that may come")
+
+ // If our estimate was correct there should be no more pending packets.
+ // We attempt to read a packet a few times with a short sleep in between
+ // to ensure that we don't see the sender send any unexpected packets.
+ packetsUnexpected := 0
+ for {
+ gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
+ if !gotPacket {
+ break
+ }
+ bytesRead += maxPayload
+ packetsUnexpected++
+ time.Sleep(1 * time.Millisecond)
+ }
+ if packetsUnexpected != 0 {
+ t.Fatalf("received %d unexpected packets for iteration %d", packetsUnexpected, i)
+ }
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+ }
+}
+
func TestFastRecovery(t *testing.T) {
maxPayload := 10
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
@@ -2864,8 +2989,9 @@ func TestSetCongestionControl(t *testing.T) {
mustPass bool
}{
{"reno", true},
- {"cubic", false},
+ {"cubic", true},
}
+
for _, tc := range testCases {
t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
c := context.New(t, 1500)
@@ -2881,7 +3007,7 @@ func TestSetCongestionControl(t *testing.T) {
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tcp.CongestionControlOption("reno"); got != want {
+ if got, want := cc, tc.cc; got != want {
t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want)
}
})
@@ -2899,7 +3025,7 @@ func TestAvailableCongestionControl(t *testing.T) {
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcp.AvailableCongestionControlOption("reno"); got != want {
+ if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
t.Fatalf("unexpected value for AvailableCongestionControlOption: got: %v, want: %v", got, want)
}
}
@@ -2917,11 +3043,19 @@ func TestSetAvailableCongestionControl(t *testing.T) {
}
// Verify that we still get the expected list of congestion control options.
- var cc tcp.CongestionControlOption
+ var cc tcp.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tcp.CongestionControlOption("reno"); got != want {
- t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want)
+ if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("unexpected value for available congestion control got: %v, want: %v", got, want)
+ }
+}
+
+func enableCUBIC(t *testing.T, c *context.Context) {
+ t.Helper()
+ opt := tcp.CongestionControlOption("cubic")
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e44979527..6b5786140 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -242,6 +242,27 @@ func (c *Context) GetPacket() []byte {
return nil
}
+// GetPacketNonBlocking reads a packet from the link layer endpoint
+// and verifies that it is an IPv4 packet with the expected source
+// and destination address. If no packet is available it will return
+// nil immediately.
+func (c *Context) GetPacketNonBlocking() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
+ default:
+ return nil
+ }
+}
+
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
@@ -355,6 +376,32 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
}
}
+// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
+// and verifies that the packet packet payload of packet matches the slice of
+// data indicated by offset & size. It returns true if a packet was received and
+// processed.
+func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
+ b := c.GetPacketNonBlocking()
+ if b == nil {
+ return false
+ }
+ checker.IPv4(c.t, b,
+ checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ pdata := data[offset:][:size]
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
+ c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
+ }
+ return true
+}
+
// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
// only endpoint instead of a default dual stack socket.