diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/cubic.go | 230 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/reno.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 46 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 150 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 47 |
8 files changed, 478 insertions, 19 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 9ebae6cc7..8b911c295 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -19,6 +19,7 @@ go_library( srcs = [ "accept.go", "connect.go", + "cubic.go", "endpoint.go", "endpoint_state.go", "forwarder.go", diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go new file mode 100644 index 000000000..cdb85598d --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -0,0 +1,230 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "math" + "time" +) + +// cubicState stores the variables related to TCP CUBIC congestion +// control algorithm state. +// +// See: https://tools.ietf.org/html/rfc8312. +type cubicState struct { + // wLastMax is the previous wMax value. + wLastMax float64 + + // wMax is the value of the congestion window at the + // time of last congestion event. + wMax float64 + + // t denotes the time when the current congestion avoidance + // was entered. + t time.Time + + // numCongestionEvents tracks the number of congestion events since last + // RTO. + numCongestionEvents int + + // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as + // per RFC. + c float64 + + // k is the time period that the above function takes to increase the + // current window size to W_max if there are no further congestion + // events and is calculated using the following equation: + // + // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) + k float64 + + // beta is the CUBIC multiplication decrease factor. that is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // W_cubic(0)=W_max*beta_cubic. + beta float64 + + // wC is window computed by CUBIC at time t. It's calculated using the + // formula: + // + // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) + wC float64 + + // wEst is the window computed by CUBIC at time t+RTT i.e + // W_cubic(t+RTT). + wEst float64 + + s *sender +} + +// newCubicCC returns a partially initialized cubic state with the constants +// beta and c set and t set to current time. +func newCubicCC(s *sender) *cubicState { + return &cubicState{ + t: time.Now(), + beta: 0.7, + c: 0.4, + s: s, + } +} + +// enterCongestionAvoidance is used to initialize cubic in cases where we exit +// SlowStart without a real congestion event taking place. This can happen when +// a connection goes back to slow start due to a retransmit and we exceed the +// previously lowered ssThresh without experiencing packet loss. +// +// Refer: https://tools.ietf.org/html/rfc8312#section-4.8 +func (c *cubicState) enterCongestionAvoidance() { + // See: https://tools.ietf.org/html/rfc8312#section-4.7 & + // https://tools.ietf.org/html/rfc8312#section-4.8 + if c.numCongestionEvents == 0 { + c.k = 0 + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + } +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window we cross +// the ssThresh then it will return the number of packets that must be consumed +// in congestion avoidance mode. +func (c *cubicState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := c.s.sndCwnd + packetsAcked + enterCA := false + if newcwnd >= c.s.sndSsthresh { + newcwnd = c.s.sndSsthresh + c.s.sndCAAckCount = 0 + enterCA = true + } + + packetsAcked -= newcwnd - c.s.sndCwnd + c.s.sndCwnd = newcwnd + if enterCA { + c.enterCongestionAvoidance() + } + return packetsAcked +} + +// Update updates cubic's internal state variables. It must be called on every +// ACK received. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +func (c *cubicState) Update(packetsAcked int) { + if c.s.sndCwnd < c.s.sndSsthresh { + packetsAcked = c.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } else { + c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, c.s.srtt) + } +} + +// cubicCwnd computes the CUBIC congestion window after t seconds from last +// congestion event. +func (c *cubicState) cubicCwnd(t float64) float64 { + return c.c*math.Pow(t, 3.0) + c.wMax +} + +// getCwnd returns the current congestion window as computed by CUBIC. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { + elapsed := time.Since(c.t).Seconds() + + // Compute the window as per Cubic after 'elapsed' time + // since last congestion event. + c.wC = c.cubicCwnd(elapsed - c.k) + + // Compute the TCP friendly estimate of the congestion window. + c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + + // Make sure in the TCP friendly region CUBIC performs at least + // as well as Reno. + if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + // TCP Friendly region of cubic. + return int(c.wEst) + } + + // In Concave/Convex region of CUBIC, calculate what CUBIC window + // will be after 1 RTT and use that to grow congestion window + // for every ack. + tEst := (time.Since(c.t) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.k) + // As per 4.3 for each received ACK cwnd must be incremented + // by (w_cubic(t+RTT) - cwnd/cwnd. + cwnd := float64(sndCwnd) + for i := 0; i < packetsAcked; i++ { + // Concave/Convex regions of cubic have the same formulas. + // See: https://tools.ietf.org/html/rfc8312#section-4.3 + cwnd += (wtRtt - cwnd) / cwnd + } + return int(cwnd) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +func (c *cubicState) HandleNDupAcks() { + // See: https://tools.ietf.org/html/rfc8312#section-4.5 + c.numCongestionEvents++ + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + c.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionContrl.HandleRTOExpired. +func (c *cubicState) HandleRTOExpired() { + // See: https://tools.ietf.org/html/rfc8312#section-4.6 + c.t = time.Now() + c.numCongestionEvents = 0 + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + + // We lost a packet, so reduce ssthresh. + c.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + c.s.sndCwnd = 1 +} + +// fastConvergence implements the logic for Fast Convergence algorithm as +// described in https://tools.ietf.org/html/rfc8312#section-4.6. +func (c *cubicState) fastConvergence() { + if c.wMax < c.wLastMax { + c.wLastMax = c.wMax + c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + } else { + c.wLastMax = c.wMax + } + // Recompute k as wMax may have changed. + c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) +} + +// PostRecovery implemements congestionControl.PostRecovery. +func (c *cubicState) PostRecovery() { + c.t = time.Now() +} + +// reduceSlowStartThreshold returns new SsThresh as described in +// https://tools.ietf.org/html/rfc8312#section-4.7. +func (c *cubicState) reduceSlowStartThreshold() { + c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index de1883d84..8bfb68f91 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -187,6 +187,10 @@ type endpoint struct { sndWaker sleep.Waker `state:"manual"` sndCloseWaker sleep.Waker `state:"manual"` + // cc stores the name of the Congestion Control algorithm to use for + // this endpoint. + cc CongestionControlOption + // The following are used when a "packet too big" control packet is // received. They are protected by sndBufMu. They are used to // communicate to the main protocol goroutine how many such control @@ -254,6 +258,11 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.rcvBufSize = rs.Default } + var cs CongestionControlOption + if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + e.cc = cs + } + if p := stack.GetTCPProbe(); p != nil { e.probe = p } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index cbe0e564e..194d3f41d 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -69,6 +69,11 @@ type ReceiveBufferSizeOption struct { Max int } +const ( + ccReno = "reno" + ccCubic = "cubic" +) + // CongestionControlOption sets the current congestion control algorithm. type CongestionControlOption string @@ -227,8 +232,8 @@ func init() { return &protocol{ sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, - congestionControl: "reno", - availableCongestionControl: []string{"reno"}, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, } }) } diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go index 03ae8d747..feb593234 100644 --- a/pkg/tcpip/transport/tcp/reno.go +++ b/pkg/tcpip/transport/tcp/reno.go @@ -96,3 +96,8 @@ func (r *renoState) HandleRTOExpired() { // initial congestion window. r.s.sndCwnd = 1 } + +// PostRecovery implements congestionControl.PostRecovery. +func (r *renoState) PostRecovery() { + // noop. +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 376e81846..568bd7024 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -51,6 +51,11 @@ type congestionControl interface { // number of packet's that were acked by the most recent cumulative // acknowledgement. Update(packetsAcked int) + + // PostRecovery is invoked when the sender is exiting a fast retransmit/ + // recovery phase. This provides congestion control algorithms a way + // to adjust their state when exiting recovery. + PostRecovery() } // sender holds the state necessary to send TCP segments. @@ -174,7 +179,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint }, } - s.cc = newRenoCC(s) + s.cc = s.initCongestionControl(ep.cc) // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. @@ -189,6 +194,17 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint return s } +func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl { + switch congestionControlName { + case ccCubic: + return newCubicCC(s) + case ccReno: + fallthrough + default: + return newRenoCC(s) + } +} + // updateMaxPayloadSize updates the maximum payload size based on the given // MTU. If this is in response to "packet too big" control packets (indicated // by the count argument), it also reduces the number of outstanding packets and @@ -409,6 +425,7 @@ func (s *sender) sendData() { } func (s *sender) enterFastRecovery() { + s.fr.active = true // Save state to reflect we're now in fast recovery. // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. // We inflat the cwnd by 3 to account for the 3 packets which triggered @@ -417,7 +434,6 @@ func (s *sender) enterFastRecovery() { s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding - s.fr.active = true } func (s *sender) leaveFastRecovery() { @@ -429,12 +445,13 @@ func (s *sender) leaveFastRecovery() { // Deflate cwnd. It had been artificially inflated when new dups arrived. s.sndCwnd = s.sndSsthresh + s.cc.PostRecovery() } // checkDuplicateAck is called when an ack is received. It manages the state // related to duplicate acks and determines if a retransmit is needed according // to the rules in RFC 6582 (NewReno). -func (s *sender) checkDuplicateAck(seg *segment) bool { +func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { ack := seg.ackNumber if s.fr.active { // We are in fast recovery mode. Ignore the ack if it's out of @@ -474,6 +491,7 @@ func (s *sender) checkDuplicateAck(seg *segment) bool { // // N.B. The retransmit timer will be reset by the caller. s.fr.first = ack + s.dupAckCount = 0 return true } @@ -508,16 +526,11 @@ func (s *sender) checkDuplicateAck(seg *segment) bool { return true } -// updateCwnd updates the congestion window based on the number of packets that -// were acknowledged. -func (s *sender) updateCwnd(packetsAcked int) { -} - // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(seg *segment) { // Check if we can extract an RTT measurement from this ack. - if s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + if !s.ep.sendTSOk && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { s.updateRTO(time.Now().Sub(s.rttMeasureTime)) s.rttMeasureSeqNum = s.sndNxt } @@ -534,10 +547,25 @@ func (s *sender) handleRcvdSegment(seg *segment) { // Ignore ack if it doesn't acknowledge any new data. ack := seg.ackNumber if (ack - 1).InRange(s.sndUna, s.sndNxt) { + s.dupAckCount = 0 // When an ack is received we must reset the timer. We stop it // here and it will be restarted later if needed. s.resendTimer.disable() + // See : https://tools.ietf.org/html/rfc1323#section-3.3. + // Specifically we should only update the RTO using TSEcr if the + // following condition holds: + // + // A TSecr value received in a segment is used to update the + // averaged RTT measurement only if the segment acknowledges + // some new data, i.e., only if it advances the left edge of + // the send window. + if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + // TSVal/Ecr values sent by Netstack are at a millisecond + // granularity. + elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + s.updateRTO(elapsed) + } // Remove all acknowledged data from the write list. acked := s.sndUna.Size(ack) s.sndUna = ack diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 45ebca5b1..11410b050 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -17,6 +17,7 @@ package tcp_test import ( "bytes" "fmt" + "math" "testing" "time" @@ -2005,7 +2006,7 @@ func TestCongestionAvoidance(t *testing.T) { // Check we don't receive any more packets on this iteration. // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) } // Don't acknowledge the first packet of the last packet train. Let's @@ -2043,7 +2044,7 @@ func TestCongestionAvoidance(t *testing.T) { // Check we don't receive any more packets on this iteration. // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) // Acknowledge all the data received so far. c.SendAck(790, bytesRead) @@ -2054,6 +2055,130 @@ func TestCongestionAvoidance(t *testing.T) { } } +// cubicCwnd returns an estimate of a cubic window given the +// originalCwnd, wMax, last congestion event time and sRTT. +func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { + cwnd := float64(origCwnd) + // We wait 50ms between each iteration so sRTT as computed by cubic + // should be close to 50ms. + elapsed := (time.Since(congEventTime) + sRTT).Seconds() + k := math.Cbrt(float64(wMax) * 0.3 / 0.7) + wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) + cwnd += (wtRTT - cwnd) / cwnd + return int(cwnd) +} + +func TestCubicCongestionAvoidance(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + enableCUBIC(t, c) + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd * 0.7. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all pending data. + c.SendAck(790, bytesRead) + + // Store away the time we sent the ACK and assuming a 200ms RTO + // we estimate that the sender will have an RTO 200ms from now + // and go back into slow start. + packetDropTime := time.Now().Add(200 * time.Millisecond) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 essentially putting the connection + // straight into congestion avoidance. + wMax := expected + // Lower expected as per cubic spec after a congestion event. + expected = int(float64(expected) * 0.7) + cwnd := expected + for i := 0; i < iterations; i++ { + // Cubic grows window independent of ACKs. Cubic Window growth + // is a function of time elapsed since last congestion event. + // As a result the congestion window does not grow + // deterministically in response to ACKs. + // + // We need to roughly estimate what the cwnd of the sender is + // based on when we sent the dupacks. + cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) + + packetsExpected := cwnd + for j := 0; j < packetsExpected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + t.Logf("expected packets received, next trying to receive any extra packets that may come") + + // If our estimate was correct there should be no more pending packets. + // We attempt to read a packet a few times with a short sleep in between + // to ensure that we don't see the sender send any unexpected packets. + packetsUnexpected := 0 + for { + gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) + if !gotPacket { + break + } + bytesRead += maxPayload + packetsUnexpected++ + time.Sleep(1 * time.Millisecond) + } + if packetsUnexpected != 0 { + t.Fatalf("received %d unexpected packets for iteration %d", packetsUnexpected, i) + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + } +} + func TestFastRecovery(t *testing.T) { maxPayload := 10 c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) @@ -2864,8 +2989,9 @@ func TestSetCongestionControl(t *testing.T) { mustPass bool }{ {"reno", true}, - {"cubic", false}, + {"cubic", true}, } + for _, tc := range testCases { t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { c := context.New(t, 1500) @@ -2881,7 +3007,7 @@ func TestSetCongestionControl(t *testing.T) { if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tcp.CongestionControlOption("reno"); got != want { + if got, want := cc, tc.cc; got != want { t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want) } }) @@ -2899,7 +3025,7 @@ func TestAvailableCongestionControl(t *testing.T) { if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcp.AvailableCongestionControlOption("reno"); got != want { + if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want { t.Fatalf("unexpected value for AvailableCongestionControlOption: got: %v, want: %v", got, want) } } @@ -2917,11 +3043,19 @@ func TestSetAvailableCongestionControl(t *testing.T) { } // Verify that we still get the expected list of congestion control options. - var cc tcp.CongestionControlOption + var cc tcp.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tcp.CongestionControlOption("reno"); got != want { - t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want) + if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("unexpected value for available congestion control got: %v, want: %v", got, want) + } +} + +func enableCUBIC(t *testing.T, c *context.Context) { + t.Helper() + opt := tcp.CongestionControlOption("cubic") + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e44979527..6b5786140 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -242,6 +242,27 @@ func (c *Context) GetPacket() []byte { return nil } +// GetPacketNonBlocking reads a packet from the link layer endpoint +// and verifies that it is an IPv4 packet with the expected source +// and destination address. If no packet is available it will return +// nil immediately. +func (c *Context) GetPacketNonBlocking() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b + default: + return nil + } +} + // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. @@ -355,6 +376,32 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { } } +// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint +// and verifies that the packet packet payload of packet matches the slice of +// data indicated by offset & size. It returns true if a packet was received and +// processed. +func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { + b := c.GetPacketNonBlocking() + if b == nil { + return false + } + checker.IPv4(c.t, b, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(TestPort), + checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[offset:][:size] + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { + c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } + return true +} + // CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only // is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 // only endpoint instead of a default dual stack socket. |