diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rack.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rack_state.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 134 |
7 files changed, 166 insertions, 69 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 0dc710276..a00ef97c6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1357,6 +1357,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // e.mu is expected to be hold upon entering this section. if e.snd != nil { e.snd.resendTimer.cleanup() + e.snd.rc.probeTimer.cleanup() } if closeTimer != nil { @@ -1437,6 +1438,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, }, { + w: &e.snd.rc.probeWaker, + f: e.snd.probeTimerExpired, + }, + { w: &e.newSegmentWaker, f: func() *tcpip.Error { return e.handleSegments(false /* fastPath */) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8f3981075..281f4cd58 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -508,6 +508,9 @@ type endpoint struct { // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags + // tcpRecovery is the loss deteoction algorithm used by TCP. + tcpRecovery tcpip.TCPRecovery + // sackPermitted is set to true if the peer sends the TCPSACKPermitted // option in the SYN/SYN-ACK. sackPermitted bool @@ -918,6 +921,8 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.maxSynRetries = uint8(synRetries) } + s.TransportProtocolOption(ProtocolNumber, &e.tcpRecovery) + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -3072,7 +3077,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { } } - rc := e.snd.rc + rc := &e.snd.rc s.Sender.RACKState = stack.TCPRACKState{ XmitTime: rc.xmitTime, EndSequence: rc.endSequence, diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 672159eed..c9e194f82 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -405,7 +405,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E case *tcpip.TCPRecovery: p.mu.RLock() - *v = tcpip.TCPRecovery(p.recovery) + *v = p.recovery p.mu.RUnlock() return nil @@ -543,7 +543,8 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, - recovery: tcpip.TCPRACKLossDetection, + // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection. + recovery: 0, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index e0a50a919..b71e6b992 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -17,9 +17,18 @@ package tcp import ( "time" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) +// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as +// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. +// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is +// 1, PTO is inflated by WCDelAckT time to compensate for a potential long +// delayed ACK timer at the receiver. +const wcDelayedACKTimeout = 200 * time.Millisecond + // RACK is a loss detection algorithm used in TCP to detect packet loss and // reordering using transmission timestamp of the packets instead of packet or // sequence counts. To use RACK, SACK should be enabled on the connection. @@ -54,6 +63,15 @@ type rackControl struct { // xmitTime is the latest transmission timestamp of rackControl.seg. xmitTime time.Time `state:".(unixTime)"` + + // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm. + probeTimer timer `state:"nosave"` + probeWaker sleep.Waker `state:"nosave"` +} + +// init initializes RACK specific fields. +func (rc *rackControl) init() { + rc.probeTimer.init(&rc.probeWaker) } // update will update the RACK related fields when an ACK has been received. @@ -127,3 +145,61 @@ func (rc *rackControl) detectReorder(seg *segment) { func (rc *rackControl) setDSACKSeen() { rc.dsackSeen = true } + +// shouldSchedulePTO dictates whether we should schedule a PTO or not. +// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.1. +func (s *sender) shouldSchedulePTO() bool { + // Schedule PTO only if RACK loss detection is enabled. + return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 && + // The connection supports SACK. + s.ep.sackPermitted && + // The connection is not in loss recovery. + (s.state != RTORecovery && s.state != SACKRecovery) && + // The connection has no SACKed sequences in the SACK scoreboard. + s.ep.scoreboard.Sacked() == 0 +} + +// schedulePTO schedules the probe timeout as defined in +// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.1. +func (s *sender) schedulePTO() { + pto := time.Second + s.rtt.Lock() + if s.rtt.srttInited && s.rtt.srtt > 0 { + pto = s.rtt.srtt * 2 + if s.outstanding == 1 { + pto += wcDelayedACKTimeout + } + } + s.rtt.Unlock() + + now := time.Now() + if s.resendTimer.enabled() { + if now.Add(pto).After(s.resendTimer.target) { + pto = s.resendTimer.target.Sub(now) + } + s.resendTimer.disable() + } + + s.rc.probeTimer.enable(pto) +} + +// probeTimerExpired is the same as TLP_send_probe() as defined in +// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2. +func (s *sender) probeTimerExpired() *tcpip.Error { + if !s.rc.probeTimer.checkExpiration() { + return nil + } + // TODO(gvisor.dev/issue/5084): Implement this pseudo algorithm. + // If an unsent segment exists AND + // the receive window allows new data to be sent: + // Transmit the lowest-sequence unsent segment of up to SMSS + // Increment FlightSize by the size of the newly-sent segment + // Else if TLPRxtOut is not set: + // Retransmit the highest-sequence segment sent so far + // TLPRxtOut = true + // TLPHighRxt = SND.NXT + // The cwnd remains unchanged + // If FlightSize != 0: + // Arm RTO timer only. + return nil +} diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go index c9dc7e773..76cad0831 100644 --- a/pkg/tcpip/transport/tcp/rack_state.go +++ b/pkg/tcpip/transport/tcp/rack_state.go @@ -27,3 +27,8 @@ func (rc *rackControl) saveXmitTime() unixTime { func (rc *rackControl) loadXmitTime(unix unixTime) { rc.xmitTime = time.Unix(unix.second, unix.nano) } + +// afterLoad is invoked by stateify. +func (rc *rackControl) afterLoad() { + rc.probeTimer.init(&rc.probeWaker) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index cc991aba6..c0e9d98e3 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -286,6 +286,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint gso: ep.gso != nil, } + s.rc.init() + if s.gso { s.ep.gso.MSS = uint16(maxPayloadSize) } @@ -1455,6 +1457,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Reset firstRetransmittedSegXmitTime to the zero value. s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() + s.rc.probeTimer.disable() } } diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index aab92b94f..272ad67bd 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -185,6 +185,7 @@ func (e *endpoint) StateFields() []string { "recentTSTime", "tsOffset", "shutdownFlags", + "tcpRecovery", "sackPermitted", "sack", "delay", @@ -231,7 +232,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { var recentTSTimeValue unixTime = e.saveRecentTSTime() stateSinkObject.SaveValue(26, recentTSTimeValue) var acceptedChanValue []*endpoint = e.saveAcceptedChan() - stateSinkObject.SaveValue(49, acceptedChanValue) + stateSinkObject.SaveValue(50, acceptedChanValue) stateSinkObject.Save(0, &e.EndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) @@ -257,37 +258,38 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(25, &e.recentTS) stateSinkObject.Save(27, &e.tsOffset) stateSinkObject.Save(28, &e.shutdownFlags) - stateSinkObject.Save(29, &e.sackPermitted) - stateSinkObject.Save(30, &e.sack) - stateSinkObject.Save(31, &e.delay) - stateSinkObject.Save(32, &e.scoreboard) - stateSinkObject.Save(33, &e.segmentQueue) - stateSinkObject.Save(34, &e.synRcvdCount) - stateSinkObject.Save(35, &e.userMSS) - stateSinkObject.Save(36, &e.maxSynRetries) - stateSinkObject.Save(37, &e.windowClamp) - stateSinkObject.Save(38, &e.sndBufSize) - stateSinkObject.Save(39, &e.sndBufUsed) - stateSinkObject.Save(40, &e.sndClosed) - stateSinkObject.Save(41, &e.sndBufInQueue) - stateSinkObject.Save(42, &e.sndQueue) - stateSinkObject.Save(43, &e.cc) - stateSinkObject.Save(44, &e.packetTooBigCount) - stateSinkObject.Save(45, &e.sndMTU) - stateSinkObject.Save(46, &e.keepalive) - stateSinkObject.Save(47, &e.userTimeout) - stateSinkObject.Save(48, &e.deferAccept) - stateSinkObject.Save(50, &e.rcv) - stateSinkObject.Save(51, &e.snd) - stateSinkObject.Save(52, &e.connectingAddress) - stateSinkObject.Save(53, &e.amss) - stateSinkObject.Save(54, &e.sendTOS) - stateSinkObject.Save(55, &e.gso) - stateSinkObject.Save(56, &e.tcpLingerTimeout) - stateSinkObject.Save(57, &e.closed) - stateSinkObject.Save(58, &e.txHash) - stateSinkObject.Save(59, &e.owner) - stateSinkObject.Save(60, &e.ops) + stateSinkObject.Save(29, &e.tcpRecovery) + stateSinkObject.Save(30, &e.sackPermitted) + stateSinkObject.Save(31, &e.sack) + stateSinkObject.Save(32, &e.delay) + stateSinkObject.Save(33, &e.scoreboard) + stateSinkObject.Save(34, &e.segmentQueue) + stateSinkObject.Save(35, &e.synRcvdCount) + stateSinkObject.Save(36, &e.userMSS) + stateSinkObject.Save(37, &e.maxSynRetries) + stateSinkObject.Save(38, &e.windowClamp) + stateSinkObject.Save(39, &e.sndBufSize) + stateSinkObject.Save(40, &e.sndBufUsed) + stateSinkObject.Save(41, &e.sndClosed) + stateSinkObject.Save(42, &e.sndBufInQueue) + stateSinkObject.Save(43, &e.sndQueue) + stateSinkObject.Save(44, &e.cc) + stateSinkObject.Save(45, &e.packetTooBigCount) + stateSinkObject.Save(46, &e.sndMTU) + stateSinkObject.Save(47, &e.keepalive) + stateSinkObject.Save(48, &e.userTimeout) + stateSinkObject.Save(49, &e.deferAccept) + stateSinkObject.Save(51, &e.rcv) + stateSinkObject.Save(52, &e.snd) + stateSinkObject.Save(53, &e.connectingAddress) + stateSinkObject.Save(54, &e.amss) + stateSinkObject.Save(55, &e.sendTOS) + stateSinkObject.Save(56, &e.gso) + stateSinkObject.Save(57, &e.tcpLingerTimeout) + stateSinkObject.Save(58, &e.closed) + stateSinkObject.Save(59, &e.txHash) + stateSinkObject.Save(60, &e.owner) + stateSinkObject.Save(61, &e.ops) } func (e *endpoint) StateLoad(stateSourceObject state.Source) { @@ -316,42 +318,43 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(25, &e.recentTS) stateSourceObject.Load(27, &e.tsOffset) stateSourceObject.Load(28, &e.shutdownFlags) - stateSourceObject.Load(29, &e.sackPermitted) - stateSourceObject.Load(30, &e.sack) - stateSourceObject.Load(31, &e.delay) - stateSourceObject.Load(32, &e.scoreboard) - stateSourceObject.LoadWait(33, &e.segmentQueue) - stateSourceObject.Load(34, &e.synRcvdCount) - stateSourceObject.Load(35, &e.userMSS) - stateSourceObject.Load(36, &e.maxSynRetries) - stateSourceObject.Load(37, &e.windowClamp) - stateSourceObject.Load(38, &e.sndBufSize) - stateSourceObject.Load(39, &e.sndBufUsed) - stateSourceObject.Load(40, &e.sndClosed) - stateSourceObject.Load(41, &e.sndBufInQueue) - stateSourceObject.LoadWait(42, &e.sndQueue) - stateSourceObject.Load(43, &e.cc) - stateSourceObject.Load(44, &e.packetTooBigCount) - stateSourceObject.Load(45, &e.sndMTU) - stateSourceObject.Load(46, &e.keepalive) - stateSourceObject.Load(47, &e.userTimeout) - stateSourceObject.Load(48, &e.deferAccept) - stateSourceObject.LoadWait(50, &e.rcv) - stateSourceObject.LoadWait(51, &e.snd) - stateSourceObject.Load(52, &e.connectingAddress) - stateSourceObject.Load(53, &e.amss) - stateSourceObject.Load(54, &e.sendTOS) - stateSourceObject.Load(55, &e.gso) - stateSourceObject.Load(56, &e.tcpLingerTimeout) - stateSourceObject.Load(57, &e.closed) - stateSourceObject.Load(58, &e.txHash) - stateSourceObject.Load(59, &e.owner) - stateSourceObject.Load(60, &e.ops) + stateSourceObject.Load(29, &e.tcpRecovery) + stateSourceObject.Load(30, &e.sackPermitted) + stateSourceObject.Load(31, &e.sack) + stateSourceObject.Load(32, &e.delay) + stateSourceObject.Load(33, &e.scoreboard) + stateSourceObject.LoadWait(34, &e.segmentQueue) + stateSourceObject.Load(35, &e.synRcvdCount) + stateSourceObject.Load(36, &e.userMSS) + stateSourceObject.Load(37, &e.maxSynRetries) + stateSourceObject.Load(38, &e.windowClamp) + stateSourceObject.Load(39, &e.sndBufSize) + stateSourceObject.Load(40, &e.sndBufUsed) + stateSourceObject.Load(41, &e.sndClosed) + stateSourceObject.Load(42, &e.sndBufInQueue) + stateSourceObject.LoadWait(43, &e.sndQueue) + stateSourceObject.Load(44, &e.cc) + stateSourceObject.Load(45, &e.packetTooBigCount) + stateSourceObject.Load(46, &e.sndMTU) + stateSourceObject.Load(47, &e.keepalive) + stateSourceObject.Load(48, &e.userTimeout) + stateSourceObject.Load(49, &e.deferAccept) + stateSourceObject.LoadWait(51, &e.rcv) + stateSourceObject.LoadWait(52, &e.snd) + stateSourceObject.Load(53, &e.connectingAddress) + stateSourceObject.Load(54, &e.amss) + stateSourceObject.Load(55, &e.sendTOS) + stateSourceObject.Load(56, &e.gso) + stateSourceObject.Load(57, &e.tcpLingerTimeout) + stateSourceObject.Load(58, &e.closed) + stateSourceObject.Load(59, &e.txHash) + stateSourceObject.Load(60, &e.owner) + stateSourceObject.Load(61, &e.ops) stateSourceObject.LoadValue(4, new(string), func(y interface{}) { e.loadHardError(y.(string)) }) stateSourceObject.LoadValue(5, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) }) - stateSourceObject.LoadValue(49, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) + stateSourceObject.LoadValue(50, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) stateSourceObject.AfterLoad(e.afterLoad) } @@ -417,8 +420,6 @@ func (rc *rackControl) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(5, &rc.reorderSeen) } -func (rc *rackControl) afterLoad() {} - func (rc *rackControl) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &rc.dsackSeen) stateSourceObject.Load(1, &rc.endSequence) @@ -427,6 +428,7 @@ func (rc *rackControl) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(4, &rc.rtt) stateSourceObject.Load(5, &rc.reorderSeen) stateSourceObject.LoadValue(6, new(unixTime), func(y interface{}) { rc.loadXmitTime(y.(unixTime)) }) + stateSourceObject.AfterLoad(rc.afterLoad) } func (r *receiver) StateTypeName() string { |