diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 48 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 178 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 36 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 78 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/sack.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 14 | ||||
-rwxr-xr-x | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 36 |
10 files changed, 387 insertions, 51 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 9b1ad6a28..52fd1bfa3 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -213,6 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) + n.amss = mssForRoute(&n.route) n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) @@ -232,7 +233,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // The receiver at least temporarily has a zero receive window scale, // but the caller may change it (before starting the protocol loop). n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) - n.rcv = newReceiver(n, irs, l.rcvWnd, 0) + n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize())) + // Bootstrap the auto tuning algorithm. Starting at zero will result in + // a large step function on the first window adjustment causing the + // window to grow to a really large value. + n.rcvAutoParams.prevCopied = n.initialReceiveWindow() return n, nil } @@ -249,7 +254,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head } // Perform the 3-way handshake. - h := newHandshake(ep, l.rcvWnd) + h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) h.resetToSynRcvd(cookie, irs, opts) if err := h.execute(); err != nil { @@ -359,16 +364,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) - // Send SYN with window scaling because we currently + + // Send SYN without window scaling because we currently // dont't encode this information in the cookie. // // Enable Timestamp option if the original syn did have // the timestamp option specified. + mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, + MSS: uint16(mss), } sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 84e3dd26c..00d2ae524 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -78,6 +78,9 @@ type handshake struct { // mss is the maximum segment size received from the peer. mss uint16 + // amss is the maximum segment size advertised by us to the peer. + amss uint16 + // sndWndScale is the send window scale, as defined in RFC 1323. A // negative value means no scaling is supported by the peer. sndWndScale int @@ -87,11 +90,24 @@ type handshake struct { } func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { + rcvWndScale := ep.rcvWndScaleForHandshake() + + // Round-down the rcvWnd to a multiple of wndScale. This ensures that the + // window offered in SYN won't be reduced due to the loss of precision if + // window scaling is enabled after the handshake. + rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale) + + // Ensure we can always accept at least 1 byte if the scale specified + // was too high for the provided rcvWnd. + if rcvWnd == 0 { + rcvWnd = 1 + } + h := handshake{ ep: ep, active: true, rcvWnd: rcvWnd, - rcvWndScale: FindWndScale(rcvWnd), + rcvWndScale: int(rcvWndScale), } h.resetState() return h @@ -224,7 +240,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.ep.state = StateSynRecv h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ - WS: h.rcvWndScale, + WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTS, @@ -233,6 +249,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // permits SACK. This is not explicitly defined in the RFC but // this is the behaviour implemented by Linux. SACKPermitted: rcvSynOpts.SACKPermitted, + MSS: h.ep.amss, } sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) @@ -277,6 +294,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTS, SACKPermitted: h.ep.sackPermitted, + MSS: h.ep.amss, } sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -419,12 +437,15 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. + h.ep.amss = mssForRoute(&h.ep.route) + synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTS, SACKPermitted: bool(sackEnabled), + MSS: h.ep.amss, } // Execute is also called in a listen context so we want to make sure we @@ -433,6 +454,11 @@ func (h *handshake) execute() *tcpip.Error { if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + if h.sndWndScale < 0 { + // Disable window scaling if the peer did not send us + // the window scaling option. + synOpts.WS = -1 + } } sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) for h.state != handshakeCompleted { @@ -554,13 +580,6 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { } func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { - // The MSS in opts is automatically calculated as this function is - // called from many places and we don't want every call point being - // embedded with the MSS calculation. - if opts.MSS == 0 { - opts.MSS = uint16(r.MTU() - header.TCPMinimumSize) - } - options := makeSynOptions(opts) err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) putOptions(options) @@ -861,7 +880,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // This is an active connection, so we must initiate the 3-way // handshake, and then inform potential waiters about its // completion. - h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + initialRcvWnd := e.initialReceiveWindow() + h := newHandshake(e, seqnum.Size(initialRcvWnd)) e.mu.Lock() h.ep.state = StateSynSent e.mu.Unlock() @@ -886,8 +906,14 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + rcvBufSize := seqnum.Size(e.receiveBufferSize()) e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // boot strap the auto tuning algorithm. Starting at zero will + // result in a large step function on the first proper causing + // the window to just go to a really large value after the first + // RTT itself. + e.rcvAutoParams.prevCopied = initialRcvWnd e.rcvListMu.Unlock() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9614b2958..1aa1f12b4 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -132,6 +132,42 @@ type SACKInfo struct { NumBlocks int } +// rcvBufAutoTuneParams are used to hold state variables to compute +// the auto tuned recv buffer size. +// +// +stateify savable +type rcvBufAutoTuneParams struct { + // measureTime is the time at which the current measurement + // was started. + measureTime time.Time `state:".(unixTime)"` + + // copied is the number of bytes copied out of the receive + // buffers since this measure began. + copied int + + // prevCopied is the number of bytes copied out of the receive + // buffers in the previous RTT period. + prevCopied int + + // rtt is the non-smoothed minimum RTT as measured by observing the time + // between when a byte is first acknowledged and the receipt of data + // that is at least one window beyond the sequence number that was + // acknowledged. + rtt time.Duration + + // rttMeasureSeqNumber is the highest acceptable sequence number at the + // time this RTT measurement period began. + rttMeasureSeqNumber seqnum.Value + + // rttMeasureTime is the absolute time at which the current rtt + // measurement period began. + rttMeasureTime time.Time `state:".(unixTime)"` + + // disabled is true if an explicit receive buffer is set for the + // endpoint. + disabled bool +} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -165,11 +201,12 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int - rcvBufUsed int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + rcvBufSize int + rcvBufUsed int + rcvAutoParams rcvBufAutoTuneParams // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` @@ -339,6 +376,9 @@ type endpoint struct { bindAddress tcpip.Address connectingAddress tcpip.Address + // amss is the advertised MSS to the peer by this endpoint. + amss uint16 + gso *stack.GSO } @@ -373,8 +413,8 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite netProto: netProto, waiterQueue: waiterQueue, state: StateInitial, - rcvBufSize: DefaultBufferSize, - sndBufSize: DefaultBufferSize, + rcvBufSize: DefaultReceiveBufferSize, + sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), reuseAddr: true, keepalive: keepalive{ @@ -400,6 +440,11 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.cc = cs } + var mrb tcpip.ModerateReceiveBufferOption + if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + e.rcvAutoParams.disabled = !bool(mrb) + } + if p := stack.GetTCPProbe(); p != nil { e.probe = p } @@ -408,6 +453,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.workMu.Init() e.workMu.Lock() e.tsOffset = timeStampOffset() + return e } @@ -551,6 +597,83 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// initialReceiveWindow returns the initial receive window to advertise in the +// SYN/SYN-ACK. +func (e *endpoint) initialReceiveWindow() int { + rcvWnd := e.receiveBufferAvailable() + if rcvWnd > math.MaxUint16 { + rcvWnd = math.MaxUint16 + } + routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + if rcvWnd > routeWnd { + rcvWnd = routeWnd + } + return rcvWnd +} + +// ModerateRecvBuf adjusts the receive buffer and the advertised window +// based on the number of bytes copied to user space. +func (e *endpoint) ModerateRecvBuf(copied int) { + e.rcvListMu.Lock() + if e.rcvAutoParams.disabled { + e.rcvListMu.Unlock() + return + } + now := time.Now() + if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { + e.rcvAutoParams.copied += copied + e.rcvListMu.Unlock() + return + } + prevRTTCopied := e.rcvAutoParams.copied + copied + prevCopied := e.rcvAutoParams.prevCopied + rcvWnd := 0 + if prevRTTCopied > prevCopied { + // The minimal receive window based on what was copied by the app + // in the immediate preceding RTT and some extra buffer for 16 + // segments to account for variations. + // We multiply by 2 to account for packet losses. + rcvWnd = prevRTTCopied*2 + 16*int(e.amss) + + // Scale for slow start based on bytes copied in this RTT vs previous. + grow := (rcvWnd * (prevRTTCopied - prevCopied)) / prevCopied + + // Multiply growth factor by 2 again to account for sender being + // in slow-start where the sender grows it's congestion window + // by 100% per RTT. + rcvWnd += grow * 2 + + // Make sure auto tuned buffer size can always receive upto 2x + // the initial window of 10 segments. + if minRcvWnd := int(e.amss) * InitialCwnd * 2; rcvWnd < minRcvWnd { + rcvWnd = minRcvWnd + } + + // Cap the auto tuned buffer size by the maximum permissible + // receive buffer size. + if max := e.maxReceiveBufferSize(); rcvWnd > max { + rcvWnd = max + } + + // We do not adjust downwards as that can cause the receiver to + // reject valid data that might already be in flight as the + // acceptable window will shrink. + if rcvWnd > e.rcvBufSize { + e.rcvBufSize = rcvWnd + e.notifyProtocolGoroutine(notifyReceiveWindowChanged) + } + + // We only update prevCopied when we grow the buffer because in cases + // where prevCopied > prevRTTCopied the existing buffer is already big + // enough to handle the current rate and we don't need to do any + // adjustments. + e.rcvAutoParams.prevCopied = prevRTTCopied + } + e.rcvAutoParams.measureTime = now + e.rcvAutoParams.copied = 0 + e.rcvListMu.Unlock() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() @@ -821,6 +944,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { wasZero := e.zeroReceiveWindow(scale) e.rcvBufSize = size + e.rcvAutoParams.disabled = true if wasZero && !e.zeroReceiveWindow(scale) { mask |= notifyNonZeroReceiveWindow } @@ -1657,6 +1781,33 @@ func (e *endpoint) receiveBufferSize() int { return size } +func (e *endpoint) maxReceiveBufferSize() int { + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + // As a fallback return the hardcoded max buffer size. + return MaxBufferSize + } + return rs.Max +} + +// rcvWndScaleForHandshake computes the receive window scale to offer to the +// peer when window scaling is enabled (true by default). If auto-tuning is +// disabled then the window scaling factor is based on the size of the +// receiveBuffer otherwise we use the max permissible receive buffer size to +// compute the scale. +func (e *endpoint) rcvWndScaleForHandshake() int { + bufSizeForScale := e.receiveBufferSize() + + e.rcvListMu.Lock() + autoTuningDisabled := e.rcvAutoParams.disabled + e.rcvListMu.Unlock() + if autoTuningDisabled { + return FindWndScale(seqnum.Size(bufSizeForScale)) + } + + return FindWndScale(seqnum.Size(e.maxReceiveBufferSize())) +} + // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { @@ -1749,6 +1900,13 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.RcvBufSize = e.rcvBufSize s.RcvBufUsed = e.rcvBufUsed s.RcvClosed = e.rcvClosed + s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime + s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied + s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied + s.RcvAutoParams.RTT = e.rcvAutoParams.rtt + s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber + s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime + s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled e.rcvListMu.Unlock() // Endpoint TCP Option state. @@ -1802,13 +1960,13 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RTTMeasureTime: e.snd.rttMeasureTime, Closed: e.snd.closed, RTO: e.snd.rto, - SRTTInited: e.snd.srttInited, MaxPayloadSize: e.snd.maxPayloadSize, SndWndScale: e.snd.sndWndScale, MaxSentAck: e.snd.maxSentAck, } e.snd.rtt.Lock() s.Sender.SRTT = e.snd.rtt.srtt + s.Sender.SRTTInited = e.snd.rtt.srttInited e.snd.rtt.Unlock() if cubic, ok := e.snd.cc.(*cubicState); ok { @@ -1856,3 +2014,7 @@ func (e *endpoint) State() uint32 { defer e.mu.Unlock() return uint32(e.state) } + +func mssForRoute(r *stack.Route) uint16 { + return uint16(r.MTU() - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 58be61927..ec61a3886 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -360,3 +360,23 @@ func loadError(s string) *tcpip.Error { return e } + +// saveMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { + return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} +} + +// loadMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { + r.measureTime = time.Unix(unix.second, unix.nano) +} + +// saveRttMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime { + return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()} +} + +// loadRttMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) { + r.rttMeasureTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2ce94aeb9..63666f0b3 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -47,7 +47,7 @@ type Forwarder struct { // If rcvWnd is set to zero, the default buffer size is used instead. func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder { if rcvWnd == 0 { - rcvWnd = DefaultBufferSize + rcvWnd = DefaultReceiveBufferSize } return &Forwarder{ maxInFlight: maxInFlight, diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 919e4ce24..ee04dcfcc 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -41,13 +41,18 @@ const ( ProtocolNumber = header.TCPProtocolNumber // MinBufferSize is the smallest size of a receive or send buffer. - minBufferSize = 4 << 10 // 4096 bytes. + MinBufferSize = 4 << 10 // 4096 bytes. - // DefaultBufferSize is the default size of the receive and send buffers. - DefaultBufferSize = 1 << 20 // 1MB + // DefaultSendBufferSize is the default size of the send buffer for + // an endpoint. + DefaultSendBufferSize = 1 << 20 // 1MB - // MaxBufferSize is the largest size a receive and send buffer can grow to. - maxBufferSize = 4 << 20 // 4MB + // DefaultReceiveBufferSize is the default size of the receive buffer + // for an endpoint. + DefaultReceiveBufferSize = 1 << 20 // 1MB + + // MaxBufferSize is the largest size a receive/send buffer can grow to. + MaxBufferSize = 4 << 20 // 4MB // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. @@ -86,6 +91,7 @@ type protocol struct { recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string + moderateReceiveBuffer bool } // Number returns the tcp protocol number. @@ -192,6 +198,13 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // linux returns ENOENT when an invalid congestion control // is specified. return tcpip.ErrNoSuchFile + + case tcpip.ModerateReceiveBufferOption: + p.mu.Lock() + p.moderateReceiveBuffer = bool(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -217,16 +230,25 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { *v = p.recvBufferSize p.mu.Unlock() return nil + case *tcpip.CongestionControlOption: p.mu.Lock() *v = tcpip.CongestionControlOption(p.congestionControl) p.mu.Unlock() return nil + case *tcpip.AvailableCongestionControlOption: p.mu.Lock() *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.Unlock() return nil + + case *tcpip.ModerateReceiveBufferOption: + p.mu.Lock() + *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -235,8 +257,8 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { func init() { stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { return &protocol{ - sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 8d9de9bf9..e90f9a7d9 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -16,6 +16,7 @@ package tcp import ( "container/heap" + "time" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -38,6 +39,9 @@ type receiver struct { // shrinking it. rcvAcc seqnum.Value + // rcvWnd is the non-scaled receive window last advertised to the peer. + rcvWnd seqnum.Size + rcvWndScale uint8 closed bool @@ -47,13 +51,14 @@ type receiver struct { pendingBufSize seqnum.Size } -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ ep: ep, rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, rcvWndScale: rcvWndScale, - pendingBufSize: rcvWnd, + pendingBufSize: pendingBufSize, } } @@ -72,14 +77,16 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - // Calculate the window size based on the current buffer size. - n := r.ep.receiveBufferAvailable() - acc := r.rcvNxt.Add(seqnum.Size(n)) + // Calculate the window size based on the available buffer space. + receiveBufferAvailable := r.ep.receiveBufferAvailable() + acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) if r.rcvAcc.LessThan(acc) { r.rcvAcc = acc } - - return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale + // Stash away the non-scaled receive window as we use it for measuring + // receiver's estimated RTT. + r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + return r.rcvNxt, r.rcvWnd >> r.rcvWndScale } // nonZeroWindow is called when the receive window grows from zero to nonzero; @@ -130,6 +137,21 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Update the segment that we're expecting to consume. r.rcvNxt = segSeq.Add(segLen) + // In cases of a misbehaving sender which could send more than the + // advertised window, we could end up in a situation where we get a + // segment that exceeds the window advertised. Instead of partially + // accepting the segment and discarding bytes beyond the advertised + // window, we accept the whole segment and make sure r.rcvAcc is moved + // forward to match r.rcvNxt to indicate that the window is now closed. + // + // In absence of this check the r.acceptable() check fails and accepts + // segments that should be dropped because rcvWnd is calculated as + // the size of the interval (rcvNxt, rcvAcc] which becomes extremely + // large if rcvAcc is ever less than rcvNxt. + if r.rcvAcc.LessThan(r.rcvNxt) { + r.rcvAcc = r.rcvNxt + } + // Trim SACK Blocks to remove any SACK information that covers // sequence numbers that have been consumed. TrimSACKBlockList(&r.ep.sack, r.rcvNxt) @@ -198,6 +220,39 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum return true } +// updateRTT updates the receiver RTT measurement based on the sequence number +// of the received segment. +func (r *receiver) updateRTT() { + // From: https://public.lanl.gov/radiant/pubs/drs/sc2001-poster.pdf + // + // A system that is only transmitting acknowledgements can still + // estimate the round-trip time by observing the time between when a byte + // is first acknowledged and the receipt of data that is at least one + // window beyond the sequence number that was acknowledged. + r.ep.rcvListMu.Lock() + if r.ep.rcvAutoParams.rttMeasureTime.IsZero() { + // New measurement. + r.ep.rcvAutoParams.rttMeasureTime = time.Now() + r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) + r.ep.rcvListMu.Unlock() + return + } + if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) { + r.ep.rcvListMu.Unlock() + return + } + rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime) + // We only store the minimum observed RTT here as this is only used in + // absence of a SRTT available from either timestamps or a sender + // measurement of RTT. + if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt { + r.ep.rcvAutoParams.rtt = rtt + } + r.ep.rcvAutoParams.rttMeasureTime = time.Now() + r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) + r.ep.rcvListMu.Unlock() +} + // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) { @@ -226,10 +281,9 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed += s.logicalLen() s.incRef() heap.Push(&r.pendingRcvdSegments, s) + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) } - UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) - // Immediately send an ack so that the peer knows it may // have to retransmit. r.ep.snd.sendAck() @@ -237,6 +291,12 @@ func (r *receiver) handleRcvdSegment(s *segment) { return } + // Since we consumed a segment update the receiver's RTT estimate + // if required. + if segLen > 0 { + r.updateRTT() + } + // By consuming the current segment, we may have filled a gap in the // sequence number domain that allows pending segments to be consumed // now. So try to do it. diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go index 52c5d9867..7be86d68e 100644 --- a/pkg/tcpip/transport/tcp/sack.go +++ b/pkg/tcpip/transport/tcp/sack.go @@ -31,6 +31,13 @@ const ( // segment identified by segStart->segEnd. func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) { newSB := header.SACKBlock{Start: segStart, End: segEnd} + + // Ignore any invalid SACK blocks or blocks that are before rcvNxt as + // those bytes have already been acked. + if newSB.End.LessThanEq(newSB.Start) || newSB.End.LessThan(rcvNxt) { + return + } + if sack.NumBlocks == 0 { sack.Blocks[0] = newSB sack.NumBlocks = 1 @@ -39,9 +46,8 @@ func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value var n = 0 for i := 0; i < sack.NumBlocks; i++ { start, end := sack.Blocks[i].Start, sack.Blocks[i].End - if end.LessThanEq(start) || start.LessThanEq(rcvNxt) { - // Discard any invalid blocks where end is before start - // and discard any sack blocks that are before rcvNxt as + if end.LessThanEq(rcvNxt) { + // Discard any sack blocks that are before rcvNxt as // those have already been acked. continue } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 297861462..0fee7ab72 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -121,9 +121,8 @@ type sender struct { // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", // "round-trip time variation" and "retransmit timeout", as defined in // section 2 of RFC 6298. - rtt rtt - rto time.Duration - srttInited bool + rtt rtt + rto time.Duration // maxPayloadSize is the maximum size of the payload of a given segment. // It is initialized on demand. @@ -150,8 +149,9 @@ type sender struct { type rtt struct { sync.Mutex `state:"nosave"` - srtt time.Duration - rttvar time.Duration + srtt time.Duration + rttvar time.Duration + srttInited bool } // fastRecovery holds information related to fast recovery from a packet loss. @@ -323,10 +323,10 @@ func (s *sender) sendAck() { // available. This is done in accordance with section 2 of RFC 6298. func (s *sender) updateRTO(rtt time.Duration) { s.rtt.Lock() - if !s.srttInited { + if !s.rtt.srttInited { s.rtt.rttvar = rtt / 2 s.rtt.srtt = rtt - s.srttInited = true + s.rtt.srttInited = true } else { diff := s.rtt.srtt - rtt if diff < 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 1674c6a08..601ff758d 100755 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -50,6 +50,31 @@ func (x *SACKInfo) load(m state.Map) { m.Load("NumBlocks", &x.NumBlocks) } +func (x *rcvBufAutoTuneParams) beforeSave() {} +func (x *rcvBufAutoTuneParams) save(m state.Map) { + x.beforeSave() + var measureTime unixTime = x.saveMeasureTime() + m.SaveValue("measureTime", measureTime) + var rttMeasureTime unixTime = x.saveRttMeasureTime() + m.SaveValue("rttMeasureTime", rttMeasureTime) + m.Save("copied", &x.copied) + m.Save("prevCopied", &x.prevCopied) + m.Save("rtt", &x.rtt) + m.Save("rttMeasureSeqNumber", &x.rttMeasureSeqNumber) + m.Save("disabled", &x.disabled) +} + +func (x *rcvBufAutoTuneParams) afterLoad() {} +func (x *rcvBufAutoTuneParams) load(m state.Map) { + m.Load("copied", &x.copied) + m.Load("prevCopied", &x.prevCopied) + m.Load("rtt", &x.rtt) + m.Load("rttMeasureSeqNumber", &x.rttMeasureSeqNumber) + m.Load("disabled", &x.disabled) + m.LoadValue("measureTime", new(unixTime), func(y interface{}) { x.loadMeasureTime(y.(unixTime)) }) + m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) }) +} + func (x *endpoint) save(m state.Map) { x.beforeSave() var lastError string = x.saveLastError() @@ -66,6 +91,7 @@ func (x *endpoint) save(m state.Map) { m.Save("rcvClosed", &x.rcvClosed) m.Save("rcvBufSize", &x.rcvBufSize) m.Save("rcvBufUsed", &x.rcvBufUsed) + m.Save("rcvAutoParams", &x.rcvAutoParams) m.Save("id", &x.id) m.Save("isRegistered", &x.isRegistered) m.Save("v6only", &x.v6only) @@ -100,6 +126,7 @@ func (x *endpoint) save(m state.Map) { m.Save("snd", &x.snd) m.Save("bindAddress", &x.bindAddress) m.Save("connectingAddress", &x.connectingAddress) + m.Save("amss", &x.amss) m.Save("gso", &x.gso) } @@ -110,6 +137,7 @@ func (x *endpoint) load(m state.Map) { m.Load("rcvClosed", &x.rcvClosed) m.Load("rcvBufSize", &x.rcvBufSize) m.Load("rcvBufUsed", &x.rcvBufUsed) + m.Load("rcvAutoParams", &x.rcvAutoParams) m.Load("id", &x.id) m.Load("isRegistered", &x.isRegistered) m.Load("v6only", &x.v6only) @@ -144,6 +172,7 @@ func (x *endpoint) load(m state.Map) { m.LoadWait("snd", &x.snd) m.Load("bindAddress", &x.bindAddress) m.Load("connectingAddress", &x.connectingAddress) + m.Load("amss", &x.amss) m.Load("gso", &x.gso) m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) }) m.LoadValue("state", new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) @@ -177,6 +206,7 @@ func (x *receiver) save(m state.Map) { m.Save("ep", &x.ep) m.Save("rcvNxt", &x.rcvNxt) m.Save("rcvAcc", &x.rcvAcc) + m.Save("rcvWnd", &x.rcvWnd) m.Save("rcvWndScale", &x.rcvWndScale) m.Save("closed", &x.closed) m.Save("pendingRcvdSegments", &x.pendingRcvdSegments) @@ -189,6 +219,7 @@ func (x *receiver) load(m state.Map) { m.Load("ep", &x.ep) m.Load("rcvNxt", &x.rcvNxt) m.Load("rcvAcc", &x.rcvAcc) + m.Load("rcvWnd", &x.rcvWnd) m.Load("rcvWndScale", &x.rcvWndScale) m.Load("closed", &x.closed) m.Load("pendingRcvdSegments", &x.pendingRcvdSegments) @@ -302,7 +333,6 @@ func (x *sender) save(m state.Map) { m.Save("writeList", &x.writeList) m.Save("rtt", &x.rtt) m.Save("rto", &x.rto) - m.Save("srttInited", &x.srttInited) m.Save("maxPayloadSize", &x.maxPayloadSize) m.Save("gso", &x.gso) m.Save("sndWndScale", &x.sndWndScale) @@ -328,7 +358,6 @@ func (x *sender) load(m state.Map) { m.Load("writeList", &x.writeList) m.Load("rtt", &x.rtt) m.Load("rto", &x.rto) - m.Load("srttInited", &x.srttInited) m.Load("maxPayloadSize", &x.maxPayloadSize) m.Load("gso", &x.gso) m.Load("sndWndScale", &x.sndWndScale) @@ -344,12 +373,14 @@ func (x *rtt) save(m state.Map) { x.beforeSave() m.Save("srtt", &x.srtt) m.Save("rttvar", &x.rttvar) + m.Save("srttInited", &x.srttInited) } func (x *rtt) afterLoad() {} func (x *rtt) load(m state.Map) { m.Load("srtt", &x.srtt) m.Load("rttvar", &x.rttvar) + m.Load("srttInited", &x.srttInited) } func (x *fastRecovery) beforeSave() {} @@ -415,6 +446,7 @@ func (x *segmentEntry) load(m state.Map) { func init() { state.Register("tcp.cubicState", (*cubicState)(nil), state.Fns{Save: (*cubicState).save, Load: (*cubicState).load}) state.Register("tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load}) + state.Register("tcp.rcvBufAutoTuneParams", (*rcvBufAutoTuneParams)(nil), state.Fns{Save: (*rcvBufAutoTuneParams).save, Load: (*rcvBufAutoTuneParams).load}) state.Register("tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) state.Register("tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load}) state.Register("tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load}) |