diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 165 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 50 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 45 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment_queue.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 304 |
7 files changed, 391 insertions, 248 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 881752371..6891fd245 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -898,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -1003,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - rcvBufSize := seqnum.Size(e.receiveBufferSize()) e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. @@ -1136,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { } cont, err := e.handleSegment(s) + s.decRef() if err != nil { - s.decRef() return err } if !cont { - s.decRef() return nil } } @@ -1221,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } + // Increase counter if after processing the segment we would potentially + // advertise a zero window. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. @@ -1233,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // or a notification from the protocolMainLoop (caller goroutine). // This means that with this return, the segment dequeue below can // never occur on a closed endpoint. - s.decRef() return false, nil } @@ -1425,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.rcv.nonZeroWindow() } - if n¬ifyReceiveWindowChanged != 0 { - e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) - } - if n¬ifyMTUChanged != 0 { e.sndBufMu.Lock() count := e.packetTooBigCount diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 120483838..87db13720 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,6 +63,17 @@ const ( StateClosing ) +const ( + // rcvAdvWndScale is used to split the available socket buffer into + // application buffer and the window to be advertised to the peer. This is + // currently hard coded to split the available space equally. + rcvAdvWndScale = 1 + + // SegOverheadFactor is used to multiply the value provided by the + // user on a SetSockOpt for setting the socket send/receive buffer sizes. + SegOverheadFactor = 2 +) + // connected returns true when s is one of the states representing an // endpoint connected to a peer. func (s EndpointState) connected() bool { @@ -149,7 +160,6 @@ func (s EndpointState) String() string { // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota - notifyReceiveWindowChanged notifyClose notifyMTUChanged notifyDrain @@ -384,13 +394,26 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + // rcvBufSize is the total size of the receive buffer. + rcvBufSize int + // rcvBufUsed is the actual number of payload bytes held in the receive buffer + // not counting any overheads of the segments itself. NOTE: This will always + // be strictly <= rcvMemUsed below. rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams + // rcvMemUsed tracks the total amount of memory in use by received segments + // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // compute the window and the actual available buffer space. This is distinct + // from rcvBufUsed above which is the actual number of payload bytes held in + // the buffer not including any segment overheads. + // + // rcvMemUsed must be accessed atomically. + rcvMemUsed int32 + // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. mu sync.Mutex `state:"nosave"` @@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.probe = p } - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) @@ -1129,10 +1152,16 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// wndFromSpace returns the window that we can advertise based on the available +// receive buffer space. +func wndFromSpace(space int) int { + return space / (1 << rcvAdvWndScale) +} + // initialReceiveWindow returns the initial receive window to advertise in the // SYN/SYN-ACK. func (e *endpoint) initialReceiveWindow() int { - rcvWnd := e.receiveBufferAvailable() + rcvWnd := wndFromSpace(e.receiveBufferAvailable()) if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } @@ -1209,14 +1238,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = rcvWnd - availAfter := e.receiveBufferAvailableLocked() - mask := uint32(notifyReceiveWindowChanged) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } - e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -1293,18 +1320,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { v := views[s.viewToDeliver] s.viewToDeliver++ + var delta int if s.viewToDeliver >= len(views) { e.rcvList.Remove(s) + // We only free up receive buffer space when the segment is released as the + // segment is still holding on to the views even though some views have been + // read out to the user. + delta = s.segMemSize() s.decRef() } e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1481,11 +1512,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } // windowCrossedACKThresholdLocked checks if the receive window to be announced -// now would be under aMSS or under half receive buffer, whichever smaller. This -// is useful as a receive side silly window syndrome prevention mechanism. If -// window grows to reasonable value, we should send ACK to the sender to inform -// the rx space is now large. We also want ensure a series of small read()'s -// won't trigger a flood of spurious tiny ACK's. +// would be under aMSS or under the window derived from half receive buffer, +// whichever smaller. This is useful as a receive side silly window syndrome +// prevention mechanism. If window grows to reasonable value, we should send ACK +// to the sender to inform the rx space is now large. We also want ensure a +// series of small read()'s won't trigger a flood of spurious tiny ACK's. // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of @@ -1496,17 +1527,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := e.receiveBufferAvailableLocked() + newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 } - threshold := int(e.amss) - if threshold > e.rcvBufSize/2 { - threshold = e.rcvBufSize / 2 + // rcvBufFraction is the inverse of the fraction of receive buffer size that + // is used to decide if the available buffer space is now above it. + const rcvBufFraction = 2 + if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + threshold = wndThreshold } - switch { case oldAvail < threshold && newAvail >= threshold: return true, true @@ -1636,17 +1668,23 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Make sure the receive buffer size is within the min and max // allowed. var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) + } + + if v > rs.Max { + v = rs.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < rs.Min { v = rs.Min } - if v > rs.Max { - v = rs.Max - } + } else { + v = math.MaxInt32 } - mask := uint32(notifyReceiveWindowChanged) - e.LockUser() e.rcvListMu.Lock() @@ -1660,14 +1698,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { v = 1 << scale } - // Make sure 2*size doesn't overflow. - if v > math.MaxInt32/2 { - v = math.MaxInt32 / 2 - } - - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = v - availAfter := e.receiveBufferAvailableLocked() + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvAutoParams.disabled = true @@ -1675,24 +1708,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } e.rcvListMu.Unlock() e.UnlockUser() - e.notifyProtocolGoroutine(mask) case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) + } + + if v > ss.Max { + v = ss.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < ss.Min { v = ss.Min } - if v > ss.Max { - v = ss.Max - } + } else { + v = math.MaxInt32 } e.sndBufMu.Lock() @@ -2699,13 +2739,8 @@ func (e *endpoint) updateSndBufferUsage(v int) { func (e *endpoint) readyToRead(s *segment) { e.rcvListMu.Lock() if s != nil { + e.rcvBufUsed += s.payloadSize() s.incRef() - e.rcvBufUsed += s.data.Size() - // Increase counter if the receive window falls down below MSS - // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } e.rcvList.PushBack(s) } else { e.rcvClosed = true @@ -2720,15 +2755,17 @@ func (e *endpoint) readyToRead(s *segment) { func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. - if e.rcvBufUsed >= e.rcvBufSize { + memUsed := e.receiveMemUsed() + if memUsed >= e.rcvBufSize { return 0 } - return e.rcvBufSize - e.rcvBufUsed + return e.rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the -// receive buffer. +// receive buffer based on the actual memory used by all segments held in +// receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvListMu.Lock() available := e.receiveBufferAvailableLocked() @@ -2736,14 +2773,35 @@ func (e *endpoint) receiveBufferAvailable() int { return available } +// receiveBufferUsed returns the amount of in-use receive buffer. +func (e *endpoint) receiveBufferUsed() int { + e.rcvListMu.Lock() + used := e.rcvBufUsed + e.rcvListMu.Unlock() + return used +} + +// receiveBufferSize returns the current size of the receive buffer. func (e *endpoint) receiveBufferSize() int { e.rcvListMu.Lock() size := e.rcvBufSize e.rcvListMu.Unlock() - return size } +// receiveMemUsed returns the total memory in use by segments held by this +// endpoint. +func (e *endpoint) receiveMemUsed() int { + return int(atomic.LoadInt32(&e.rcvMemUsed)) +} + +// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed. +func (e *endpoint) updateReceiveMemUsed(delta int) { + atomic.AddInt32(&e.rcvMemUsed, int32(delta)) +} + +// maxReceiveBufferSize returns the stack wide maximum receive buffer size for +// an endpoint. func (e *endpoint) maxReceiveBufferSize() int { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { @@ -2894,7 +2952,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RcvAcc: e.rcv.rcvAcc, RcvWndScale: e.rcv.rcvWndScale, PendingBufUsed: e.rcv.pendingBufUsed, - PendingBufSize: e.rcv.pendingBufSize, } // Copy sender state. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 41d0050f3..b25431467 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() { // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. - e.segmentQueue.setLimit(0) + e.segmentQueue.freeze() e.mu.Lock() defer e.mu.Unlock() @@ -178,7 +178,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index cfd43b5e3..4aafb4d22 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -47,22 +47,24 @@ type receiver struct { closed bool + // pendingRcvdSegments is bounded by the receive buffer size of the + // endpoint. pendingRcvdSegments segmentHeap - pendingBufUsed seqnum.Size - pendingBufSize seqnum.Size + // pendingBufUsed tracks the total number of bytes (including segment + // overhead) currently queued in pendingRcvdSegments. + pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` } -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ ep: ep, rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, lastRcvdAckTime: time.Now(), } } @@ -85,15 +87,23 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - // Calculate the window size based on the available buffer space. - receiveBufferAvailable := r.ep.receiveBufferAvailable() - acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) - if r.rcvAcc.LessThan(acc) { - r.rcvAcc = acc + avail := wndFromSpace(r.ep.receiveBufferAvailable()) + acc := r.rcvNxt.Add(seqnum.Size(avail)) + newWnd := r.rcvNxt.Size(acc) + curWnd := r.rcvNxt.Size(r.rcvAcc) + + // Update rcvAcc only if new window is > previously advertised window. We + // should never shrink the acceptable sequence space once it has been + // advertised the peer. If we shrink the acceptable sequence space then we + // would end up dropping bytes that might already be in flight. + if newWnd > curWnd { + r.rcvAcc = r.rcvNxt.Add(newWnd) + } else { + newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. - r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + r.rcvWnd = newWnd return r.rcvNxt, r.rcvWnd >> r.rcvWndScale } @@ -195,7 +205,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of // truncated items, thus truncated items must be set to nil to avoid // memory leaks. @@ -384,10 +396,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { - // We only store the segment if it's within our buffer - // size limit. - if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += seqnum.Size(s.segMemSize()) + // We only store the segment if it's within our buffer size limit. + // + // Only use 75% of the receive buffer queue for out-of-order + // segments. This ensures that we always leave some space for the inorder + // segments to arrive allowing pending segments to be processed and + // delivered to the user. + if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvListMu.Lock() + r.pendingBufUsed += s.segMemSize() + r.ep.rcvListMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -421,7 +439,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= seqnum.Size(s.segMemSize()) + r.ep.rcvListMu.Lock() + r.pendingBufUsed -= s.segMemSize() + r.ep.rcvListMu.Unlock() s.decRef() } return false, nil diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 94307d31a..13acaf753 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "sync/atomic" "time" @@ -24,6 +25,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// queueFlags are used to indicate which queue of an endpoint a particular segment +// belongs to. This is used to track memory accounting correctly. +type queueFlags uint8 + +const ( + recvQ queueFlags = 1 << iota + sendQ +) + // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. // segment is mostly immutable, the only field allowed to change is viewToDeliver. @@ -32,6 +42,8 @@ import ( type segment struct { segmentEntry refCnt int32 + ep *endpoint + qFlags queueFlags id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -100,6 +112,8 @@ func (s *segment) clone() *segment { rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, + ep: s.ep, + qFlags: s.qFlags, } t.data = s.data.Clone(t.views[:]) return t @@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool { return s.flags&flags == flags } +// setOwner sets the owning endpoint for this segment. Its required +// to be called to ensure memory accounting for receive/send buffer +// queues is done properly. +func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) { + switch qFlags { + case recvQ: + ep.updateReceiveMemUsed(s.segMemSize()) + case sendQ: + // no memory account for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b", qFlags)) + } + s.ep = ep + s.qFlags = qFlags +} + func (s *segment) decRef() { if atomic.AddInt32(&s.refCnt, -1) == 0 { + if s.ep != nil { + switch s.qFlags { + case recvQ: + s.ep.updateReceiveMemUsed(-s.segMemSize()) + case sendQ: + // no memory accounting for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) + } + } s.route.Release() } } @@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// payloadSize is the size of s.data. +func (s *segment) payloadSize() int { + return s.data.Size() +} + // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 48a257137..54545a1b1 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -22,16 +22,16 @@ import ( // // +stateify savable type segmentQueue struct { - mu sync.Mutex `state:"nosave"` - list segmentList `state:"wait"` - limit int - used int + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + ep *endpoint + frozen bool } // emptyLocked determines if the queue is empty. // Preconditions: q.mu must be held. func (q *segmentQueue) emptyLocked() bool { - return q.used == 0 + return q.list.Empty() } // empty determines if the queue is empty. @@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool { return r } -// setLimit updates the limit. No segments are immediately dropped in case the -// queue becomes full due to the new limit. -func (q *segmentQueue) setLimit(limit int) { - q.mu.Lock() - q.limit = limit - q.mu.Unlock() -} - // enqueue adds the given segment to the queue. // // Returns true when the segment is successfully added to the queue, in which @@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) { // false if the queue is full, in which case ownership is retained by the // caller. func (q *segmentQueue) enqueue(s *segment) bool { + // q.ep.receiveBufferParams() must be called without holding q.mu to + // avoid lock order inversion. + bufSz := q.ep.receiveBufferSize() + used := q.ep.receiveMemUsed() q.mu.Lock() - r := q.used < q.limit - if r { + // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue + // is currently full). + allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + + if allow { q.list.PushBack(s) - q.used++ + // Set the owner now that the endpoint owns the segment. + s.setOwner(q.ep, recvQ) } q.mu.Unlock() - return r + return allow } // dequeue removes and returns the next segment from queue, if one exists. @@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment { s := q.list.Front() if s != nil { q.list.Remove(s) - q.used-- } q.mu.Unlock() return s } + +// freeze prevents any more segments from being added to the queue. i.e all +// future segmentQueue.enqueue will return false and not add the segment to the +// queue till the queue is unfroze with a corresponding segmentQueue.thaw call. +func (q *segmentQueue) freeze() { + q.mu.Lock() + q.frozen = true + q.mu.Unlock() +} + +// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and +// allows new segments to be queued again. +func (q *segmentQueue) thaw() { + q.mu.Lock() + q.frozen = false + q.mu.Unlock() +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 77e0d0e97..1da199cd6 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -169,6 +169,7 @@ func (x *endpoint) StateFields() []string { "rcvBufSize", "rcvBufUsed", "rcvAutoParams", + "rcvMemUsed", "ownedByUser", "state", "boundNICID", @@ -231,11 +232,11 @@ func (x *endpoint) StateSave(m state.Sink) { var lastError string = x.saveLastError() m.SaveValue(3, lastError) var state EndpointState = x.saveState() - m.SaveValue(10, state) + m.SaveValue(11, state) var recentTSTime unixTime = x.saveRecentTSTime() - m.SaveValue(25, recentTSTime) + m.SaveValue(26, recentTSTime) var acceptedChan []*endpoint = x.saveAcceptedChan() - m.SaveValue(51, acceptedChan) + m.SaveValue(52, acceptedChan) m.Save(0, &x.EndpointInfo) m.Save(1, &x.waiterQueue) m.Save(2, &x.uniqueID) @@ -244,57 +245,58 @@ func (x *endpoint) StateSave(m state.Sink) { m.Save(6, &x.rcvBufSize) m.Save(7, &x.rcvBufUsed) m.Save(8, &x.rcvAutoParams) - m.Save(9, &x.ownedByUser) - m.Save(11, &x.boundNICID) - m.Save(12, &x.ttl) - m.Save(13, &x.v6only) - m.Save(14, &x.isConnectNotified) - m.Save(15, &x.broadcast) - m.Save(16, &x.portFlags) - m.Save(17, &x.boundBindToDevice) - m.Save(18, &x.boundPortFlags) - m.Save(19, &x.boundDest) - m.Save(20, &x.effectiveNetProtos) - m.Save(21, &x.workerRunning) - m.Save(22, &x.workerCleanup) - m.Save(23, &x.sendTSOk) - m.Save(24, &x.recentTS) - m.Save(26, &x.tsOffset) - m.Save(27, &x.shutdownFlags) - m.Save(28, &x.sackPermitted) - m.Save(29, &x.sack) - m.Save(30, &x.bindToDevice) - m.Save(31, &x.delay) - m.Save(32, &x.cork) - m.Save(33, &x.scoreboard) - m.Save(34, &x.slowAck) - m.Save(35, &x.segmentQueue) - m.Save(36, &x.synRcvdCount) - m.Save(37, &x.userMSS) - m.Save(38, &x.maxSynRetries) - m.Save(39, &x.windowClamp) - m.Save(40, &x.sndBufSize) - m.Save(41, &x.sndBufUsed) - m.Save(42, &x.sndClosed) - m.Save(43, &x.sndBufInQueue) - m.Save(44, &x.sndQueue) - m.Save(45, &x.cc) - m.Save(46, &x.packetTooBigCount) - m.Save(47, &x.sndMTU) - m.Save(48, &x.keepalive) - m.Save(49, &x.userTimeout) - m.Save(50, &x.deferAccept) - m.Save(52, &x.rcv) - m.Save(53, &x.snd) - m.Save(54, &x.connectingAddress) - m.Save(55, &x.amss) - m.Save(56, &x.sendTOS) - m.Save(57, &x.gso) - m.Save(58, &x.tcpLingerTimeout) - m.Save(59, &x.closed) - m.Save(60, &x.txHash) - m.Save(61, &x.owner) - m.Save(62, &x.linger) + m.Save(9, &x.rcvMemUsed) + m.Save(10, &x.ownedByUser) + m.Save(12, &x.boundNICID) + m.Save(13, &x.ttl) + m.Save(14, &x.v6only) + m.Save(15, &x.isConnectNotified) + m.Save(16, &x.broadcast) + m.Save(17, &x.portFlags) + m.Save(18, &x.boundBindToDevice) + m.Save(19, &x.boundPortFlags) + m.Save(20, &x.boundDest) + m.Save(21, &x.effectiveNetProtos) + m.Save(22, &x.workerRunning) + m.Save(23, &x.workerCleanup) + m.Save(24, &x.sendTSOk) + m.Save(25, &x.recentTS) + m.Save(27, &x.tsOffset) + m.Save(28, &x.shutdownFlags) + m.Save(29, &x.sackPermitted) + m.Save(30, &x.sack) + m.Save(31, &x.bindToDevice) + m.Save(32, &x.delay) + m.Save(33, &x.cork) + m.Save(34, &x.scoreboard) + m.Save(35, &x.slowAck) + m.Save(36, &x.segmentQueue) + m.Save(37, &x.synRcvdCount) + m.Save(38, &x.userMSS) + m.Save(39, &x.maxSynRetries) + m.Save(40, &x.windowClamp) + m.Save(41, &x.sndBufSize) + m.Save(42, &x.sndBufUsed) + m.Save(43, &x.sndClosed) + m.Save(44, &x.sndBufInQueue) + m.Save(45, &x.sndQueue) + m.Save(46, &x.cc) + m.Save(47, &x.packetTooBigCount) + m.Save(48, &x.sndMTU) + m.Save(49, &x.keepalive) + m.Save(50, &x.userTimeout) + m.Save(51, &x.deferAccept) + m.Save(53, &x.rcv) + m.Save(54, &x.snd) + m.Save(55, &x.connectingAddress) + m.Save(56, &x.amss) + m.Save(57, &x.sendTOS) + m.Save(58, &x.gso) + m.Save(59, &x.tcpLingerTimeout) + m.Save(60, &x.closed) + m.Save(61, &x.txHash) + m.Save(62, &x.owner) + m.Save(63, &x.linger) } func (x *endpoint) StateLoad(m state.Source) { @@ -306,61 +308,62 @@ func (x *endpoint) StateLoad(m state.Source) { m.Load(6, &x.rcvBufSize) m.Load(7, &x.rcvBufUsed) m.Load(8, &x.rcvAutoParams) - m.Load(9, &x.ownedByUser) - m.Load(11, &x.boundNICID) - m.Load(12, &x.ttl) - m.Load(13, &x.v6only) - m.Load(14, &x.isConnectNotified) - m.Load(15, &x.broadcast) - m.Load(16, &x.portFlags) - m.Load(17, &x.boundBindToDevice) - m.Load(18, &x.boundPortFlags) - m.Load(19, &x.boundDest) - m.Load(20, &x.effectiveNetProtos) - m.Load(21, &x.workerRunning) - m.Load(22, &x.workerCleanup) - m.Load(23, &x.sendTSOk) - m.Load(24, &x.recentTS) - m.Load(26, &x.tsOffset) - m.Load(27, &x.shutdownFlags) - m.Load(28, &x.sackPermitted) - m.Load(29, &x.sack) - m.Load(30, &x.bindToDevice) - m.Load(31, &x.delay) - m.Load(32, &x.cork) - m.Load(33, &x.scoreboard) - m.Load(34, &x.slowAck) - m.LoadWait(35, &x.segmentQueue) - m.Load(36, &x.synRcvdCount) - m.Load(37, &x.userMSS) - m.Load(38, &x.maxSynRetries) - m.Load(39, &x.windowClamp) - m.Load(40, &x.sndBufSize) - m.Load(41, &x.sndBufUsed) - m.Load(42, &x.sndClosed) - m.Load(43, &x.sndBufInQueue) - m.LoadWait(44, &x.sndQueue) - m.Load(45, &x.cc) - m.Load(46, &x.packetTooBigCount) - m.Load(47, &x.sndMTU) - m.Load(48, &x.keepalive) - m.Load(49, &x.userTimeout) - m.Load(50, &x.deferAccept) - m.LoadWait(52, &x.rcv) - m.LoadWait(53, &x.snd) - m.Load(54, &x.connectingAddress) - m.Load(55, &x.amss) - m.Load(56, &x.sendTOS) - m.Load(57, &x.gso) - m.Load(58, &x.tcpLingerTimeout) - m.Load(59, &x.closed) - m.Load(60, &x.txHash) - m.Load(61, &x.owner) - m.Load(62, &x.linger) + m.Load(9, &x.rcvMemUsed) + m.Load(10, &x.ownedByUser) + m.Load(12, &x.boundNICID) + m.Load(13, &x.ttl) + m.Load(14, &x.v6only) + m.Load(15, &x.isConnectNotified) + m.Load(16, &x.broadcast) + m.Load(17, &x.portFlags) + m.Load(18, &x.boundBindToDevice) + m.Load(19, &x.boundPortFlags) + m.Load(20, &x.boundDest) + m.Load(21, &x.effectiveNetProtos) + m.Load(22, &x.workerRunning) + m.Load(23, &x.workerCleanup) + m.Load(24, &x.sendTSOk) + m.Load(25, &x.recentTS) + m.Load(27, &x.tsOffset) + m.Load(28, &x.shutdownFlags) + m.Load(29, &x.sackPermitted) + m.Load(30, &x.sack) + m.Load(31, &x.bindToDevice) + m.Load(32, &x.delay) + m.Load(33, &x.cork) + m.Load(34, &x.scoreboard) + m.Load(35, &x.slowAck) + m.LoadWait(36, &x.segmentQueue) + m.Load(37, &x.synRcvdCount) + m.Load(38, &x.userMSS) + m.Load(39, &x.maxSynRetries) + m.Load(40, &x.windowClamp) + m.Load(41, &x.sndBufSize) + m.Load(42, &x.sndBufUsed) + m.Load(43, &x.sndClosed) + m.Load(44, &x.sndBufInQueue) + m.LoadWait(45, &x.sndQueue) + m.Load(46, &x.cc) + m.Load(47, &x.packetTooBigCount) + m.Load(48, &x.sndMTU) + m.Load(49, &x.keepalive) + m.Load(50, &x.userTimeout) + m.Load(51, &x.deferAccept) + m.LoadWait(53, &x.rcv) + m.LoadWait(54, &x.snd) + m.Load(55, &x.connectingAddress) + m.Load(56, &x.amss) + m.Load(57, &x.sendTOS) + m.Load(58, &x.gso) + m.Load(59, &x.tcpLingerTimeout) + m.Load(60, &x.closed) + m.Load(61, &x.txHash) + m.Load(62, &x.owner) + m.Load(63, &x.linger) m.LoadValue(3, new(string), func(y interface{}) { x.loadLastError(y.(string)) }) - m.LoadValue(10, new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) - m.LoadValue(25, new(unixTime), func(y interface{}) { x.loadRecentTSTime(y.(unixTime)) }) - m.LoadValue(51, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) + m.LoadValue(11, new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) + m.LoadValue(26, new(unixTime), func(y interface{}) { x.loadRecentTSTime(y.(unixTime)) }) + m.LoadValue(52, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) m.AfterLoad(x.afterLoad) } @@ -446,7 +449,6 @@ func (x *receiver) StateFields() []string { "closed", "pendingRcvdSegments", "pendingBufUsed", - "pendingBufSize", "lastRcvdAckTime", } } @@ -456,7 +458,7 @@ func (x *receiver) beforeSave() {} func (x *receiver) StateSave(m state.Sink) { x.beforeSave() var lastRcvdAckTime unixTime = x.saveLastRcvdAckTime() - m.SaveValue(9, lastRcvdAckTime) + m.SaveValue(8, lastRcvdAckTime) m.Save(0, &x.ep) m.Save(1, &x.rcvNxt) m.Save(2, &x.rcvAcc) @@ -465,7 +467,6 @@ func (x *receiver) StateSave(m state.Sink) { m.Save(5, &x.closed) m.Save(6, &x.pendingRcvdSegments) m.Save(7, &x.pendingBufUsed) - m.Save(8, &x.pendingBufSize) } func (x *receiver) afterLoad() {} @@ -479,8 +480,7 @@ func (x *receiver) StateLoad(m state.Source) { m.Load(5, &x.closed) m.Load(6, &x.pendingRcvdSegments) m.Load(7, &x.pendingBufUsed) - m.Load(8, &x.pendingBufSize) - m.LoadValue(9, new(unixTime), func(y interface{}) { x.loadLastRcvdAckTime(y.(unixTime)) }) + m.LoadValue(8, new(unixTime), func(y interface{}) { x.loadLastRcvdAckTime(y.(unixTime)) }) } func (x *renoState) StateTypeName() string { @@ -540,6 +540,8 @@ func (x *segment) StateFields() []string { return []string{ "segmentEntry", "refCnt", + "ep", + "qFlags", "data", "hdr", "viewToDeliver", @@ -563,26 +565,28 @@ func (x *segment) beforeSave() {} func (x *segment) StateSave(m state.Sink) { x.beforeSave() var data buffer.VectorisedView = x.saveData() - m.SaveValue(2, data) + m.SaveValue(4, data) var options []byte = x.saveOptions() - m.SaveValue(12, options) + m.SaveValue(14, options) var rcvdTime unixTime = x.saveRcvdTime() - m.SaveValue(14, rcvdTime) + m.SaveValue(16, rcvdTime) var xmitTime unixTime = x.saveXmitTime() - m.SaveValue(15, xmitTime) + m.SaveValue(17, xmitTime) m.Save(0, &x.segmentEntry) m.Save(1, &x.refCnt) - m.Save(3, &x.hdr) - m.Save(4, &x.viewToDeliver) - m.Save(5, &x.sequenceNumber) - m.Save(6, &x.ackNumber) - m.Save(7, &x.flags) - m.Save(8, &x.window) - m.Save(9, &x.csum) - m.Save(10, &x.csumValid) - m.Save(11, &x.parsedOptions) - m.Save(13, &x.hasNewSACKInfo) - m.Save(16, &x.xmitCount) + m.Save(2, &x.ep) + m.Save(3, &x.qFlags) + m.Save(5, &x.hdr) + m.Save(6, &x.viewToDeliver) + m.Save(7, &x.sequenceNumber) + m.Save(8, &x.ackNumber) + m.Save(9, &x.flags) + m.Save(10, &x.window) + m.Save(11, &x.csum) + m.Save(12, &x.csumValid) + m.Save(13, &x.parsedOptions) + m.Save(15, &x.hasNewSACKInfo) + m.Save(18, &x.xmitCount) } func (x *segment) afterLoad() {} @@ -590,21 +594,23 @@ func (x *segment) afterLoad() {} func (x *segment) StateLoad(m state.Source) { m.Load(0, &x.segmentEntry) m.Load(1, &x.refCnt) - m.Load(3, &x.hdr) - m.Load(4, &x.viewToDeliver) - m.Load(5, &x.sequenceNumber) - m.Load(6, &x.ackNumber) - m.Load(7, &x.flags) - m.Load(8, &x.window) - m.Load(9, &x.csum) - m.Load(10, &x.csumValid) - m.Load(11, &x.parsedOptions) - m.Load(13, &x.hasNewSACKInfo) - m.Load(16, &x.xmitCount) - m.LoadValue(2, new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) - m.LoadValue(12, new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) }) - m.LoadValue(14, new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) }) - m.LoadValue(15, new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) }) + m.Load(2, &x.ep) + m.Load(3, &x.qFlags) + m.Load(5, &x.hdr) + m.Load(6, &x.viewToDeliver) + m.Load(7, &x.sequenceNumber) + m.Load(8, &x.ackNumber) + m.Load(9, &x.flags) + m.Load(10, &x.window) + m.Load(11, &x.csum) + m.Load(12, &x.csumValid) + m.Load(13, &x.parsedOptions) + m.Load(15, &x.hasNewSACKInfo) + m.Load(18, &x.xmitCount) + m.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) + m.LoadValue(14, new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) }) + m.LoadValue(16, new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) }) + m.LoadValue(17, new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) }) } func (x *segmentQueue) StateTypeName() string { @@ -614,8 +620,8 @@ func (x *segmentQueue) StateTypeName() string { func (x *segmentQueue) StateFields() []string { return []string{ "list", - "limit", - "used", + "ep", + "frozen", } } @@ -624,16 +630,16 @@ func (x *segmentQueue) beforeSave() {} func (x *segmentQueue) StateSave(m state.Sink) { x.beforeSave() m.Save(0, &x.list) - m.Save(1, &x.limit) - m.Save(2, &x.used) + m.Save(1, &x.ep) + m.Save(2, &x.frozen) } func (x *segmentQueue) afterLoad() {} func (x *segmentQueue) StateLoad(m state.Source) { m.LoadWait(0, &x.list) - m.Load(1, &x.limit) - m.Load(2, &x.used) + m.Load(1, &x.ep) + m.Load(2, &x.frozen) } func (x *sender) StateTypeName() string { |