diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 322 |
1 files changed, 286 insertions, 36 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 23422ca5e..beb90afb5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,19 +17,20 @@ package tcp import ( "fmt" "math" + "strings" "sync" "sync/atomic" "time" - "gvisor.googlesource.com/gvisor/pkg/rand" - "gvisor.googlesource.com/gvisor/pkg/sleep" - "gvisor.googlesource.com/gvisor/pkg/tcpip" - "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" - "gvisor.googlesource.com/gvisor/pkg/tcpip/header" - "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" - "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" - "gvisor.googlesource.com/gvisor/pkg/tmutex" - "gvisor.googlesource.com/gvisor/pkg/waiter" + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tmutex" + "gvisor.dev/gvisor/pkg/waiter" ) // EndpointState represents the state of a TCP endpoint. @@ -116,6 +117,7 @@ const ( notifyDrain notifyReset notifyKeepaliveChanged + notifyMSSChanged ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -131,6 +133,42 @@ type SACKInfo struct { NumBlocks int } +// rcvBufAutoTuneParams are used to hold state variables to compute +// the auto tuned recv buffer size. +// +// +stateify savable +type rcvBufAutoTuneParams struct { + // measureTime is the time at which the current measurement + // was started. + measureTime time.Time `state:".(unixTime)"` + + // copied is the number of bytes copied out of the receive + // buffers since this measure began. + copied int + + // prevCopied is the number of bytes copied out of the receive + // buffers in the previous RTT period. + prevCopied int + + // rtt is the non-smoothed minimum RTT as measured by observing the time + // between when a byte is first acknowledged and the receipt of data + // that is at least one window beyond the sequence number that was + // acknowledged. + rtt time.Duration + + // rttMeasureSeqNumber is the highest acceptable sequence number at the + // time this RTT measurement period began. + rttMeasureSeqNumber seqnum.Value + + // rttMeasureTime is the absolute time at which the current rtt + // measurement period began. + rttMeasureTime time.Time `state:".(unixTime)"` + + // disabled is true if an explicit receive buffer is set for the + // endpoint. + disabled bool +} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -164,18 +202,23 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int - rcvBufUsed int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + rcvBufSize int + rcvBufUsed int + rcvAutoParams rcvBufAutoTuneParams + // zeroWindow indicates that the window was closed due to receive buffer + // space being filled up. This is set by the worker goroutine before + // moving a segment to the rcvList. This setting is cleared by the + // endpoint when a Read() call reads enough data for the new window to + // be non-zero. + zeroWindow bool // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` id stack.TransportEndpointID - // state endpointState `state:".(endpointState)"` - // pState ProtocolState state EndpointState `state:".(EndpointState)"` isPortReserved bool `state:"manual"` @@ -269,6 +312,10 @@ type endpoint struct { // in SYN-RCVD state. synRcvdCount int + // userMSS if non-zero is the MSS value explicitly set by the user + // for this endpoint using the TCP_MAXSEG setsockopt. + userMSS int + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -286,7 +333,7 @@ type endpoint struct { // cc stores the name of the Congestion Control algorithm to use for // this endpoint. - cc CongestionControlOption + cc tcpip.CongestionControlOption // The following are used when a "packet too big" control packet is // received. They are protected by sndBufMu. They are used to @@ -338,6 +385,9 @@ type endpoint struct { bindAddress tcpip.Address connectingAddress tcpip.Address + // amss is the advertised MSS to the peer by this endpoint. + amss uint16 + gso *stack.GSO } @@ -372,8 +422,8 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite netProto: netProto, waiterQueue: waiterQueue, state: StateInitial, - rcvBufSize: DefaultBufferSize, - sndBufSize: DefaultBufferSize, + rcvBufSize: DefaultReceiveBufferSize, + sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), reuseAddr: true, keepalive: keepalive{ @@ -394,11 +444,16 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.rcvBufSize = rs.Default } - var cs CongestionControlOption + var cs tcpip.CongestionControlOption if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } + var mrb tcpip.ModerateReceiveBufferOption + if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + e.rcvAutoParams.disabled = !bool(mrb) + } + if p := stack.GetTCPProbe(); p != nil { e.probe = p } @@ -407,6 +462,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.workMu.Init() e.workMu.Lock() e.tsOffset = timeStampOffset() + return e } @@ -550,6 +606,83 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// initialReceiveWindow returns the initial receive window to advertise in the +// SYN/SYN-ACK. +func (e *endpoint) initialReceiveWindow() int { + rcvWnd := e.receiveBufferAvailable() + if rcvWnd > math.MaxUint16 { + rcvWnd = math.MaxUint16 + } + routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + if rcvWnd > routeWnd { + rcvWnd = routeWnd + } + return rcvWnd +} + +// ModerateRecvBuf adjusts the receive buffer and the advertised window +// based on the number of bytes copied to user space. +func (e *endpoint) ModerateRecvBuf(copied int) { + e.rcvListMu.Lock() + if e.rcvAutoParams.disabled { + e.rcvListMu.Unlock() + return + } + now := time.Now() + if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { + e.rcvAutoParams.copied += copied + e.rcvListMu.Unlock() + return + } + prevRTTCopied := e.rcvAutoParams.copied + copied + prevCopied := e.rcvAutoParams.prevCopied + rcvWnd := 0 + if prevRTTCopied > prevCopied { + // The minimal receive window based on what was copied by the app + // in the immediate preceding RTT and some extra buffer for 16 + // segments to account for variations. + // We multiply by 2 to account for packet losses. + rcvWnd = prevRTTCopied*2 + 16*int(e.amss) + + // Scale for slow start based on bytes copied in this RTT vs previous. + grow := (rcvWnd * (prevRTTCopied - prevCopied)) / prevCopied + + // Multiply growth factor by 2 again to account for sender being + // in slow-start where the sender grows it's congestion window + // by 100% per RTT. + rcvWnd += grow * 2 + + // Make sure auto tuned buffer size can always receive upto 2x + // the initial window of 10 segments. + if minRcvWnd := int(e.amss) * InitialCwnd * 2; rcvWnd < minRcvWnd { + rcvWnd = minRcvWnd + } + + // Cap the auto tuned buffer size by the maximum permissible + // receive buffer size. + if max := e.maxReceiveBufferSize(); rcvWnd > max { + rcvWnd = max + } + + // We do not adjust downwards as that can cause the receiver to + // reject valid data that might already be in flight as the + // acceptable window will shrink. + if rcvWnd > e.rcvBufSize { + e.rcvBufSize = rcvWnd + e.notifyProtocolGoroutine(notifyReceiveWindowChanged) + } + + // We only update prevCopied when we grow the buffer because in cases + // where prevCopied > prevRTTCopied the existing buffer is already big + // enough to handle the current rate and we don't need to do any + // adjustments. + e.rcvAutoParams.prevCopied = prevRTTCopied + } + e.rcvAutoParams.measureTime = now + e.rcvAutoParams.copied = 0 + e.rcvListMu.Unlock() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() @@ -595,10 +728,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { s.decRef() } - scale := e.rcv.rcvWndScale - wasZero := e.zeroReceiveWindow(scale) e.rcvBufUsed -= len(v) - if wasZero && !e.zeroReceiveWindow(scale) { + // If the window was zero before this read and if the read freed up + // enough buffer space for the scaled window to be non-zero then notify + // the protocol goroutine to send a window update. + if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) { + e.zeroWindow = false e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -785,6 +920,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -818,9 +964,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { size = math.MaxInt32 / 2 } - wasZero := e.zeroReceiveWindow(scale) e.rcvBufSize = size - if wasZero && !e.zeroReceiveWindow(scale) { + e.rcvAutoParams.disabled = true + if e.zeroWindow && !e.zeroReceiveWindow(scale) { + e.zeroWindow = false mask |= notifyNonZeroReceiveWindow } e.rcvListMu.Unlock() @@ -898,6 +1045,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.CongestionControlOption: + // Query the available cc algorithms in the stack and + // validate that the specified algorithm is actually + // supported in the stack. + var avail tcpip.AvailableCongestionControlOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { + return err + } + availCC := strings.Split(string(avail), " ") + for _, cc := range availCC { + if v == tcpip.CongestionControlOption(cc) { + // Acquire the work mutex as we may need to + // reinitialize the congestion control state. + e.mu.Lock() + state := e.state + e.cc = v + e.mu.Unlock() + switch state { + case StateEstablished: + e.workMu.Lock() + e.mu.Lock() + if e.state == state { + e.snd.cc = e.snd.initCongestionControl(e.cc) + } + e.mu.Unlock() + e.workMu.Unlock() + } + return nil + } + } + + // Linux returns ENOENT when an invalid congestion + // control algorithm is specified. + return tcpip.ErrNoSuchFile default: return nil } @@ -929,6 +1110,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err + case *tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + *o = header.TCPDefaultMSS + return nil + case *tcpip.SendBufferSizeOption: e.sndBufMu.Lock() *o = tcpip.SendBufferSizeOption(e.sndBufSize) @@ -1067,6 +1256,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.CongestionControlOption: + e.mu.Lock() + *o = e.cc + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1098,6 +1293,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Addr == "" && addr.Port == 0 { + // AF_UNSPEC isn't supported. + return tcpip.ErrAddressFamilyNotSupported + } + return e.connect(addr, true, true) } @@ -1582,6 +1782,13 @@ func (e *endpoint) readyToRead(s *segment) { if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() + // Check if the receive window is now closed. If so make sure + // we set the zero window before we deliver the segment to ensure + // that a subsequent read of the segment will correctly trigger + // a non-zero notification. + if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + e.zeroWindow = true + } e.rcvList.PushBack(s) } else { e.rcvClosed = true @@ -1591,21 +1798,26 @@ func (e *endpoint) readyToRead(s *segment) { e.waiterQueue.Notify(waiter.EventIn) } -// receiveBufferAvailable calculates how many bytes are still available in the -// receive buffer. -func (e *endpoint) receiveBufferAvailable() int { - e.rcvListMu.Lock() - size := e.rcvBufSize - used := e.rcvBufUsed - e.rcvListMu.Unlock() - +// receiveBufferAvailableLocked calculates how many bytes are still available +// in the receive buffer. +// rcvListMu must be held when this function is called. +func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. - if used >= size { + if e.rcvBufUsed >= e.rcvBufSize { return 0 } - return size - used + return e.rcvBufSize - e.rcvBufUsed +} + +// receiveBufferAvailable calculates how many bytes are still available in the +// receive buffer. +func (e *endpoint) receiveBufferAvailable() int { + e.rcvListMu.Lock() + available := e.receiveBufferAvailableLocked() + e.rcvListMu.Unlock() + return available } func (e *endpoint) receiveBufferSize() int { @@ -1616,6 +1828,33 @@ func (e *endpoint) receiveBufferSize() int { return size } +func (e *endpoint) maxReceiveBufferSize() int { + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + // As a fallback return the hardcoded max buffer size. + return MaxBufferSize + } + return rs.Max +} + +// rcvWndScaleForHandshake computes the receive window scale to offer to the +// peer when window scaling is enabled (true by default). If auto-tuning is +// disabled then the window scaling factor is based on the size of the +// receiveBuffer otherwise we use the max permissible receive buffer size to +// compute the scale. +func (e *endpoint) rcvWndScaleForHandshake() int { + bufSizeForScale := e.receiveBufferSize() + + e.rcvListMu.Lock() + autoTuningDisabled := e.rcvAutoParams.disabled + e.rcvListMu.Unlock() + if autoTuningDisabled { + return FindWndScale(seqnum.Size(bufSizeForScale)) + } + + return FindWndScale(seqnum.Size(e.maxReceiveBufferSize())) +} + // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { @@ -1708,6 +1947,13 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.RcvBufSize = e.rcvBufSize s.RcvBufUsed = e.rcvBufUsed s.RcvClosed = e.rcvClosed + s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime + s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied + s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied + s.RcvAutoParams.RTT = e.rcvAutoParams.rtt + s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber + s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime + s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled e.rcvListMu.Unlock() // Endpoint TCP Option state. @@ -1761,13 +2007,13 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RTTMeasureTime: e.snd.rttMeasureTime, Closed: e.snd.closed, RTO: e.snd.rto, - SRTTInited: e.snd.srttInited, MaxPayloadSize: e.snd.maxPayloadSize, SndWndScale: e.snd.sndWndScale, MaxSentAck: e.snd.maxSentAck, } e.snd.rtt.Lock() s.Sender.SRTT = e.snd.rtt.srtt + s.Sender.SRTTInited = e.snd.rtt.srttInited e.snd.rtt.Unlock() if cubic, ok := e.snd.cc.(*cubicState); ok { @@ -1815,3 +2061,7 @@ func (e *endpoint) State() uint32 { defer e.mu.Unlock() return uint32(e.state) } + +func mssForRoute(r *stack.Route) uint16 { + return uint16(r.MTU() - header.TCPMinimumSize) +} |