summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/tcp/connect.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go165
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go50
-rw-r--r--pkg/tcpip/transport/tcp/segment.go45
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go52
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go304
7 files changed, 391 insertions, 248 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 881752371..6891fd245 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -898,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -1003,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
// Bootstrap the auto tuning algorithm. Starting at zero will
// result in a really large receive window after the first auto
// tuning adjustment.
@@ -1136,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
}
cont, err := e.handleSegment(s)
+ s.decRef()
if err != nil {
- s.decRef()
return err
}
if !cont {
- s.decRef()
return nil
}
}
@@ -1221,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
return true, nil
}
+ // Increase counter if after processing the segment we would potentially
+ // advertise a zero window.
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
@@ -1233,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
- s.decRef()
return false, nil
}
@@ -1425,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.rcv.nonZeroWindow()
}
- if n&notifyReceiveWindowChanged != 0 {
- e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
- }
-
if n&notifyMTUChanged != 0 {
e.sndBufMu.Lock()
count := e.packetTooBigCount
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 120483838..87db13720 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,6 +63,17 @@ const (
StateClosing
)
+const (
+ // rcvAdvWndScale is used to split the available socket buffer into
+ // application buffer and the window to be advertised to the peer. This is
+ // currently hard coded to split the available space equally.
+ rcvAdvWndScale = 1
+
+ // SegOverheadFactor is used to multiply the value provided by the
+ // user on a SetSockOpt for setting the socket send/receive buffer sizes.
+ SegOverheadFactor = 2
+)
+
// connected returns true when s is one of the states representing an
// endpoint connected to a peer.
func (s EndpointState) connected() bool {
@@ -149,7 +160,6 @@ func (s EndpointState) String() string {
// Reasons for notifying the protocol goroutine.
const (
notifyNonZeroReceiveWindow = 1 << iota
- notifyReceiveWindowChanged
notifyClose
notifyMTUChanged
notifyDrain
@@ -384,13 +394,26 @@ type endpoint struct {
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- rcvBufSize int
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ // rcvBufSize is the total size of the receive buffer.
+ rcvBufSize int
+ // rcvBufUsed is the actual number of payload bytes held in the receive buffer
+ // not counting any overheads of the segments itself. NOTE: This will always
+ // be strictly <= rcvMemUsed below.
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
+ // rcvMemUsed tracks the total amount of memory in use by received segments
+ // held in rcvList, pendingRcvdSegments and the segment queue. This is used to
+ // compute the window and the actual available buffer space. This is distinct
+ // from rcvBufUsed above which is the actual number of payload bytes held in
+ // the buffer not including any segment overheads.
+ //
+ // rcvMemUsed must be accessed atomically.
+ rcvMemUsed int32
+
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
mu sync.Mutex `state:"nosave"`
@@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.probe = p
}
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
@@ -1129,10 +1152,16 @@ func (e *endpoint) cleanupLocked() {
tcpip.DeleteDanglingEndpoint(e)
}
+// wndFromSpace returns the window that we can advertise based on the available
+// receive buffer space.
+func wndFromSpace(space int) int {
+ return space / (1 << rcvAdvWndScale)
+}
+
// initialReceiveWindow returns the initial receive window to advertise in the
// SYN/SYN-ACK.
func (e *endpoint) initialReceiveWindow() int {
- rcvWnd := e.receiveBufferAvailable()
+ rcvWnd := wndFromSpace(e.receiveBufferAvailable())
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
@@ -1209,14 +1238,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = rcvWnd
- availAfter := e.receiveBufferAvailableLocked()
- mask := uint32(notifyReceiveWindowChanged)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
- e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -1293,18 +1320,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
v := views[s.viewToDeliver]
s.viewToDeliver++
+ var delta int
if s.viewToDeliver >= len(views) {
e.rcvList.Remove(s)
+ // We only free up receive buffer space when the segment is released as the
+ // segment is still holding on to the views even though some views have been
+ // read out to the user.
+ delta = s.segMemSize()
s.decRef()
}
e.rcvBufUsed -= len(v)
-
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1481,11 +1512,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
}
// windowCrossedACKThresholdLocked checks if the receive window to be announced
-// now would be under aMSS or under half receive buffer, whichever smaller. This
-// is useful as a receive side silly window syndrome prevention mechanism. If
-// window grows to reasonable value, we should send ACK to the sender to inform
-// the rx space is now large. We also want ensure a series of small read()'s
-// won't trigger a flood of spurious tiny ACK's.
+// would be under aMSS or under the window derived from half receive buffer,
+// whichever smaller. This is useful as a receive side silly window syndrome
+// prevention mechanism. If window grows to reasonable value, we should send ACK
+// to the sender to inform the rx space is now large. We also want ensure a
+// series of small read()'s won't trigger a flood of spurious tiny ACK's.
//
// For large receive buffers, the threshold is aMSS - once reader reads more
// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
@@ -1496,17 +1527,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := e.receiveBufferAvailableLocked()
+ newAvail := wndFromSpace(e.receiveBufferAvailableLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
}
-
threshold := int(e.amss)
- if threshold > e.rcvBufSize/2 {
- threshold = e.rcvBufSize / 2
+ // rcvBufFraction is the inverse of the fraction of receive buffer size that
+ // is used to decide if the available buffer space is now above it.
+ const rcvBufFraction = 2
+ if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ threshold = wndThreshold
}
-
switch {
case oldAvail < threshold && newAvail >= threshold:
return true, true
@@ -1636,17 +1668,23 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Make sure the receive buffer size is within the min and max
// allowed.
var rs tcpip.TCPReceiveBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
+ }
+
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < rs.Min {
v = rs.Min
}
- if v > rs.Max {
- v = rs.Max
- }
+ } else {
+ v = math.MaxInt32
}
- mask := uint32(notifyReceiveWindowChanged)
-
e.LockUser()
e.rcvListMu.Lock()
@@ -1660,14 +1698,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
v = 1 << scale
}
- // Make sure 2*size doesn't overflow.
- if v > math.MaxInt32/2 {
- v = math.MaxInt32 / 2
- }
-
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = v
- availAfter := e.receiveBufferAvailableLocked()
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvAutoParams.disabled = true
@@ -1675,24 +1708,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
e.rcvListMu.Unlock()
e.UnlockUser()
- e.notifyProtocolGoroutine(mask)
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
var ss tcpip.TCPSendBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err))
+ }
+
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < ss.Min {
v = ss.Min
}
- if v > ss.Max {
- v = ss.Max
- }
+ } else {
+ v = math.MaxInt32
}
e.sndBufMu.Lock()
@@ -2699,13 +2739,8 @@ func (e *endpoint) updateSndBufferUsage(v int) {
func (e *endpoint) readyToRead(s *segment) {
e.rcvListMu.Lock()
if s != nil {
+ e.rcvBufUsed += s.payloadSize()
s.incRef()
- e.rcvBufUsed += s.data.Size()
- // Increase counter if the receive window falls down below MSS
- // or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
@@ -2720,15 +2755,17 @@ func (e *endpoint) readyToRead(s *segment) {
func (e *endpoint) receiveBufferAvailableLocked() int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
- if e.rcvBufUsed >= e.rcvBufSize {
+ memUsed := e.receiveMemUsed()
+ if memUsed >= e.rcvBufSize {
return 0
}
- return e.rcvBufSize - e.rcvBufUsed
+ return e.rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
-// receive buffer.
+// receive buffer based on the actual memory used by all segments held in
+// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
e.rcvListMu.Lock()
available := e.receiveBufferAvailableLocked()
@@ -2736,14 +2773,35 @@ func (e *endpoint) receiveBufferAvailable() int {
return available
}
+// receiveBufferUsed returns the amount of in-use receive buffer.
+func (e *endpoint) receiveBufferUsed() int {
+ e.rcvListMu.Lock()
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+ return used
+}
+
+// receiveBufferSize returns the current size of the receive buffer.
func (e *endpoint) receiveBufferSize() int {
e.rcvListMu.Lock()
size := e.rcvBufSize
e.rcvListMu.Unlock()
-
return size
}
+// receiveMemUsed returns the total memory in use by segments held by this
+// endpoint.
+func (e *endpoint) receiveMemUsed() int {
+ return int(atomic.LoadInt32(&e.rcvMemUsed))
+}
+
+// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed.
+func (e *endpoint) updateReceiveMemUsed(delta int) {
+ atomic.AddInt32(&e.rcvMemUsed, int32(delta))
+}
+
+// maxReceiveBufferSize returns the stack wide maximum receive buffer size for
+// an endpoint.
func (e *endpoint) maxReceiveBufferSize() int {
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
@@ -2894,7 +2952,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
RcvAcc: e.rcv.rcvAcc,
RcvWndScale: e.rcv.rcvWndScale,
PendingBufUsed: e.rcv.pendingBufUsed,
- PendingBufSize: e.rcv.pendingBufSize,
}
// Copy sender state.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 41d0050f3..b25431467 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() {
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
- e.segmentQueue.setLimit(0)
+ e.segmentQueue.freeze()
e.mu.Lock()
defer e.mu.Unlock()
@@ -178,7 +178,7 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index cfd43b5e3..4aafb4d22 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -47,22 +47,24 @@ type receiver struct {
closed bool
+ // pendingRcvdSegments is bounded by the receive buffer size of the
+ // endpoint.
pendingRcvdSegments segmentHeap
- pendingBufUsed seqnum.Size
- pendingBufSize seqnum.Size
+ // pendingBufUsed tracks the total number of bytes (including segment
+ // overhead) currently queued in pendingRcvdSegments.
+ pendingBufUsed int
// Time when the last ack was received.
lastRcvdAckTime time.Time `state:".(unixTime)"`
}
-func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
lastRcvdAckTime: time.Now(),
}
}
@@ -85,15 +87,23 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- // Calculate the window size based on the available buffer space.
- receiveBufferAvailable := r.ep.receiveBufferAvailable()
- acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
- if r.rcvAcc.LessThan(acc) {
- r.rcvAcc = acc
+ avail := wndFromSpace(r.ep.receiveBufferAvailable())
+ acc := r.rcvNxt.Add(seqnum.Size(avail))
+ newWnd := r.rcvNxt.Size(acc)
+ curWnd := r.rcvNxt.Size(r.rcvAcc)
+
+ // Update rcvAcc only if new window is > previously advertised window. We
+ // should never shrink the acceptable sequence space once it has been
+ // advertised the peer. If we shrink the acceptable sequence space then we
+ // would end up dropping bytes that might already be in flight.
+ if newWnd > curWnd {
+ r.rcvAcc = r.rcvNxt.Add(newWnd)
+ } else {
+ newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
- r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
+ r.rcvWnd = newWnd
return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
}
@@ -195,7 +205,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
r.pendingRcvdSegments[i].decRef()
+
// Note that slice truncation does not allow garbage collection of
// truncated items, thus truncated items must be set to nil to avoid
// memory leaks.
@@ -384,10 +396,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
- // We only store the segment if it's within our buffer
- // size limit.
- if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += seqnum.Size(s.segMemSize())
+ // We only store the segment if it's within our buffer size limit.
+ //
+ // Only use 75% of the receive buffer queue for out-of-order
+ // segments. This ensures that we always leave some space for the inorder
+ // segments to arrive allowing pending segments to be processed and
+ // delivered to the user.
+ if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 {
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed += s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -421,7 +439,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= seqnum.Size(s.segMemSize())
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed -= s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.decRef()
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 94307d31a..13acaf753 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "fmt"
"sync/atomic"
"time"
@@ -24,6 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// queueFlags are used to indicate which queue of an endpoint a particular segment
+// belongs to. This is used to track memory accounting correctly.
+type queueFlags uint8
+
+const (
+ recvQ queueFlags = 1 << iota
+ sendQ
+)
+
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
// segment is mostly immutable, the only field allowed to change is viewToDeliver.
@@ -32,6 +42,8 @@ import (
type segment struct {
segmentEntry
refCnt int32
+ ep *endpoint
+ qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
@@ -100,6 +112,8 @@ func (s *segment) clone() *segment {
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
+ ep: s.ep,
+ qFlags: s.qFlags,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool {
return s.flags&flags == flags
}
+// setOwner sets the owning endpoint for this segment. Its required
+// to be called to ensure memory accounting for receive/send buffer
+// queues is done properly.
+func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) {
+ switch qFlags {
+ case recvQ:
+ ep.updateReceiveMemUsed(s.segMemSize())
+ case sendQ:
+ // no memory account for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b", qFlags))
+ }
+ s.ep = ep
+ s.qFlags = qFlags
+}
+
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ if s.ep != nil {
+ switch s.qFlags {
+ case recvQ:
+ s.ep.updateReceiveMemUsed(-s.segMemSize())
+ case sendQ:
+ // no memory accounting for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
+ }
+ }
s.route.Release()
}
}
@@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// payloadSize is the size of s.data.
+func (s *segment) payloadSize() int {
+ return s.data.Size()
+}
+
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 48a257137..54545a1b1 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -22,16 +22,16 @@ import (
//
// +stateify savable
type segmentQueue struct {
- mu sync.Mutex `state:"nosave"`
- list segmentList `state:"wait"`
- limit int
- used int
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ ep *endpoint
+ frozen bool
}
// emptyLocked determines if the queue is empty.
// Preconditions: q.mu must be held.
func (q *segmentQueue) emptyLocked() bool {
- return q.used == 0
+ return q.list.Empty()
}
// empty determines if the queue is empty.
@@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool {
return r
}
-// setLimit updates the limit. No segments are immediately dropped in case the
-// queue becomes full due to the new limit.
-func (q *segmentQueue) setLimit(limit int) {
- q.mu.Lock()
- q.limit = limit
- q.mu.Unlock()
-}
-
// enqueue adds the given segment to the queue.
//
// Returns true when the segment is successfully added to the queue, in which
@@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) {
// false if the queue is full, in which case ownership is retained by the
// caller.
func (q *segmentQueue) enqueue(s *segment) bool {
+ // q.ep.receiveBufferParams() must be called without holding q.mu to
+ // avoid lock order inversion.
+ bufSz := q.ep.receiveBufferSize()
+ used := q.ep.receiveMemUsed()
q.mu.Lock()
- r := q.used < q.limit
- if r {
+ // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
+ // is currently full).
+ allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+
+ if allow {
q.list.PushBack(s)
- q.used++
+ // Set the owner now that the endpoint owns the segment.
+ s.setOwner(q.ep, recvQ)
}
q.mu.Unlock()
- return r
+ return allow
}
// dequeue removes and returns the next segment from queue, if one exists.
@@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment {
s := q.list.Front()
if s != nil {
q.list.Remove(s)
- q.used--
}
q.mu.Unlock()
return s
}
+
+// freeze prevents any more segments from being added to the queue. i.e all
+// future segmentQueue.enqueue will return false and not add the segment to the
+// queue till the queue is unfroze with a corresponding segmentQueue.thaw call.
+func (q *segmentQueue) freeze() {
+ q.mu.Lock()
+ q.frozen = true
+ q.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and
+// allows new segments to be queued again.
+func (q *segmentQueue) thaw() {
+ q.mu.Lock()
+ q.frozen = false
+ q.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 77e0d0e97..1da199cd6 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -169,6 +169,7 @@ func (x *endpoint) StateFields() []string {
"rcvBufSize",
"rcvBufUsed",
"rcvAutoParams",
+ "rcvMemUsed",
"ownedByUser",
"state",
"boundNICID",
@@ -231,11 +232,11 @@ func (x *endpoint) StateSave(m state.Sink) {
var lastError string = x.saveLastError()
m.SaveValue(3, lastError)
var state EndpointState = x.saveState()
- m.SaveValue(10, state)
+ m.SaveValue(11, state)
var recentTSTime unixTime = x.saveRecentTSTime()
- m.SaveValue(25, recentTSTime)
+ m.SaveValue(26, recentTSTime)
var acceptedChan []*endpoint = x.saveAcceptedChan()
- m.SaveValue(51, acceptedChan)
+ m.SaveValue(52, acceptedChan)
m.Save(0, &x.EndpointInfo)
m.Save(1, &x.waiterQueue)
m.Save(2, &x.uniqueID)
@@ -244,57 +245,58 @@ func (x *endpoint) StateSave(m state.Sink) {
m.Save(6, &x.rcvBufSize)
m.Save(7, &x.rcvBufUsed)
m.Save(8, &x.rcvAutoParams)
- m.Save(9, &x.ownedByUser)
- m.Save(11, &x.boundNICID)
- m.Save(12, &x.ttl)
- m.Save(13, &x.v6only)
- m.Save(14, &x.isConnectNotified)
- m.Save(15, &x.broadcast)
- m.Save(16, &x.portFlags)
- m.Save(17, &x.boundBindToDevice)
- m.Save(18, &x.boundPortFlags)
- m.Save(19, &x.boundDest)
- m.Save(20, &x.effectiveNetProtos)
- m.Save(21, &x.workerRunning)
- m.Save(22, &x.workerCleanup)
- m.Save(23, &x.sendTSOk)
- m.Save(24, &x.recentTS)
- m.Save(26, &x.tsOffset)
- m.Save(27, &x.shutdownFlags)
- m.Save(28, &x.sackPermitted)
- m.Save(29, &x.sack)
- m.Save(30, &x.bindToDevice)
- m.Save(31, &x.delay)
- m.Save(32, &x.cork)
- m.Save(33, &x.scoreboard)
- m.Save(34, &x.slowAck)
- m.Save(35, &x.segmentQueue)
- m.Save(36, &x.synRcvdCount)
- m.Save(37, &x.userMSS)
- m.Save(38, &x.maxSynRetries)
- m.Save(39, &x.windowClamp)
- m.Save(40, &x.sndBufSize)
- m.Save(41, &x.sndBufUsed)
- m.Save(42, &x.sndClosed)
- m.Save(43, &x.sndBufInQueue)
- m.Save(44, &x.sndQueue)
- m.Save(45, &x.cc)
- m.Save(46, &x.packetTooBigCount)
- m.Save(47, &x.sndMTU)
- m.Save(48, &x.keepalive)
- m.Save(49, &x.userTimeout)
- m.Save(50, &x.deferAccept)
- m.Save(52, &x.rcv)
- m.Save(53, &x.snd)
- m.Save(54, &x.connectingAddress)
- m.Save(55, &x.amss)
- m.Save(56, &x.sendTOS)
- m.Save(57, &x.gso)
- m.Save(58, &x.tcpLingerTimeout)
- m.Save(59, &x.closed)
- m.Save(60, &x.txHash)
- m.Save(61, &x.owner)
- m.Save(62, &x.linger)
+ m.Save(9, &x.rcvMemUsed)
+ m.Save(10, &x.ownedByUser)
+ m.Save(12, &x.boundNICID)
+ m.Save(13, &x.ttl)
+ m.Save(14, &x.v6only)
+ m.Save(15, &x.isConnectNotified)
+ m.Save(16, &x.broadcast)
+ m.Save(17, &x.portFlags)
+ m.Save(18, &x.boundBindToDevice)
+ m.Save(19, &x.boundPortFlags)
+ m.Save(20, &x.boundDest)
+ m.Save(21, &x.effectiveNetProtos)
+ m.Save(22, &x.workerRunning)
+ m.Save(23, &x.workerCleanup)
+ m.Save(24, &x.sendTSOk)
+ m.Save(25, &x.recentTS)
+ m.Save(27, &x.tsOffset)
+ m.Save(28, &x.shutdownFlags)
+ m.Save(29, &x.sackPermitted)
+ m.Save(30, &x.sack)
+ m.Save(31, &x.bindToDevice)
+ m.Save(32, &x.delay)
+ m.Save(33, &x.cork)
+ m.Save(34, &x.scoreboard)
+ m.Save(35, &x.slowAck)
+ m.Save(36, &x.segmentQueue)
+ m.Save(37, &x.synRcvdCount)
+ m.Save(38, &x.userMSS)
+ m.Save(39, &x.maxSynRetries)
+ m.Save(40, &x.windowClamp)
+ m.Save(41, &x.sndBufSize)
+ m.Save(42, &x.sndBufUsed)
+ m.Save(43, &x.sndClosed)
+ m.Save(44, &x.sndBufInQueue)
+ m.Save(45, &x.sndQueue)
+ m.Save(46, &x.cc)
+ m.Save(47, &x.packetTooBigCount)
+ m.Save(48, &x.sndMTU)
+ m.Save(49, &x.keepalive)
+ m.Save(50, &x.userTimeout)
+ m.Save(51, &x.deferAccept)
+ m.Save(53, &x.rcv)
+ m.Save(54, &x.snd)
+ m.Save(55, &x.connectingAddress)
+ m.Save(56, &x.amss)
+ m.Save(57, &x.sendTOS)
+ m.Save(58, &x.gso)
+ m.Save(59, &x.tcpLingerTimeout)
+ m.Save(60, &x.closed)
+ m.Save(61, &x.txHash)
+ m.Save(62, &x.owner)
+ m.Save(63, &x.linger)
}
func (x *endpoint) StateLoad(m state.Source) {
@@ -306,61 +308,62 @@ func (x *endpoint) StateLoad(m state.Source) {
m.Load(6, &x.rcvBufSize)
m.Load(7, &x.rcvBufUsed)
m.Load(8, &x.rcvAutoParams)
- m.Load(9, &x.ownedByUser)
- m.Load(11, &x.boundNICID)
- m.Load(12, &x.ttl)
- m.Load(13, &x.v6only)
- m.Load(14, &x.isConnectNotified)
- m.Load(15, &x.broadcast)
- m.Load(16, &x.portFlags)
- m.Load(17, &x.boundBindToDevice)
- m.Load(18, &x.boundPortFlags)
- m.Load(19, &x.boundDest)
- m.Load(20, &x.effectiveNetProtos)
- m.Load(21, &x.workerRunning)
- m.Load(22, &x.workerCleanup)
- m.Load(23, &x.sendTSOk)
- m.Load(24, &x.recentTS)
- m.Load(26, &x.tsOffset)
- m.Load(27, &x.shutdownFlags)
- m.Load(28, &x.sackPermitted)
- m.Load(29, &x.sack)
- m.Load(30, &x.bindToDevice)
- m.Load(31, &x.delay)
- m.Load(32, &x.cork)
- m.Load(33, &x.scoreboard)
- m.Load(34, &x.slowAck)
- m.LoadWait(35, &x.segmentQueue)
- m.Load(36, &x.synRcvdCount)
- m.Load(37, &x.userMSS)
- m.Load(38, &x.maxSynRetries)
- m.Load(39, &x.windowClamp)
- m.Load(40, &x.sndBufSize)
- m.Load(41, &x.sndBufUsed)
- m.Load(42, &x.sndClosed)
- m.Load(43, &x.sndBufInQueue)
- m.LoadWait(44, &x.sndQueue)
- m.Load(45, &x.cc)
- m.Load(46, &x.packetTooBigCount)
- m.Load(47, &x.sndMTU)
- m.Load(48, &x.keepalive)
- m.Load(49, &x.userTimeout)
- m.Load(50, &x.deferAccept)
- m.LoadWait(52, &x.rcv)
- m.LoadWait(53, &x.snd)
- m.Load(54, &x.connectingAddress)
- m.Load(55, &x.amss)
- m.Load(56, &x.sendTOS)
- m.Load(57, &x.gso)
- m.Load(58, &x.tcpLingerTimeout)
- m.Load(59, &x.closed)
- m.Load(60, &x.txHash)
- m.Load(61, &x.owner)
- m.Load(62, &x.linger)
+ m.Load(9, &x.rcvMemUsed)
+ m.Load(10, &x.ownedByUser)
+ m.Load(12, &x.boundNICID)
+ m.Load(13, &x.ttl)
+ m.Load(14, &x.v6only)
+ m.Load(15, &x.isConnectNotified)
+ m.Load(16, &x.broadcast)
+ m.Load(17, &x.portFlags)
+ m.Load(18, &x.boundBindToDevice)
+ m.Load(19, &x.boundPortFlags)
+ m.Load(20, &x.boundDest)
+ m.Load(21, &x.effectiveNetProtos)
+ m.Load(22, &x.workerRunning)
+ m.Load(23, &x.workerCleanup)
+ m.Load(24, &x.sendTSOk)
+ m.Load(25, &x.recentTS)
+ m.Load(27, &x.tsOffset)
+ m.Load(28, &x.shutdownFlags)
+ m.Load(29, &x.sackPermitted)
+ m.Load(30, &x.sack)
+ m.Load(31, &x.bindToDevice)
+ m.Load(32, &x.delay)
+ m.Load(33, &x.cork)
+ m.Load(34, &x.scoreboard)
+ m.Load(35, &x.slowAck)
+ m.LoadWait(36, &x.segmentQueue)
+ m.Load(37, &x.synRcvdCount)
+ m.Load(38, &x.userMSS)
+ m.Load(39, &x.maxSynRetries)
+ m.Load(40, &x.windowClamp)
+ m.Load(41, &x.sndBufSize)
+ m.Load(42, &x.sndBufUsed)
+ m.Load(43, &x.sndClosed)
+ m.Load(44, &x.sndBufInQueue)
+ m.LoadWait(45, &x.sndQueue)
+ m.Load(46, &x.cc)
+ m.Load(47, &x.packetTooBigCount)
+ m.Load(48, &x.sndMTU)
+ m.Load(49, &x.keepalive)
+ m.Load(50, &x.userTimeout)
+ m.Load(51, &x.deferAccept)
+ m.LoadWait(53, &x.rcv)
+ m.LoadWait(54, &x.snd)
+ m.Load(55, &x.connectingAddress)
+ m.Load(56, &x.amss)
+ m.Load(57, &x.sendTOS)
+ m.Load(58, &x.gso)
+ m.Load(59, &x.tcpLingerTimeout)
+ m.Load(60, &x.closed)
+ m.Load(61, &x.txHash)
+ m.Load(62, &x.owner)
+ m.Load(63, &x.linger)
m.LoadValue(3, new(string), func(y interface{}) { x.loadLastError(y.(string)) })
- m.LoadValue(10, new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) })
- m.LoadValue(25, new(unixTime), func(y interface{}) { x.loadRecentTSTime(y.(unixTime)) })
- m.LoadValue(51, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
+ m.LoadValue(11, new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) })
+ m.LoadValue(26, new(unixTime), func(y interface{}) { x.loadRecentTSTime(y.(unixTime)) })
+ m.LoadValue(52, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
m.AfterLoad(x.afterLoad)
}
@@ -446,7 +449,6 @@ func (x *receiver) StateFields() []string {
"closed",
"pendingRcvdSegments",
"pendingBufUsed",
- "pendingBufSize",
"lastRcvdAckTime",
}
}
@@ -456,7 +458,7 @@ func (x *receiver) beforeSave() {}
func (x *receiver) StateSave(m state.Sink) {
x.beforeSave()
var lastRcvdAckTime unixTime = x.saveLastRcvdAckTime()
- m.SaveValue(9, lastRcvdAckTime)
+ m.SaveValue(8, lastRcvdAckTime)
m.Save(0, &x.ep)
m.Save(1, &x.rcvNxt)
m.Save(2, &x.rcvAcc)
@@ -465,7 +467,6 @@ func (x *receiver) StateSave(m state.Sink) {
m.Save(5, &x.closed)
m.Save(6, &x.pendingRcvdSegments)
m.Save(7, &x.pendingBufUsed)
- m.Save(8, &x.pendingBufSize)
}
func (x *receiver) afterLoad() {}
@@ -479,8 +480,7 @@ func (x *receiver) StateLoad(m state.Source) {
m.Load(5, &x.closed)
m.Load(6, &x.pendingRcvdSegments)
m.Load(7, &x.pendingBufUsed)
- m.Load(8, &x.pendingBufSize)
- m.LoadValue(9, new(unixTime), func(y interface{}) { x.loadLastRcvdAckTime(y.(unixTime)) })
+ m.LoadValue(8, new(unixTime), func(y interface{}) { x.loadLastRcvdAckTime(y.(unixTime)) })
}
func (x *renoState) StateTypeName() string {
@@ -540,6 +540,8 @@ func (x *segment) StateFields() []string {
return []string{
"segmentEntry",
"refCnt",
+ "ep",
+ "qFlags",
"data",
"hdr",
"viewToDeliver",
@@ -563,26 +565,28 @@ func (x *segment) beforeSave() {}
func (x *segment) StateSave(m state.Sink) {
x.beforeSave()
var data buffer.VectorisedView = x.saveData()
- m.SaveValue(2, data)
+ m.SaveValue(4, data)
var options []byte = x.saveOptions()
- m.SaveValue(12, options)
+ m.SaveValue(14, options)
var rcvdTime unixTime = x.saveRcvdTime()
- m.SaveValue(14, rcvdTime)
+ m.SaveValue(16, rcvdTime)
var xmitTime unixTime = x.saveXmitTime()
- m.SaveValue(15, xmitTime)
+ m.SaveValue(17, xmitTime)
m.Save(0, &x.segmentEntry)
m.Save(1, &x.refCnt)
- m.Save(3, &x.hdr)
- m.Save(4, &x.viewToDeliver)
- m.Save(5, &x.sequenceNumber)
- m.Save(6, &x.ackNumber)
- m.Save(7, &x.flags)
- m.Save(8, &x.window)
- m.Save(9, &x.csum)
- m.Save(10, &x.csumValid)
- m.Save(11, &x.parsedOptions)
- m.Save(13, &x.hasNewSACKInfo)
- m.Save(16, &x.xmitCount)
+ m.Save(2, &x.ep)
+ m.Save(3, &x.qFlags)
+ m.Save(5, &x.hdr)
+ m.Save(6, &x.viewToDeliver)
+ m.Save(7, &x.sequenceNumber)
+ m.Save(8, &x.ackNumber)
+ m.Save(9, &x.flags)
+ m.Save(10, &x.window)
+ m.Save(11, &x.csum)
+ m.Save(12, &x.csumValid)
+ m.Save(13, &x.parsedOptions)
+ m.Save(15, &x.hasNewSACKInfo)
+ m.Save(18, &x.xmitCount)
}
func (x *segment) afterLoad() {}
@@ -590,21 +594,23 @@ func (x *segment) afterLoad() {}
func (x *segment) StateLoad(m state.Source) {
m.Load(0, &x.segmentEntry)
m.Load(1, &x.refCnt)
- m.Load(3, &x.hdr)
- m.Load(4, &x.viewToDeliver)
- m.Load(5, &x.sequenceNumber)
- m.Load(6, &x.ackNumber)
- m.Load(7, &x.flags)
- m.Load(8, &x.window)
- m.Load(9, &x.csum)
- m.Load(10, &x.csumValid)
- m.Load(11, &x.parsedOptions)
- m.Load(13, &x.hasNewSACKInfo)
- m.Load(16, &x.xmitCount)
- m.LoadValue(2, new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
- m.LoadValue(12, new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) })
- m.LoadValue(14, new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) })
- m.LoadValue(15, new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) })
+ m.Load(2, &x.ep)
+ m.Load(3, &x.qFlags)
+ m.Load(5, &x.hdr)
+ m.Load(6, &x.viewToDeliver)
+ m.Load(7, &x.sequenceNumber)
+ m.Load(8, &x.ackNumber)
+ m.Load(9, &x.flags)
+ m.Load(10, &x.window)
+ m.Load(11, &x.csum)
+ m.Load(12, &x.csumValid)
+ m.Load(13, &x.parsedOptions)
+ m.Load(15, &x.hasNewSACKInfo)
+ m.Load(18, &x.xmitCount)
+ m.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+ m.LoadValue(14, new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) })
+ m.LoadValue(16, new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) })
+ m.LoadValue(17, new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) })
}
func (x *segmentQueue) StateTypeName() string {
@@ -614,8 +620,8 @@ func (x *segmentQueue) StateTypeName() string {
func (x *segmentQueue) StateFields() []string {
return []string{
"list",
- "limit",
- "used",
+ "ep",
+ "frozen",
}
}
@@ -624,16 +630,16 @@ func (x *segmentQueue) beforeSave() {}
func (x *segmentQueue) StateSave(m state.Sink) {
x.beforeSave()
m.Save(0, &x.list)
- m.Save(1, &x.limit)
- m.Save(2, &x.used)
+ m.Save(1, &x.ep)
+ m.Save(2, &x.frozen)
}
func (x *segmentQueue) afterLoad() {}
func (x *segmentQueue) StateLoad(m state.Source) {
m.LoadWait(0, &x.list)
- m.Load(1, &x.limit)
- m.Load(2, &x.used)
+ m.Load(1, &x.ep)
+ m.Load(2, &x.frozen)
}
func (x *sender) StateTypeName() string {