From da48c04d0df4bb044624cc3e7003ab3e973336de Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Mon, 23 Jul 2018 15:14:19 -0700 Subject: Refactor new reno congestion control logic out of sender. This CL also puts the congestion control logic behind an interface so that we can easily swap it out for say CUBIC in the future. PiperOrigin-RevId: 205732848 Change-Id: I891cdfd17d4d126b658b5faa0c6bd6083187944b --- pkg/tcpip/transport/tcp/BUILD | 2 + pkg/tcpip/transport/tcp/reno.go | 96 +++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/transport/tcp/snd.go | 75 +++++++++++++------------------- 3 files changed, 129 insertions(+), 44 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/reno.go diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 6a2f42a12..53623787d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -10,6 +10,7 @@ go_stateify( "endpoint.go", "endpoint_state.go", "rcv.go", + "reno.go", "segment.go", "segment_heap.go", "segment_queue.go", @@ -44,6 +45,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "reno.go", "sack.go", "segment.go", "segment_heap.go", diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go new file mode 100644 index 000000000..60f170a27 --- /dev/null +++ b/pkg/tcpip/transport/tcp/reno.go @@ -0,0 +1,96 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// renoState stores the variables related to TCP New Reno congestion +// control algorithm. +type renoState struct { + s *sender +} + +// newRenoCC initializes the state for the NewReno congestion control algorithm. +func newRenoCC(s *sender) *renoState { + return &renoState{s: s} +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window +// we cross the SSthreshold then it will return the number of packets that +// must be consumed in congestion avoidance mode. +func (r *renoState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := r.s.sndCwnd + packetsAcked + if newcwnd >= r.s.sndSsthresh { + newcwnd = r.s.sndSsthresh + r.s.sndCAAckCount = 0 + } + + packetsAcked -= newcwnd - r.s.sndCwnd + r.s.sndCwnd = newcwnd + return packetsAcked +} + +// updateCongestionAvoidance will update congestion window in congestion +// avoidance mode as described in RFC5681 section 3.1 +func (r *renoState) updateCongestionAvoidance(packetsAcked int) { + // Consume the packets in congestion avoidance mode. + r.s.sndCAAckCount += packetsAcked + if r.s.sndCAAckCount >= r.s.sndCwnd { + r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd + r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + } +} + +// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, +// page 6, eq. 4. It is called when we detect congestion in the network. +func (r *renoState) reduceSlowStartThreshold() { + r.s.sndSsthresh = r.s.outstanding / 2 + if r.s.sndSsthresh < 2 { + r.s.sndSsthresh = 2 + } + +} + +// Update updates the congestion state based on the number of packets that +// were acknowledged. +// Update implements congestionControl.Update. +func (r *renoState) Update(packetsAcked int) { + if r.s.sndCwnd < r.s.sndSsthresh { + packetsAcked = r.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } + r.updateCongestionAvoidance(packetsAcked) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +func (r *renoState) HandleNDupAcks() { + // A retransmit was triggered due to nDupAckThreshold + // being hit. Reduce our slow start threshold. + r.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionControl.HandleRTOExpired. +func (r *renoState) HandleRTOExpired() { + // We lost a packet, so reduce ssthresh. + r.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + r.s.sndCwnd = 1 +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 7dfbf6384..e38686e1b 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -31,8 +31,28 @@ const ( // InitialCwnd is the initial congestion window. InitialCwnd = 10 + + // nDupAckThreshold is the number of duplicate ACK's required + // before fast-retransmit is entered. + nDupAckThreshold = 3 ) +// congestionControl is an interface that must be implemented by any supported +// congestion control algorithm. +type congestionControl interface { + // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold + // just before entering fast retransmit. + HandleNDupAcks() + + // HandleRTOExpired is invoked when the retransmit timer expires. + HandleRTOExpired() + + // Update is invoked when processing inbound acks. It's passed the + // number of packet's that were acked by the most recent cumulative + // acknowledgement. + Update(packetsAcked int) +} + // sender holds the state necessary to send TCP segments. type sender struct { ep *endpoint @@ -107,6 +127,9 @@ type sender struct { // maxSentAck is the maxium acknowledgement actually sent. maxSentAck seqnum.Value + + // cc is the congestion control algorithm in use for this sender. + cc congestionControl } // fastRecovery holds information related to fast recovery from a packet loss. @@ -147,6 +170,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint }, } + s.cc = newRenoCC(s) + // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { @@ -251,15 +276,6 @@ func (s *sender) resendSegment() { } } -// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, -// page 6, eq. 4. It is called when we detect congestion in the network. -func (s *sender) reduceSlowStartThreshold() { - s.sndSsthresh = s.outstanding / 2 - if s.sndSsthresh < 2 { - s.sndSsthresh = 2 - } -} - // retransmitTimerExpired is called when the retransmit timer expires, and // unacknowledged segments are assumed lost, and thus need to be resent. // Returns true if the connection is still usable, or false if the connection @@ -292,13 +308,7 @@ func (s *sender) retransmitTimerExpired() bool { // we were not in fast recovery. s.fr.last = s.sndNxt - 1 - // We lost a packet, so reduce ssthresh. - s.reduceSlowStartThreshold() - - // Reduce the congestion window to 1, i.e., enter slow-start. Per - // RFC 5681, page 7, we must use 1 regardless of the value of the - // initial congestion window. - s.sndCwnd = 1 + s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and // start sending again. Set the number of outstanding packets to 0 so @@ -395,8 +405,6 @@ func (s *sender) sendData() { } func (s *sender) enterFastRecovery() { - // Save state to reflect we're now in fast recovery. - s.reduceSlowStartThreshold() // Save state to reflect we're now in fast recovery. // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. // We inflat the cwnd by 3 to account for the 3 packets which triggered @@ -474,9 +482,9 @@ func (s *sender) checkDuplicateAck(seg *segment) bool { return false } - // Enter fast recovery when we reach 3 dups. s.dupAckCount++ - if s.dupAckCount != 3 { + // Do not enter fast recovery until we reach nDupAckThreshold. + if s.dupAckCount < nDupAckThreshold { return false } @@ -489,6 +497,8 @@ func (s *sender) checkDuplicateAck(seg *segment) bool { s.dupAckCount = 0 return false } + + s.cc.HandleNDupAcks() s.enterFastRecovery() s.dupAckCount = 0 return true @@ -497,29 +507,6 @@ func (s *sender) checkDuplicateAck(seg *segment) bool { // updateCwnd updates the congestion window based on the number of packets that // were acknowledged. func (s *sender) updateCwnd(packetsAcked int) { - if s.sndCwnd < s.sndSsthresh { - // Don't let the congestion window cross into the congestion - // avoidance range. - newcwnd := s.sndCwnd + packetsAcked - if newcwnd >= s.sndSsthresh { - newcwnd = s.sndSsthresh - s.sndCAAckCount = 0 - } - - packetsAcked -= newcwnd - s.sndCwnd - s.sndCwnd = newcwnd - if packetsAcked == 0 { - // We've consumed all ack'd packets. - return - } - } - - // Consume the packets in congestion avoidance mode. - s.sndCAAckCount += packetsAcked - if s.sndCAAckCount >= s.sndCwnd { - s.sndCwnd += s.sndCAAckCount / s.sndCwnd - s.sndCAAckCount = s.sndCAAckCount % s.sndCwnd - } } // handleRcvdSegment is called when a segment is received; it is responsible for @@ -580,7 +567,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. if !s.fr.active { - s.updateCwnd(originalOutstanding - s.outstanding) + s.cc.Update(originalOutstanding - s.outstanding) } // It is possible for s.outstanding to drop below zero if we get -- cgit v1.2.3