From 5d50c91c4da820220adcfe9ce0741ed1e5e9f4b7 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 24 Sep 2020 07:13:23 -0700 Subject: Change segment/pending queue to use receive buffer limits. segment_queue today has its own standalone limit of MaxUnprocessedSegments but this can be a problem in UnlockUser() we do not release the lock till there are segments to be processed. What can happen is as handleSegments dequeues packets more keep getting queued and we will never release the lock. This can keep happening even if the receive buffer is full because nothing can read() till we release the lock. Further having a separate limit for pending segments makes it harder to track memory usage etc. Unifying the limits makes it easier to reason about memory in use and makes the overall buffer behaviour more consistent. PiperOrigin-RevId: 333508122 --- pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/connect.go | 19 +- pkg/tcpip/transport/tcp/dual_stack_test.go | 16 +- pkg/tcpip/transport/tcp/endpoint.go | 165 ++-- pkg/tcpip/transport/tcp/endpoint_state.go | 4 +- pkg/tcpip/transport/tcp/rcv.go | 50 +- pkg/tcpip/transport/tcp/segment.go | 45 ++ pkg/tcpip/transport/tcp/segment_queue.go | 52 +- pkg/tcpip/transport/tcp/tcp_test.go | 861 +++++++++++---------- pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 19 +- pkg/tcpip/transport/tcp/testing/context/context.go | 85 +- 11 files changed, 776 insertions(+), 541 deletions(-) (limited to 'pkg/tcpip/transport/tcp') diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4778e7b1c..518449602 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -94,6 +94,7 @@ go_test( shard_count = 10, deps = [ ":tcp", + "//pkg/rand", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 881752371..6891fd245 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -898,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -1003,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - rcvBufSize := seqnum.Size(e.receiveBufferSize()) e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. @@ -1136,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { } cont, err := e.handleSegment(s) + s.decRef() if err != nil { - s.decRef() return err } if !cont { - s.decRef() return nil } } @@ -1221,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } + // Increase counter if after processing the segment we would potentially + // advertise a zero window. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. @@ -1233,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // or a notification from the protocolMainLoop (caller goroutine). // This means that with this return, the segment dequeue below can // never occur on a closed endpoint. - s.decRef() return false, nil } @@ -1425,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.rcv.nonZeroWindow() } - if n¬ifyReceiveWindowChanged != 0 { - e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) - } - if n¬ifyMTUChanged != 0 { e.sndBufMu.Lock() count := e.packetTooBigCount diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 94207c141..560b4904c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -78,8 +78,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv4(t, c.GetPacket(), ackCheckers...) @@ -185,8 +185,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv6(t, c.GetV6Packet(), ackCheckers...) @@ -283,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -319,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -352,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -492,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 120483838..87db13720 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,6 +63,17 @@ const ( StateClosing ) +const ( + // rcvAdvWndScale is used to split the available socket buffer into + // application buffer and the window to be advertised to the peer. This is + // currently hard coded to split the available space equally. + rcvAdvWndScale = 1 + + // SegOverheadFactor is used to multiply the value provided by the + // user on a SetSockOpt for setting the socket send/receive buffer sizes. + SegOverheadFactor = 2 +) + // connected returns true when s is one of the states representing an // endpoint connected to a peer. func (s EndpointState) connected() bool { @@ -149,7 +160,6 @@ func (s EndpointState) String() string { // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota - notifyReceiveWindowChanged notifyClose notifyMTUChanged notifyDrain @@ -384,13 +394,26 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + // rcvBufSize is the total size of the receive buffer. + rcvBufSize int + // rcvBufUsed is the actual number of payload bytes held in the receive buffer + // not counting any overheads of the segments itself. NOTE: This will always + // be strictly <= rcvMemUsed below. rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams + // rcvMemUsed tracks the total amount of memory in use by received segments + // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // compute the window and the actual available buffer space. This is distinct + // from rcvBufUsed above which is the actual number of payload bytes held in + // the buffer not including any segment overheads. + // + // rcvMemUsed must be accessed atomically. + rcvMemUsed int32 + // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. mu sync.Mutex `state:"nosave"` @@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.probe = p } - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) @@ -1129,10 +1152,16 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// wndFromSpace returns the window that we can advertise based on the available +// receive buffer space. +func wndFromSpace(space int) int { + return space / (1 << rcvAdvWndScale) +} + // initialReceiveWindow returns the initial receive window to advertise in the // SYN/SYN-ACK. func (e *endpoint) initialReceiveWindow() int { - rcvWnd := e.receiveBufferAvailable() + rcvWnd := wndFromSpace(e.receiveBufferAvailable()) if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } @@ -1209,14 +1238,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = rcvWnd - availAfter := e.receiveBufferAvailableLocked() - mask := uint32(notifyReceiveWindowChanged) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } - e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -1293,18 +1320,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { v := views[s.viewToDeliver] s.viewToDeliver++ + var delta int if s.viewToDeliver >= len(views) { e.rcvList.Remove(s) + // We only free up receive buffer space when the segment is released as the + // segment is still holding on to the views even though some views have been + // read out to the user. + delta = s.segMemSize() s.decRef() } e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1481,11 +1512,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } // windowCrossedACKThresholdLocked checks if the receive window to be announced -// now would be under aMSS or under half receive buffer, whichever smaller. This -// is useful as a receive side silly window syndrome prevention mechanism. If -// window grows to reasonable value, we should send ACK to the sender to inform -// the rx space is now large. We also want ensure a series of small read()'s -// won't trigger a flood of spurious tiny ACK's. +// would be under aMSS or under the window derived from half receive buffer, +// whichever smaller. This is useful as a receive side silly window syndrome +// prevention mechanism. If window grows to reasonable value, we should send ACK +// to the sender to inform the rx space is now large. We also want ensure a +// series of small read()'s won't trigger a flood of spurious tiny ACK's. // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of @@ -1496,17 +1527,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := e.receiveBufferAvailableLocked() + newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 } - threshold := int(e.amss) - if threshold > e.rcvBufSize/2 { - threshold = e.rcvBufSize / 2 + // rcvBufFraction is the inverse of the fraction of receive buffer size that + // is used to decide if the available buffer space is now above it. + const rcvBufFraction = 2 + if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + threshold = wndThreshold } - switch { case oldAvail < threshold && newAvail >= threshold: return true, true @@ -1636,17 +1668,23 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Make sure the receive buffer size is within the min and max // allowed. var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) + } + + if v > rs.Max { + v = rs.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < rs.Min { v = rs.Min } - if v > rs.Max { - v = rs.Max - } + } else { + v = math.MaxInt32 } - mask := uint32(notifyReceiveWindowChanged) - e.LockUser() e.rcvListMu.Lock() @@ -1660,14 +1698,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { v = 1 << scale } - // Make sure 2*size doesn't overflow. - if v > math.MaxInt32/2 { - v = math.MaxInt32 / 2 - } - - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = v - availAfter := e.receiveBufferAvailableLocked() + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvAutoParams.disabled = true @@ -1675,24 +1708,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } e.rcvListMu.Unlock() e.UnlockUser() - e.notifyProtocolGoroutine(mask) case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) + } + + if v > ss.Max { + v = ss.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < ss.Min { v = ss.Min } - if v > ss.Max { - v = ss.Max - } + } else { + v = math.MaxInt32 } e.sndBufMu.Lock() @@ -2699,13 +2739,8 @@ func (e *endpoint) updateSndBufferUsage(v int) { func (e *endpoint) readyToRead(s *segment) { e.rcvListMu.Lock() if s != nil { + e.rcvBufUsed += s.payloadSize() s.incRef() - e.rcvBufUsed += s.data.Size() - // Increase counter if the receive window falls down below MSS - // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } e.rcvList.PushBack(s) } else { e.rcvClosed = true @@ -2720,15 +2755,17 @@ func (e *endpoint) readyToRead(s *segment) { func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. - if e.rcvBufUsed >= e.rcvBufSize { + memUsed := e.receiveMemUsed() + if memUsed >= e.rcvBufSize { return 0 } - return e.rcvBufSize - e.rcvBufUsed + return e.rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the -// receive buffer. +// receive buffer based on the actual memory used by all segments held in +// receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvListMu.Lock() available := e.receiveBufferAvailableLocked() @@ -2736,14 +2773,35 @@ func (e *endpoint) receiveBufferAvailable() int { return available } +// receiveBufferUsed returns the amount of in-use receive buffer. +func (e *endpoint) receiveBufferUsed() int { + e.rcvListMu.Lock() + used := e.rcvBufUsed + e.rcvListMu.Unlock() + return used +} + +// receiveBufferSize returns the current size of the receive buffer. func (e *endpoint) receiveBufferSize() int { e.rcvListMu.Lock() size := e.rcvBufSize e.rcvListMu.Unlock() - return size } +// receiveMemUsed returns the total memory in use by segments held by this +// endpoint. +func (e *endpoint) receiveMemUsed() int { + return int(atomic.LoadInt32(&e.rcvMemUsed)) +} + +// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed. +func (e *endpoint) updateReceiveMemUsed(delta int) { + atomic.AddInt32(&e.rcvMemUsed, int32(delta)) +} + +// maxReceiveBufferSize returns the stack wide maximum receive buffer size for +// an endpoint. func (e *endpoint) maxReceiveBufferSize() int { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { @@ -2894,7 +2952,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RcvAcc: e.rcv.rcvAcc, RcvWndScale: e.rcv.rcvWndScale, PendingBufUsed: e.rcv.pendingBufUsed, - PendingBufSize: e.rcv.pendingBufSize, } // Copy sender state. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 41d0050f3..b25431467 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() { // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. - e.segmentQueue.setLimit(0) + e.segmentQueue.freeze() e.mu.Lock() defer e.mu.Unlock() @@ -178,7 +178,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index cfd43b5e3..4aafb4d22 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -47,22 +47,24 @@ type receiver struct { closed bool + // pendingRcvdSegments is bounded by the receive buffer size of the + // endpoint. pendingRcvdSegments segmentHeap - pendingBufUsed seqnum.Size - pendingBufSize seqnum.Size + // pendingBufUsed tracks the total number of bytes (including segment + // overhead) currently queued in pendingRcvdSegments. + pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` } -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ ep: ep, rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, lastRcvdAckTime: time.Now(), } } @@ -85,15 +87,23 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - // Calculate the window size based on the available buffer space. - receiveBufferAvailable := r.ep.receiveBufferAvailable() - acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) - if r.rcvAcc.LessThan(acc) { - r.rcvAcc = acc + avail := wndFromSpace(r.ep.receiveBufferAvailable()) + acc := r.rcvNxt.Add(seqnum.Size(avail)) + newWnd := r.rcvNxt.Size(acc) + curWnd := r.rcvNxt.Size(r.rcvAcc) + + // Update rcvAcc only if new window is > previously advertised window. We + // should never shrink the acceptable sequence space once it has been + // advertised the peer. If we shrink the acceptable sequence space then we + // would end up dropping bytes that might already be in flight. + if newWnd > curWnd { + r.rcvAcc = r.rcvNxt.Add(newWnd) + } else { + newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. - r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + r.rcvWnd = newWnd return r.rcvNxt, r.rcvWnd >> r.rcvWndScale } @@ -195,7 +205,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of // truncated items, thus truncated items must be set to nil to avoid // memory leaks. @@ -384,10 +396,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { - // We only store the segment if it's within our buffer - // size limit. - if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += seqnum.Size(s.segMemSize()) + // We only store the segment if it's within our buffer size limit. + // + // Only use 75% of the receive buffer queue for out-of-order + // segments. This ensures that we always leave some space for the inorder + // segments to arrive allowing pending segments to be processed and + // delivered to the user. + if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvListMu.Lock() + r.pendingBufUsed += s.segMemSize() + r.ep.rcvListMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -421,7 +439,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= seqnum.Size(s.segMemSize()) + r.ep.rcvListMu.Lock() + r.pendingBufUsed -= s.segMemSize() + r.ep.rcvListMu.Unlock() s.decRef() } return false, nil diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 94307d31a..13acaf753 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "sync/atomic" "time" @@ -24,6 +25,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// queueFlags are used to indicate which queue of an endpoint a particular segment +// belongs to. This is used to track memory accounting correctly. +type queueFlags uint8 + +const ( + recvQ queueFlags = 1 << iota + sendQ +) + // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. // segment is mostly immutable, the only field allowed to change is viewToDeliver. @@ -32,6 +42,8 @@ import ( type segment struct { segmentEntry refCnt int32 + ep *endpoint + qFlags queueFlags id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -100,6 +112,8 @@ func (s *segment) clone() *segment { rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, + ep: s.ep, + qFlags: s.qFlags, } t.data = s.data.Clone(t.views[:]) return t @@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool { return s.flags&flags == flags } +// setOwner sets the owning endpoint for this segment. Its required +// to be called to ensure memory accounting for receive/send buffer +// queues is done properly. +func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) { + switch qFlags { + case recvQ: + ep.updateReceiveMemUsed(s.segMemSize()) + case sendQ: + // no memory account for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b", qFlags)) + } + s.ep = ep + s.qFlags = qFlags +} + func (s *segment) decRef() { if atomic.AddInt32(&s.refCnt, -1) == 0 { + if s.ep != nil { + switch s.qFlags { + case recvQ: + s.ep.updateReceiveMemUsed(-s.segMemSize()) + case sendQ: + // no memory accounting for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) + } + } s.route.Release() } } @@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// payloadSize is the size of s.data. +func (s *segment) payloadSize() int { + return s.data.Size() +} + // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 48a257137..54545a1b1 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -22,16 +22,16 @@ import ( // // +stateify savable type segmentQueue struct { - mu sync.Mutex `state:"nosave"` - list segmentList `state:"wait"` - limit int - used int + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + ep *endpoint + frozen bool } // emptyLocked determines if the queue is empty. // Preconditions: q.mu must be held. func (q *segmentQueue) emptyLocked() bool { - return q.used == 0 + return q.list.Empty() } // empty determines if the queue is empty. @@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool { return r } -// setLimit updates the limit. No segments are immediately dropped in case the -// queue becomes full due to the new limit. -func (q *segmentQueue) setLimit(limit int) { - q.mu.Lock() - q.limit = limit - q.mu.Unlock() -} - // enqueue adds the given segment to the queue. // // Returns true when the segment is successfully added to the queue, in which @@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) { // false if the queue is full, in which case ownership is retained by the // caller. func (q *segmentQueue) enqueue(s *segment) bool { + // q.ep.receiveBufferParams() must be called without holding q.mu to + // avoid lock order inversion. + bufSz := q.ep.receiveBufferSize() + used := q.ep.receiveMemUsed() q.mu.Lock() - r := q.used < q.limit - if r { + // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue + // is currently full). + allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + + if allow { q.list.PushBack(s) - q.used++ + // Set the owner now that the endpoint owns the segment. + s.setOwner(q.ep, recvQ) } q.mu.Unlock() - return r + return allow } // dequeue removes and returns the next segment from queue, if one exists. @@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment { s := q.list.Front() if s != nil { q.list.Remove(s) - q.used-- } q.mu.Unlock() return s } + +// freeze prevents any more segments from being added to the queue. i.e all +// future segmentQueue.enqueue will return false and not add the segment to the +// queue till the queue is unfroze with a corresponding segmentQueue.thaw call. +func (q *segmentQueue) freeze() { + q.mu.Lock() + q.frozen = true + q.mu.Unlock() +} + +// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and +// allows new segments to be queued again. +func (q *segmentQueue) thaw() { + q.mu.Lock() + q.frozen = false + q.mu.Unlock() +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 8b2217a98..8326736dc 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -349,8 +350,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -380,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -479,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -521,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence number // of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -561,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -597,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -644,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -665,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -725,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -777,8 +778,8 @@ func TestSimpleReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1030,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxSynAckV6(t *testing.T) { @@ -1058,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, @@ -1095,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, @@ -1133,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestSendRstOnListenerRxAckV4(t *testing.T) { @@ -1162,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxAckV6(t *testing.T) { @@ -1190,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestListenShutdown tests for the listening endpoint replying with RST @@ -1306,8 +1307,8 @@ func TestTOSV4(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1355,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1546,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1597,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1608,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Create a new connection with initial window size of 10. - c.CreateConnected(789, 30000, 10) + rcvBufSz := math.MaxUint16 + c.CreateConnected(789, 30000, rcvBufSz) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) @@ -1630,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1651,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1671,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(793), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1713,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1728,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1782,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1796,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { @@ -1815,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence // number of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), + checker.TCPSeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1861,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, 10) + const rcvBufSz = 10 + c.CreateConnected(789, 30000, rcvBufSz) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1872,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("Read failed: %s", err) } - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies + // the provided buffer value by tcp.SegOverheadFactor to calculate the actual + // receive buffer size. + data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) + for i := range data { + data[i] = byte(i % 255) + } c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -1894,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) @@ -1920,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), + checker.TCPWindow(10), ), ) } @@ -1932,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Start off with a window size of 10, then shrink it to 5. - c.CreateConnected(789, 30000, 10) - - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + // Start off with a certain receive buffer then cut it in half and verify that + // the right edge of the window does not shrink. + // NOTE: Netstack doubles the value specified here. + rcvBufSize := 65536 + iss := seqnum.Value(789) + // Enable window scaling with a scale of zero from our end. + c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1946,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ + // Send a 1 byte payload so that we can record the current receive window. + // Send a payload of half the size of rcvBufSize. + seqNum := iss.Add(1) + payload := []byte{1} + c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1965,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("Timed out waiting for data to arrive") } - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), + // Read the 1 byte payload we just sent. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + if got, want := payload, v; !bytes.Equal(got, want) { + t.Fatalf("got data: %v, want: %v", got, want) + } + + seqNum = seqNum.Add(1) + // Verify that the ACK does not shrink the window. + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), ), ) + // Stash the initial window. + initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) + // Now shrink the receive buffer to half its original size. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) + } - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ + data := generateRandomPayload(t, rcvBufSize) + // Send a payload of half the size of rcvBufSize. + c.SendPacket(data[:rcvBufSize/2], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") + // Verify that the ACK does not shrink the window. + pkt = c.GetPacket() + checker.IPv4(t, pkt, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) + if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { + t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) } + // Send another payload of half the size of rcvBufSize. This should fill up the + // socket receive buffer and we should see a zero window. + c.SendPacket(data[rcvBufSize/2:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + // Receive data and check it. - read := make([]byte, 0, 10) + read := make([]byte, 0, rcvBufSize) for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { @@ -2018,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("got data = %v, want = %v", read, data) } - // Check that we get an ACK for the newly non-zero window, which is the - // new size. + // Check that we get an ACK for the newly non-zero window, which is the new + // receive buffer size we set after the connection was established. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), + checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), ) } @@ -2051,8 +2109,8 @@ func TestSimpleSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2093,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2115,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2155,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2194,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2230,7 +2288,8 @@ func TestScaledWindowAccept(t *testing.T) { } // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -2260,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2341,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2356,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ + // Set the buffer size such that a window scale of 5 will be used. + const bufSz = 65535 * 10 + const ws = uint32(5) + c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) // Write chunks of 50000 bytes. - remain := wnd + remain := 0 sent := 0 data := make([]byte, 50000) - for remain > len(data) { + // Keep writing till the window drops below len(data). + for { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -2377,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Don't reduce window to zero here. + if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<= 16 { + for remain >= 16 { data = data[:remain-15] c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, @@ -2402,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Since the receive buffer is split between window advertisement and + // application data buffer the window does not always reflect the space + // available and actual space available can be a bit more than what is + // advertised in the window. + wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) + if wnd == 0 { + break + } + remain = wnd << ws } - // Read at least 1MSS of data. An ack should be sent in response to that. + // Read at least 2MSS of data. An ack should be sent in response to that. + // Since buffer space is now split in half between window and application + // data we need to read more than 1 MSS(65536) of data for a non-zero window + // update to be sent. For 1MSS worth of window to be available we need to + // read at least 128KB. Since our segments above were 50KB each it means + // we need to read at 3 packets. sz := 0 - for sz < defaultMTU { + for sz < defaultMTU*2 { v, _, err := c.EP.Read(nil) if err != nil { t.Fatalf("Read failed: %s", err) @@ -2429,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(sz>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2498,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize+1), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+uint32(i)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2521,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+11), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+11), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2569,8 +2646,8 @@ func TestDelay(t *testing.T) { checker.PayloadLen(len(want)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2616,8 +2693,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2639,8 +2716,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2701,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2753,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2996,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - const wndScale = 2 + const wndScale = 3 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } @@ -3031,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), + checker.TCPSeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -3052,8 +3129,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -3346,8 +3423,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3367,8 +3444,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3389,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3400,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3421,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3445,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3470,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3492,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3520,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3539,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3559,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3580,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3604,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3629,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3645,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3666,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3691,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3712,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3727,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3743,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3835,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.PayloadLen((1<> uint32(c.WindowScale)) - rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. - b := make([]byte, int(receiveBufferSize)*2) - offset := 0 - payloadSize := receiveBufferSize - 1 + payloadSize := receiveBufferSize * 2 + b := make([]byte, int(payloadSize)) + worker := (c.EP).(interface { StopWork() ResumeWork() @@ -6012,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Stop the worker goroutine. worker.StopWork() - start := offset - end := offset + payloadSize + start := 0 + end := payloadSize / 2 packetsSent := 0 for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + packetEnd := start + mss + if start+mss > end { + packetEnd = end + } + rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) packetsSent++ } @@ -6024,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // are waiting to be read. worker.ResumeWork() - // Since we read no bytes the window should goto zero till the - // application reads some of the data. - // Discard all intermediate acks except the last one. - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } + // Since we sent almost the full receive buffer worth of data (some may have + // been dropped due to segment overheads), we should get a zero window back. + pkt = c.GetPacket() + tcpHdr := header.TCP(header.IPv4(pkt).Payload()) + gotRcvWnd := tcpHdr.WindowSize() + wantAckNum := tcpHdr.AckNumber() + if got, want := int(gotRcvWnd), 0; got != want { + t.Fatalf("got rcvWnd: %d, want: %d", got, want) } - rawEP.VerifyACKRcvWnd(0) time.Sleep(25 * time.Millisecond) - // Verify that sending more data when window is closed is dropped and - // not acked. + // Verify that sending more data when receiveBuffer is exhausted. rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - // Verify that the stack sends us back an ACK with the sequence number - // of the last packet sent indicating it was dropped. - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - checker.Window(0), - )) - // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { @@ -6059,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Verify that we receive a non-zero window update ACK. When running // under thread santizer this test can end up sending more than 1 // ack, 1 for the non-zero window - p = c.GetPacket() + p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), + checker.TCPAckNum(uint32(wantAckNum)), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { return } - if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) + // We use 10% here as the error margin upwards as the initial window we + // got was afer 1 segment was already in the receive buffer queue. + tolerance := 1.1 + if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { + t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) } }, )) } -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. +// This test verifies that the advertised window is auto-tuned up as the +// application is reading the data that is being received. func TestReceiveBufferAutoTuning(t *testing.T) { const mtu = 1500 const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize @@ -6085,9 +6160,6 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Enable Auto-tuning. stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 { @@ -6109,8 +6181,10 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - wantRcvWnd := receiveBufferSize + tsVal := uint32(rawEP.TSVal) + rawEP.SendPacketWithTS([]byte{1}, tsVal) + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { return uint16(rcvWnd >> uint16(c.WindowScale)) } @@ -6127,14 +6201,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) { StopWork() ResumeWork() }) - tsVal := rawEP.TSVal - // We are going to do our own computation of what the moderated receive - // buffer should be based on sent/copied data per RTT and verify that - // the advertised window by the stack matches our calculations. - prevCopied := 0 - done := false latency := 1 * time.Millisecond - for i := 0; !done; i++ { + for i := 0; i < 5; i++ { tsVal++ // Stop the worker goroutine. @@ -6156,15 +6224,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Give 1ms for the worker to process the packets. time.Sleep(1 * time.Millisecond) - // Verify that the advertised window on the ACK is reduced by - // the total bytes sent. - expectedWnd := wantRcvWnd - totalSent - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() + lastACK := c.GetPacket() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break } + lastACK = pkt + } + if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { + t.Fatalf("advertised window got: %d, want <= %d", got, want) } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) // Now read all the data from the endpoint and invoke the // moderation API to allow for receive buffer auto-tuning @@ -6189,35 +6262,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ - if i == 0 { // In the first iteration the receiver based RTT is not // yet known as a result the moderation code should not // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) - prevCopied = totalCopied + rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) } else { - rttCopied := totalCopied - if i == 1 { - // The moderation code accumulates copied bytes till - // RTT is established. So add in the bytes sent in - // the first iteration to the total bytes for this - // RTT. - rttCopied += prevCopied - // Now reset it to the initial value used by the - // auto tuning logic. - prevCopied = tcp.InitialCwnd * mss * 2 - } - newWnd := rttCopied<<1 + 16*mss - grow := (newWnd * (rttCopied - prevCopied)) / prevCopied - newWnd += (grow << 1) - if newWnd > maxReceiveBufferSize { - newWnd = maxReceiveBufferSize - done = true + pkt := c.GetPacket() + curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale + // If thew new current window is close maxReceiveBufferSize then terminate + // the loop. This can happen before all iterations are done due to timing + // differences when running the test. + if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { + break } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) - wantRcvWnd = newWnd - prevCopied = rttCopied // Increase the latency after first two iterations to // establish a low RTT value in the receiver since it // only tracks the lowest value. This ensures that when @@ -6230,6 +6288,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) { offset += payloadSize payloadSize *= 2 } + // Check that at the end of our iterations the receive window grew close to the maximum + // permissible size of maxReceiveBufferSize/2 + if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { + t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) + } + } func TestDelayEnabled(t *testing.T) { @@ -6381,8 +6445,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6399,8 +6463,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now send a RST and this should be ignored and not @@ -6428,8 +6492,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6500,8 +6564,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6518,8 +6582,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Out of order ACK should generate an immediate ACK in @@ -6535,8 +6599,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6607,8 +6671,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6625,8 +6689,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Send a SYN request w/ sequence number lower than @@ -6764,8 +6828,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6782,8 +6846,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) time.Sleep(2 * time.Second) @@ -6797,8 +6861,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Sleep for 4 seconds so at this point we are 1 second past the @@ -6826,8 +6890,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { @@ -6926,8 +6990,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now write a few bytes and then close the endpoint. @@ -6945,8 +7009,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -6960,8 +7024,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.AckNum(uint32(iss+2)), + checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.TCPAckNum(uint32(iss+2)), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) // First send a partial ACK. @@ -7006,8 +7070,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -7043,8 +7107,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -7078,8 +7142,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7140,8 +7204,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7166,8 +7230,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7183,9 +7247,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { } } -func TestIncreaseWindowOnReceive(t *testing.T) { +func TestIncreaseWindowOnRead(t *testing.T) { // This test ensures that the endpoint sends an ack, - // after recv() when the window grows to more than 1 MSS. + // after read() when the window grows by more than 1 MSS. c := context.New(t, defaultMTU) defer c.Cleanup() @@ -7194,10 +7258,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7210,46 +7273,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) { }) sent += len(data) remain -= len(data) - - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Break once the window drops below defaultMTU/2 + if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { + break + } } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - - // We now have < 1 MSS in the buffer space. Read the data! An - // ack should be sent in response to that. The window was not - // zero, but it grew to larger than MSS. - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) - } - - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) + // We now have < 1 MSS in the buffer space. Read at least > 2 MSS + // worth of data as receive buffer space + read := 0 + // defaultMTU is a good enough estimate for the MSS used for this + // connection. + for read < defaultMTU*2 { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + read += len(v) } - // After reading two packets, we surely crossed MSS. See the ack: + // After reading > MSS worth of data, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7266,10 +7326,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7283,38 +7342,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { sent += len(data) remain -= len(data) - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowLessThanEq(0xffff), checker.TCPFlags(header.TCPFlagAck), ), ) } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - // After reading two packets, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7359,8 +7409,8 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) @@ -7374,8 +7424,8 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestTCPDeferAcceptTimeout(t *testing.T) { @@ -7412,7 +7462,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) // Send data. This should result in an acceptable endpoint. c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ @@ -7428,8 +7478,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) @@ -7444,8 +7494,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestResetDuringClose(t *testing.T) { @@ -7470,8 +7520,8 @@ func TestResetDuringClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(irs.Add(1))), - checker.AckNum(uint32(iss.Add(5))))) + checker.TCPSeqNum(uint32(irs.Add(1))), + checker.TCPAckNum(uint32(iss.Add(5))))) // Close in a separate goroutine so that we can trigger // a race with the RST we send below. This should not @@ -7552,3 +7602,14 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } } } + +// generateRandomPayload generates a random byte slice of the specified length +// causing a fatal test failure if it is unable to do so. +func generateRandomPayload(t *testing.T, n int) []byte { + t.Helper() + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + return buf +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 44593ed98..0f9ed06cd 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -159,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.PayloadLen(len(data)+header.TCPMinimumSize+12), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(true, 0, tsVal+1), ), @@ -181,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) @@ -219,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(false, 0, 0), ), @@ -237,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of + // that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 059c13821..ebbae6e2f 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -145,6 +145,10 @@ type Context struct { // WindowScale is the expected window scale in SYN packets sent by // the stack. WindowScale uint8 + + // RcvdWindowScale is the actual window scale sent by the stack in + // SYN/SYN-ACK. + RcvdWindowScale uint8 } // New allocates and initializes a test context containing a new @@ -261,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) { c.CheckNoPacketTimeout(errMsg, 1*time.Second) } -// GetPacket reads a packet from the link layer endpoint and verifies +// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies // that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { +// addresses. If no packet is received in the specified timeout it will return +// nil. +func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { - c.t.Fatalf("Packet wasn't written out") return nil } @@ -299,6 +302,21 @@ func (c *Context) GetPacket() []byte { return b } +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. +func (c *Context) GetPacket() []byte { + c.t.Helper() + + p := c.GetPacketWithTimeout(5 * time.Second) + if p == nil { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + return p +} + // GetPacketNonBlocking reads a packet from the link layer endpoint // and verifies that it is an IPv4 packet with the expected source // and destination address. If no packet is available it will return @@ -486,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.PayloadLen(size+header.TCPMinimumSize+optlen), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -513,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -652,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) } tcpHdr := header.TCP(header.IPv4(b).Payload()) + synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ @@ -669,8 +688,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) checker.TCP( checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -687,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } + c.RcvdWindowScale = uint8(synOpts.WS) c.Port = tcpHdr.SourcePort() } @@ -758,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) } -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { +// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches +// the provided tsVal as well as returns the original packet. +func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { + r.C.t.Helper() // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPTimestampChecker(true, 0, tsVal), ), ) @@ -776,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) opts := tcpSeg.ParsedOptions() r.RecentTS = opts.TSVal + return ackPacket +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + r.C.t.Helper() + _ = r.VerifyAndReturnACKWithTS(tsVal) } // VerifyACKRcvWnd verifies that the window advertised by the incoming ACK // matches the provided rcvWnd. func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { + r.C.t.Helper() ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.Window(rcvWnd), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), + checker.TCPWindow(rcvWnd), ), ) } @@ -807,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPSACKBlockChecker(sackBlocks), ), ) @@ -900,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * tcpCheckers := []checker.TransportChecker{ checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), + checker.TCPSeqNum(uint32(c.IRS) + 1), + checker.TCPAckNum(uint32(iss) + 1), } // Verify that tsEcr of ACK packet is wantOptions.TSVal if the @@ -936,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Mark in context that timestamp option is enabled for this endpoint. c.TimeStampEnabled = true - + c.RcvdWindowScale = uint8(synOptions.WS) return &RawEndpoint{ C: c, SrcPort: tcpSeg.DestinationPort(), @@ -1029,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + c.t.Helper() opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 offset += header.EncodeMSSOption(uint32(maxPayload), opts) @@ -1067,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // are present. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) + rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) c.IRS = seqnum.Value(tcp.SequenceNumber()) tcpCheckers := []checker.TransportChecker{ checker.SrcPort(StackPort), checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), + checker.TCPAckNum(uint32(iss) + 1), checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), } @@ -1116,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // Send ACK. c.SendPacket(nil, ackHeaders) + c.RcvdWindowScale = uint8(rcvdSynOptions.WS) c.Port = StackPort return &RawEndpoint{ -- cgit v1.2.3