diff options
author | Nick Brown <nickbrow@google.com> | 2021-04-19 16:41:56 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-04-19 16:43:30 -0700 |
commit | 7bfc76d946b6c3f02fc32831ddc282ac2816d5ed (patch) | |
tree | 606ee32bcf8b932d9675d18187cfeced3873105a /pkg/tcpip/transport/tcp | |
parent | 276ff149a4555b69c4c99fdcd4e1a22ccc8b9463 (diff) |
De-duplicate TCP state in TCPEndpointState vs tcp.endpoint
This change replaces individual private members in tcp.endpoint with a single
private TCPEndpointState member.
Some internal substructures within endpoint (receiver, sender) have been broken
into a public substructure (which is then copied into the TCPEndpointState
returned from completeState()) alongside other private fields.
Fixes #4466
PiperOrigin-RevId: 369329514
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 47 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 114 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/cubic.go | 119 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/cubic_state.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/dispatcher.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 648 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rack.go | 129 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rack_state.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 173 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/reno.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/reno_recovery.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/sack_recovery.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 440 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd_state.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_rack_test.go | 4 |
17 files changed, 740 insertions, 1118 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index b09a0ebbc..48417f192 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -34,14 +34,12 @@ go_library( "connect.go", "connect_unsafe.go", "cubic.go", - "cubic_state.go", "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", "protocol.go", "rack.go", - "rack_state.go", "rcv.go", "rcv_state.go", "reno.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 4be306434..664cb9420 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -215,11 +215,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header n := newEndpoint(l.stack, netProto, queue) n.ops.SetV6Only(l.v6Only) - n.ID = s.id + n.TransportEndpointInfo.ID = s.id n.boundNICID = s.nicID n.route = route n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} - n.rcvBufSize = int(l.rcvWnd) + n.rcvQueueInfo.RcvBufSize = int(l.rcvWnd) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -231,7 +231,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header // Bootstrap the auto tuning algorithm. Starting at zero will result in // a large step function on the first window adjustment causing the // window to grow to a really large value. - n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + n.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = n.initialReceiveWindow() return n, nil } @@ -290,7 +290,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint( + ep.effectiveNetProtos, + ProtocolNumber, + ep.TransportEndpointInfo.ID, + ep, + ep.boundPortFlags, + ep.boundBindToDevice, + ); err != nil { ep.mu.Unlock() ep.Close() @@ -335,14 +342,14 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.ID] = n + l.pendingEndpoints[n.TransportEndpointInfo.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.ID) + delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -383,7 +390,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window // scaling. - e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + e.rcv.RcvWndScale = e.h.effectiveRcvWndScale() // Clean up handshake state stored in the endpoint so that it can be GCed. e.h = nil @@ -444,12 +451,15 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // * propagateInheritableOptionsLocked has been called. // * e.mu is held. func (e *endpoint) reserveTupleLocked() bool { - dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + dest := tcpip.FullAddress{ + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, + } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: dest, @@ -537,9 +547,9 @@ func (e *endpoint) acceptQueueIsFull() bool { // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // @@ -689,7 +699,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint( + n.effectiveNetProtos, + ProtocolNumber, + n.TransportEndpointInfo.ID, + n, + n.boundPortFlags, + n.boundBindToDevice, + ); err != nil { n.mu.Unlock() n.Close() @@ -704,7 +721,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was // sent above. - n.tsOffset = 0 + n.TSOffset = 0 // Switch state to connected. n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 8f0f0c3e9..7bc6b08f0 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -156,7 +156,7 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the @@ -302,7 +302,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { ttl = h.ep.route.DefaultTTL() } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -358,14 +358,14 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, - TS: h.ep.sendTSOk, + TS: h.ep.SendTSOk, TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), - SACKPermitted: h.ep.sackPermitted, + SACKPermitted: h.ep.SACKPermitted, MSS: h.ep.amss, } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -390,7 +390,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. - if h.ep.sendTSOk && !s.parsedOptions.TS { + if h.ep.SendTSOk && !s.parsedOptions.TS { h.ep.stack.Stats().DroppedPackets.Increment() return nil } @@ -405,7 +405,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { } // Update timestamp if required. See RFC7323, section-4.3. - if h.ep.sendTSOk && s.parsedOptions.TS { + if h.ep.SendTSOk && s.parsedOptions.TS { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted @@ -495,8 +495,8 @@ func (h *handshake) start() { // start() is also called in a listen context so we want to make sure we only // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { - synOpts.TS = h.ep.sendTSOk - synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + synOpts.TS = h.ep.SendTSOk + synOpts.SACKPermitted = h.ep.SACKPermitted && bool(sackEnabled) if h.sndWndScale < 0 { // Disable window scaling if the peer did not send us // the window scaling option. @@ -506,7 +506,7 @@ func (h *handshake) start() { h.sendSYNOpts = synOpts h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -554,7 +554,7 @@ func (h *handshake) complete() tcpip.Error { // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -855,7 +855,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // N.B. the ordering here matches the ordering used by Linux internally // and described in the raw makeOptions function. We don't include // unnecessary cases here (post connection.) - if e.sendTSOk { + if e.SendTSOk { // Embed the timestamp if timestamp has been enabled. // // We only use the lower 32 bits of the unix time in @@ -872,7 +872,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } - if e.sackPermitted && len(sackBlocks) > 0 { + if e.SACKPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) @@ -894,7 +894,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } options := e.makeOptions(sackBlocks) err := e.sendTCP(e.route, tcpFields{ - id: e.ID, + id: e.TransportEndpointInfo.ID, ttl: e.ttl, tos: e.sendTOS, flags: flags, @@ -908,9 +908,9 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } func (e *endpoint) handleWrite() { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() next := e.drainSendQueueLocked() - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.sendData(next) } @@ -919,10 +919,10 @@ func (e *endpoint) handleWrite() { // // Precondition: e.sndBufMu must be locked. func (e *endpoint) drainSendQueueLocked() *segment { - first := e.sndQueue.Front() + first := e.sndQueueInfo.sndQueue.Front() if first != nil { - e.snd.writeList.PushBackList(&e.sndQueue) - e.sndBufInQueue = 0 + e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue) + e.sndQueueInfo.SndBufInQueue = 0 } return first } @@ -946,7 +946,7 @@ func (e *endpoint) handleClose() { e.handleWrite() // Mark send side as closed. - e.snd.closed = true + e.snd.Closed = true } // resetConnectionLocked puts the endpoint in an error state with the given @@ -968,12 +968,12 @@ func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more // information. - sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + sndWndEnd := e.snd.SndUna.Add(e.snd.SndWnd) resetSeqNum := sndWndEnd - if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { - resetSeqNum = e.snd.sndNxt + if !sndWndEnd.LessThan(e.snd.SndNxt) || e.snd.SndNxt.Size(sndWndEnd) < (1<<e.snd.SndWndScale) { + resetSeqNum = e.snd.SndNxt } - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.RcvNxt, 0) } } @@ -999,13 +999,13 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. - e.rcvAutoParams.prevCopied = int(h.rcvWnd) - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) + e.rcvQueueInfo.rcvQueueMu.Unlock() e.setEndpointState(StateEstablished) } @@ -1036,10 +1036,15 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.TransportEndpointInfo.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) + ep = e.stack.FindTransportEndpoint( + header.IPv4ProtocolNumber, + e.TransProto, + e.TransportEndpointInfo.ID, + s.nicID, + ) } if ep == nil { replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) @@ -1118,7 +1123,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1130,7 +1137,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { break } - cont, err := e.handleSegment(s) + cont, err := e.handleSegmentLocked(s) s.decRef() if err != nil { return err @@ -1148,7 +1155,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { } // Send an ACK for all processed packets if needed. - if e.rcv.rcvNxt != e.snd.maxSentAck { + if e.rcv.RcvNxt != e.snd.MaxSentAck { e.snd.sendAck() } @@ -1157,18 +1164,21 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { return nil } -func (e *endpoint) probeSegment() { - if e.probe != nil { - e.probe(e.completeState()) +// Precondition: e.mu must be held. +func (e *endpoint) probeSegmentLocked() { + if fn := e.probe; fn != nil { + fn(e.completeStateLocked()) } } // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. - defer e.probeSegment() + defer e.probeSegmentLocked() if s.flagIsSet(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { @@ -1201,7 +1211,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. - s.window <<= e.snd.sndWndScale + s.window <<= e.snd.SndWndScale // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment @@ -1265,7 +1275,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // seg.seq = snd.nxt-1. e.keepalive.unacked++ e.keepalive.Unlock() - e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) + e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.SndNxt-1) e.resetKeepaliveTimer(false) return nil } @@ -1279,7 +1289,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.SndUna != e.snd.SndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return @@ -1372,14 +1382,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ f func() tcpip.Error }{ { - w: &e.sndWaker, + w: &e.sndQueueInfo.sndWaker, f: func() tcpip.Error { e.handleWrite() return nil }, }, { - w: &e.sndCloseWaker, + w: &e.sndQueueInfo.sndCloseWaker, f: func() tcpip.Error { e.handleClose() return nil @@ -1413,7 +1423,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ { w: &e.newSegmentWaker, f: func() tcpip.Error { - return e.handleSegments(false /* fastPath */) + return e.handleSegmentsLocked(false /* fastPath */) }, }, { @@ -1429,11 +1439,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyMTUChanged != 0 { - e.sndBufMu.Lock() - count := e.packetTooBigCount - e.packetTooBigCount = 0 - mtu := e.sndMTU - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + count := e.sndQueueInfo.PacketTooBigCount + e.sndQueueInfo.PacketTooBigCount = 0 + mtu := e.sndQueueInfo.SndMTU + e.sndQueueInfo.sndQueueMu.Unlock() e.snd.updateMaxPayloadSize(mtu, count) } @@ -1463,7 +1473,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(false /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { return err } } @@ -1514,11 +1524,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.newSegmentWaker.Assert() } - e.rcvListMu.Lock() - if !e.rcvList.Empty() { + e.rcvQueueInfo.rcvQueueMu.Lock() + if !e.rcvQueueInfo.rcvQueue.Empty() { e.waiterQueue.Notify(waiter.ReadableEvents) } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 1975f1a44..962f1d687 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -17,6 +17,8 @@ package tcp import ( "math" "time" + + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // cubicState stores the variables related to TCP CUBIC congestion @@ -25,47 +27,12 @@ import ( // See: https://tools.ietf.org/html/rfc8312. // +stateify savable type cubicState struct { - // wLastMax is the previous wMax value. - wLastMax float64 - - // wMax is the value of the congestion window at the - // time of last congestion event. - wMax float64 - - // t denotes the time when the current congestion avoidance - // was entered. - t time.Time `state:".(unixTime)"` + stack.TCPCubicState // numCongestionEvents tracks the number of congestion events since last // RTO. numCongestionEvents int - // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as - // per RFC. - c float64 - - // k is the time period that the above function takes to increase the - // current window size to W_max if there are no further congestion - // events and is calculated using the following equation: - // - // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) - k float64 - - // beta is the CUBIC multiplication decrease factor. that is, when a - // congestion event is detected, CUBIC reduces its cwnd to - // W_cubic(0)=W_max*beta_cubic. - beta float64 - - // wC is window computed by CUBIC at time t. It's calculated using the - // formula: - // - // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) - wC float64 - - // wEst is the window computed by CUBIC at time t+RTT i.e - // W_cubic(t+RTT). - wEst float64 - s *sender } @@ -73,10 +40,12 @@ type cubicState struct { // beta and c set and t set to current time. func newCubicCC(s *sender) *cubicState { return &cubicState{ - t: time.Now(), - beta: 0.7, - c: 0.4, - s: s, + TCPCubicState: stack.TCPCubicState{ + T: time.Now(), + Beta: 0.7, + C: 0.4, + }, + s: s, } } @@ -90,10 +59,10 @@ func (c *cubicState) enterCongestionAvoidance() { // See: https://tools.ietf.org/html/rfc8312#section-4.7 & // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { - c.k = 0 - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.K = 0 + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) } } @@ -104,16 +73,16 @@ func (c *cubicState) enterCongestionAvoidance() { func (c *cubicState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := c.s.sndCwnd + packetsAcked + newcwnd := c.s.SndCwnd + packetsAcked enterCA := false - if newcwnd >= c.s.sndSsthresh { - newcwnd = c.s.sndSsthresh - c.s.sndCAAckCount = 0 + if newcwnd >= c.s.Ssthresh { + newcwnd = c.s.Ssthresh + c.s.SndCAAckCount = 0 enterCA = true } - packetsAcked -= newcwnd - c.s.sndCwnd - c.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - c.s.SndCwnd + c.s.SndCwnd = newcwnd if enterCA { c.enterCongestionAvoidance() } @@ -124,49 +93,49 @@ func (c *cubicState) updateSlowStart(packetsAcked int) int { // ACK received. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) Update(packetsAcked int) { - if c.s.sndCwnd < c.s.sndSsthresh { + if c.s.SndCwnd < c.s.Ssthresh { packetsAcked = c.updateSlowStart(packetsAcked) if packetsAcked == 0 { return } } else { c.s.rtt.Lock() - srtt := c.s.rtt.srtt + srtt := c.s.rtt.TCPRTTState.SRTT c.s.rtt.Unlock() - c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) + c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt) } } // cubicCwnd computes the CUBIC congestion window after t seconds from last // congestion event. func (c *cubicState) cubicCwnd(t float64) float64 { - return c.c*math.Pow(t, 3.0) + c.wMax + return c.C*math.Pow(t, 3.0) + c.WMax } // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.t).Seconds() + elapsed := time.Since(c.T).Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.wC = c.cubicCwnd(elapsed - c.k) + c.WC = c.cubicCwnd(elapsed - c.K) // Compute the TCP friendly estimate of the congestion window. - c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. - if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + if c.WC < c.WEst && float64(sndCwnd) < c.WEst { // TCP Friendly region of cubic. - return int(c.wEst) + return int(c.WEst) } // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.t) + srtt).Seconds() - wtRtt := c.cubicCwnd(tEst - c.k) + tEst := (time.Since(c.T) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. cwnd := float64(sndCwnd) @@ -182,9 +151,9 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() c.reduceSlowStartThreshold() @@ -193,10 +162,10 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.t = time.Now() + c.T = time.Now() c.numCongestionEvents = 0 - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() @@ -206,29 +175,29 @@ func (c *cubicState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - c.s.sndCwnd = 1 + c.s.SndCwnd = 1 } // fastConvergence implements the logic for Fast Convergence algorithm as // described in https://tools.ietf.org/html/rfc8312#section-4.6. func (c *cubicState) fastConvergence() { - if c.wMax < c.wLastMax { - c.wLastMax = c.wMax - c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + if c.WMax < c.WLastMax { + c.WLastMax = c.WMax + c.WMax = c.WMax * (1.0 + c.Beta) / 2.0 } else { - c.wLastMax = c.wMax + c.WLastMax = c.WMax } // Recompute k as wMax may have changed. - c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) + c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C) } // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.t = time.Now() + c.T = time.Now() } // reduceSlowStartThreshold returns new SsThresh as described in // https://tools.ietf.org/html/rfc8312#section-4.7. func (c *cubicState) reduceSlowStartThreshold() { - c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) + c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0)) } diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go deleted file mode 100644 index d0f58cfaf..000000000 --- a/pkg/tcpip/transport/tcp/cubic_state.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// saveT is invoked by stateify. -func (c *cubicState) saveT() unixTime { - return unixTime{c.t.Unix(), c.t.UnixNano()} -} - -// loadT is invoked by stateify. -func (c *cubicState) loadT(unix unixTime) { - c.t = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 21162f01a..512053a04 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -116,7 +116,7 @@ func (p *processor) start(wg *sync.WaitGroup) { if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { // If the endpoint is in a connected state then we do direct delivery // to ensure low latency and avoid scheduler interactions. - switch err := ep.handleSegments(true /* fastPath */); { + switch err := ep.handleSegmentsLocked(true /* fastPath */); { case err != nil: // Send any active resets if required. ep.resetConnectionLocked(err) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bc88e48e9..884332828 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -191,42 +191,6 @@ type SACKInfo struct { NumBlocks int } -// rcvBufAutoTuneParams are used to hold state variables to compute -// the auto tuned recv buffer size. -// -// +stateify savable -type rcvBufAutoTuneParams struct { - // measureTime is the time at which the current measurement - // was started. - measureTime time.Time `state:".(unixTime)"` - - // copied is the number of bytes copied out of the receive - // buffers since this measure began. - copied int - - // prevCopied is the number of bytes copied out of the receive - // buffers in the previous RTT period. - prevCopied int - - // rtt is the non-smoothed minimum RTT as measured by observing the time - // between when a byte is first acknowledged and the receipt of data - // that is at least one window beyond the sequence number that was - // acknowledged. - rtt time.Duration - - // rttMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - rttMeasureSeqNumber seqnum.Value - - // rttMeasureTime is the absolute time at which the current rtt - // measurement period began. - rttMeasureTime time.Time `state:".(unixTime)"` - - // disabled is true if an explicit receive buffer is set for the - // endpoint. - disabled bool -} - // ReceiveErrors collect segment receive errors within transport layer. type ReceiveErrors struct { tcpip.ReceiveErrors @@ -247,7 +211,7 @@ type ReceiveErrors struct { ListenOverflowAckDrop tcpip.StatCounter // ZeroRcvWindowState is the number of times we advertised - // a zero receive window when rcvList is full. + // a zero receive window when rcvQueue is full. ZeroRcvWindowState tcpip.StatCounter // WantZeroWindow is the number of times we wanted to advertise a @@ -310,18 +274,36 @@ type Stats struct { // marker interface. func (*Stats) IsEndpointStats() {} -// EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. This exists to allow tcp-only state to -// be exposed. +// sndQueueInfo implements a send queue. // // +stateify savable -type EndpointInfo struct { - stack.TransportEndpointInfo +type sndQueueInfo struct { + sndQueueMu sync.Mutex `state:"nosave"` + stack.TCPSndBufState + + // sndQueue holds segments that are ready to be sent. + sndQueue segmentList `state:"wait"` + + // sndWaker is used to signal the protocol goroutine when segments are + // added to the `sndQueue`. + sndWaker sleep.Waker `state:"manual"` + + // sndCloseWaker is used to notify the protocol goroutine when the send + // side is closed. + sndCloseWaker sleep.Waker `state:"manual"` } -// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo -// marker interface. -func (*EndpointInfo) IsEndpointInfo() {} +// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata. +// +// +stateify savable +type rcvQueueInfo struct { + rcvQueueMu sync.Mutex `state:"nosave"` + stack.TCPRcvBufState + + // rcvQueue is the queue for ready-for-delivery segments. This struct's + // mutex must be held in order append segments to list. + rcvQueue segmentList `state:"wait"` +} // +stateify savable type accepted struct { @@ -348,8 +330,8 @@ type accepted struct { // acquired with e.mu then e.mu must be acquired first. // // e.acceptMu -> protects accepted. -// e.rcvListMu -> Protects the rcvList and associated fields. -// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.rcvQueueMu -> Protects e.rcvQueue and associated fields. +// e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. // // LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different @@ -372,7 +354,8 @@ type accepted struct { // // +stateify savable type endpoint struct { - EndpointInfo + stack.TCPEndpointStateInner + stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the @@ -405,38 +388,23 @@ type endpoint struct { // rcvReadMu synchronizes calls to Read. // - // mu and rcvListMu are temporarily released during data copying. rcvReadMu + // mu and rcvQueueMu are temporarily released during data copying. rcvReadMu // must be held during each read to ensure atomicity, so that multiple reads // do not interleave. // // rcvReadMu should be held before holding mu. rcvReadMu sync.Mutex `state:"nosave"` - // rcvListMu synchronizes access to rcvList. - // - // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - - // rcvList is the queue for ready-for-delivery segments. - // - // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data - // and removing segments from list. A range of segment can be determined, then - // temporarily release mu and rcvListMu while processing the segment range. - // This allows new segments to be appended to the list while processing. - // - // rcvListMu must be held to append segments to list. - rcvList segmentList `state:"wait"` - rcvClosed bool - // rcvBufSize is the total size of the receive buffer. - rcvBufSize int - // rcvBufUsed is the actual number of payload bytes held in the receive buffer - // not counting any overheads of the segments itself. NOTE: This will always - // be strictly <= rcvMemUsed below. - rcvBufUsed int - rcvAutoParams rcvBufAutoTuneParams + // rcvQueueInfo holds the implementation of the endpoint's receive buffer. + // The data within rcvQueueInfo should only be accessed while rcvReadMu, mu, + // and rcvQueueMu are held, in that stated order. While processing the segment + // range, you can determine a range and then temporarily release mu and + // rcvQueueMu, which allows new segments to be appended to the queue while + // processing. + rcvQueueInfo rcvQueueInfo // rcvMemUsed tracks the total amount of memory in use by received segments - // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to // compute the window and the actual available buffer space. This is distinct // from rcvBufUsed above which is the actual number of payload bytes held in // the buffer not including any segment overheads. @@ -498,33 +466,16 @@ type endpoint struct { // also true, and they're both protected by the mutex. workerCleanup bool - // sendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1 - sendTSOk bool - - // recentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - recentTS uint32 - - // recentTSTime is the unix time when we updated recentTS last. + // recentTSTime is the unix time when we last updated + // TCPEndpointStateInner.RecentTS. recentTSTime time.Time `state:".(unixTime)"` - // tsOffset is a randomized offset added to the value of the - // TSVal field in the timestamp option. - tsOffset uint32 - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags // tcpRecovery is the loss deteoction algorithm used by TCP. tcpRecovery tcpip.TCPRecovery - // sackPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - sackPermitted bool - // sack holds TCP SACK related information for this endpoint. sack SACKInfo @@ -560,32 +511,13 @@ type endpoint struct { // this value. windowClamp uint32 - // The following fields are used to manage the send buffer. When - // segments are ready to be sent, they are added to sndQueue and the - // protocol goroutine is signaled via sndWaker. - // - // When the send side is closed, the protocol goroutine is notified via - // sndCloseWaker, and sndClosed is set to true. - sndBufMu sync.Mutex `state:"nosave"` - sndBufUsed int - sndClosed bool - sndBufInQueue seqnum.Size - sndQueue segmentList `state:"wait"` - sndWaker sleep.Waker `state:"manual"` - sndCloseWaker sleep.Waker `state:"manual"` + // sndQueueInfo contains the implementation of the endpoint's send queue. + sndQueueInfo sndQueueInfo // cc stores the name of the Congestion Control algorithm to use for // this endpoint. cc tcpip.CongestionControlOption - // The following are used when a "packet too big" control packet is - // received. They are protected by sndBufMu. They are used to - // communicate to the main protocol goroutine how many such control - // messages have been received since the last notification was processed - // and what was the smallest MTU seen. - packetTooBigCount int - sndMTU int - // newSegmentWaker is used to indicate to the protocol goroutine that // it needs to wake up and handle new segments queued to it. newSegmentWaker sleep.Waker `state:"manual"` @@ -782,7 +714,7 @@ func (e *endpoint) UnlockUser() { switch e.EndpointState() { case StateEstablished: - if err := e.handleSegments(true /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(true /* fastPath */); err != nil { e.notifyProtocolGoroutine(notifyTickleWorker) } default: @@ -842,13 +774,13 @@ func (e *endpoint) EndpointState() EndpointState { // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { - e.recentTS = recentTS + e.RecentTS = recentTS e.recentTSTime = time.Now() } // recentTimestamp returns the value of the recentTS field. func (e *endpoint) recentTimestamp() uint32 { - return e.recentTS + return e.RecentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -868,16 +800,17 @@ type keepalive struct { func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ stack: s, - EndpointInfo: EndpointInfo{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.TCPProtocolNumber, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + sndQueueInfo: sndQueueInfo{ + TCPSndBufState: stack.TCPSndBufState{ + SndMTU: int(math.MaxInt32), }, }, waiterQueue: waiterQueue, state: StateInitial, - rcvBufSize: DefaultReceiveBufferSize, - sndMTU: math.MaxInt32, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -889,6 +822,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } + e.rcvQueueInfo.RcvBufSize = DefaultReceiveBufferSize e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) @@ -901,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - e.rcvBufSize = rs.Default + e.rcvQueueInfo.RcvBufSize = rs.Default } var cs tcpip.CongestionControlOption @@ -911,7 +845,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { - e.rcvAutoParams.disabled = !bool(mrb) + e.rcvQueueInfo.RcvAutoParams.Disabled = !bool(mrb) } var de tcpip.TCPDelayEnabled @@ -936,7 +870,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.tsOffset = timeStampOffset() + e.TSOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(&e.keepalive.waker) @@ -974,21 +908,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.WritableEvents) != 0 { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() sndBufSize := e.getSendBufferSize() - if e.sndClosed || e.sndBufUsed < sndBufSize { + if e.sndQueueInfo.SndClosed || e.sndQueueInfo.SndBufUsed < sndBufSize { result |= waiter.WritableEvents } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() } // Determine if the endpoint is readable if requested. if (mask & waiter.ReadableEvents) != 0 { - e.rcvListMu.Lock() - if e.rcvBufUsed > 0 || e.rcvClosed { + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvBufUsed > 0 || e.rcvQueueInfo.RcvClosed { result |= waiter.ReadableEvents } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() } } @@ -1096,15 +1030,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1179,7 +1113,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1187,8 +1121,8 @@ func (e *endpoint) cleanupLocked() { portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1250,19 +1184,19 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - if e.rcvAutoParams.disabled { - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvAutoParams.Disabled { + e.rcvQueueInfo.rcvQueueMu.Unlock() return } now := time.Now() - if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { - e.rcvAutoParams.copied += copied - e.rcvListMu.Unlock() + if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { + e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied + e.rcvQueueInfo.rcvQueueMu.Unlock() return } - prevRTTCopied := e.rcvAutoParams.copied + copied - prevCopied := e.rcvAutoParams.prevCopied + prevRTTCopied := e.rcvQueueInfo.RcvAutoParams.CopiedBytes + copied + prevCopied := e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes rcvWnd := 0 if prevRTTCopied > prevCopied { // The minimal receive window based on what was copied by the app @@ -1294,24 +1228,24 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // We do not adjust downwards as that can cause the receiver to // reject valid data that might already be in flight as the // acceptable window will shrink. - if rcvWnd > e.rcvBufSize { + if rcvWnd > e.rcvQueueInfo.RcvBufSize { availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = rcvWnd + e.rcvQueueInfo.RcvBufSize = rcvWnd availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } - // We only update prevCopied when we grow the buffer because in cases - // where prevCopied > prevRTTCopied the existing buffer is already big + // We only update PrevCopiedBytes when we grow the buffer because in cases + // where PrevCopiedBytes > prevRTTCopied the existing buffer is already big // enough to handle the current rate and we don't need to do any // adjustments. - e.rcvAutoParams.prevCopied = prevRTTCopied + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = prevRTTCopied } - e.rcvAutoParams.measureTime = now - e.rcvAutoParams.copied = 0 - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.MeasureTime = now + e.rcvQueueInfo.RcvAutoParams.CopiedBytes = 0 + e.rcvQueueInfo.rcvQueueMu.Unlock() } // SetOwner implements tcpip.Endpoint.SetOwner. @@ -1360,7 +1294,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult defer e.rcvReadMu.Unlock() // N.B. Here we get a range of segments to be processed. It is safe to not - // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we + // hold rcvQueueMu when processing, since we hold rcvReadMu to ensure only we // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { @@ -1432,10 +1366,10 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - bufUsed := e.rcvBufUsed + bufUsed := e.rcvQueueInfo.RcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { if s == StateError { if err := e.hardErrorLocked(); err != nil { @@ -1447,14 +1381,14 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { return nil, nil, &tcpip.ErrNotConnected{} } - if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.EndpointState().connected() { + if e.rcvQueueInfo.RcvBufUsed == 0 { + if e.rcvQueueInfo.RcvClosed || !e.EndpointState().connected() { return nil, nil, &tcpip.ErrClosedForReceive{} } return nil, nil, &tcpip.ErrWouldBlock{} } - return e.rcvList.Front(), e.rcvList.Back(), nil + return e.rcvQueueInfo.rcvQueue.Front(), e.rcvQueueInfo.rcvQueue.Back(), nil } // commitRead commits a read of done bytes and returns the next non-empty @@ -1470,20 +1404,20 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { func (e *endpoint) commitRead(done int) *segment { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() memDelta := 0 - s := e.rcvList.Front() + s := e.rcvQueueInfo.rcvQueue.Front() for s != nil && s.data.Size() == 0 { - e.rcvList.Remove(s) + e.rcvQueueInfo.rcvQueue.Remove(s) // Memory is only considered released when the whole segment has been // read. memDelta += s.segMemSize() s.decRef() - s = e.rcvList.Front() + s = e.rcvQueueInfo.rcvQueue.Front() } - e.rcvBufUsed -= done + e.rcvQueueInfo.RcvBufUsed -= done if memDelta > 0 { // If the window was small before this read and if the read freed up @@ -1495,14 +1429,14 @@ func (e *endpoint) commitRead(done int) *segment { } } - return e.rcvList.Front() + return e.rcvQueueInfo.rcvQueue.Front() } // isEndpointWritableLocked checks if a given endpoint is writable // and also returns the number of bytes that can be written at this // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. -// Caller must hold e.mu and e.sndBufMu +// Caller must hold e.mu and e.sndQueueMu func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { @@ -1522,12 +1456,12 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { } // Check if the connection has already been closed for sends. - if e.sndClosed { + if e.sndQueueInfo.SndClosed { return 0, &tcpip.ErrClosedForSend{} } sndBufSize := e.getSendBufferSize() - avail := sndBufSize - e.sndBufUsed + avail := sndBufSize - e.sndQueueInfo.SndBufUsed if avail <= 0 { return 0, &tcpip.ErrWouldBlock{} } @@ -1544,8 +1478,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp defer e.UnlockUser() nextSeg, n, err := func() (*segment, int, tcpip.Error) { - e.sndBufMu.Lock() - defer e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + defer e.sndQueueInfo.sndQueueMu.Unlock() avail, err := e.isEndpointWritableLocked() if err != nil { @@ -1560,8 +1494,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // available buffer space to be consumed by some other caller while we // are copying data in. if !opts.Atomic { - e.sndBufMu.Unlock() - defer e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Unlock() + defer e.sndQueueInfo.sndQueueMu.Lock() e.UnlockUser() defer e.LockUser() @@ -1603,10 +1537,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + e.sndQueueInfo.SndBufUsed += len(v) + e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) + e.sndQueueInfo.sndQueue.PushBack(s) return e.drainSendQueueLocked(), len(v), nil }() @@ -1621,11 +1555,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. -// Precondition: e.mu and e.rcvListMu must be held. +// Precondition: e.mu and e.rcvQueueMu must be held. func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) - maxWindow := wndFromSpace(e.rcvBufSize) - wndFromUsedBytes := maxWindow - e.rcvBufUsed + maxWindow := wndFromSpace(e.rcvQueueInfo.RcvBufSize) + wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in // cases where we receive a lot of small segments the segment overhead is a @@ -1643,11 +1577,11 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { return seqnum.Size(newWnd) } -// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu. +// selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu. func (e *endpoint) selectWindow() (wnd seqnum.Size) { - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() wnd = e.selectWindowLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() return wnd } @@ -1665,7 +1599,7 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // above will be true if the new window is >= ACK threshold and false // otherwise. // -// Precondition: e.mu and e.rcvListMu must be held. +// Precondition: e.mu and e.rcvQueueMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { newAvail := int(e.selectWindowLocked()) oldAvail := newAvail - deltaBefore @@ -1676,7 +1610,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // rcvBufFraction is the inverse of the fraction of receive buffer size that // is used to decide if the available buffer space is now above it. const rcvBufFraction = 2 - if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + if wndThreshold := wndFromSpace(e.rcvQueueInfo.RcvBufSize / rcvBufFraction); threshold > wndThreshold { threshold = wndThreshold } switch { @@ -1711,7 +1645,7 @@ func (e *endpoint) OnKeepAliveSet(bool) { func (e *endpoint) OnDelayOptionSet(v bool) { if !v { // Handle delayed data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1719,7 +1653,7 @@ func (e *endpoint) OnDelayOptionSet(v bool) { func (e *endpoint) OnCorkOptionSet(v bool) { if !v { // Handle the corked data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1792,23 +1726,23 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { } e.LockUser() - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() // Make sure the receive buffer size allows us to send a // non-zero window size. scale := uint8(0) if e.rcv != nil { - scale = e.rcv.rcvWndScale + scale = e.rcv.RcvWndScale } if v>>scale == 0 { v = 1 << scale } availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = v + e.rcvQueueInfo.RcvBufSize = v availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvAutoParams.disabled = true + e.rcvQueueInfo.RcvAutoParams.Disabled = true // Immediately send an ACK to uncork the sender silly window // syndrome prevetion, when our available space grows above aMSS @@ -1817,7 +1751,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() e.UnlockUser() case tcpip.TTLOption: @@ -1962,10 +1896,10 @@ func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { return 0, &tcpip.ErrInvalidEndpointState{} } - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - return e.rcvBufUsed, nil + return e.rcvQueueInfo.RcvBufUsed, nil } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -2006,9 +1940,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { return e.readyReceiveSize() case tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - v := e.rcvBufSize - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + v := e.rcvQueueInfo.RcvBufSize + e.rcvQueueInfo.rcvQueueMu.Unlock() return v, nil case tcpip.TTLOption: @@ -2046,15 +1980,15 @@ func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { // the connection did not send and receive data, then RTT will // be zero. snd.rtt.Lock() - info.RTT = snd.rtt.srtt - info.RTTVar = snd.rtt.rttvar + info.RTT = snd.rtt.TCPRTTState.SRTT + info.RTTVar = snd.rtt.TCPRTTState.RTTVar snd.rtt.Unlock() - info.RTO = snd.rto + info.RTO = snd.RTO info.CcState = snd.state - info.SndSsthresh = uint32(snd.sndSsthresh) - info.SndCwnd = uint32(snd.sndCwnd) - info.ReorderSeen = snd.rc.reorderSeen + info.SndSsthresh = uint32(snd.Ssthresh) + info.SndCwnd = uint32(snd.SndCwnd) + info.ReorderSeen = snd.rc.Reord } e.UnlockUser() return info @@ -2099,7 +2033,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) e.UnlockUser() if err != nil { return err @@ -2207,20 +2141,20 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.TransportEndpointInfo.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.ID.LocalAddress = r.LocalAddress() - e.ID.RemoteAddress = r.RemoteAddress() - e.ID.RemotePort = addr.Port + e.TransportEndpointInfo.ID.LocalAddress = r.LocalAddress() + e.TransportEndpointInfo.ID.RemoteAddress = r.RemoteAddress() + e.TransportEndpointInfo.ID.RemotePort = addr.Port - if e.ID.LocalPort != 0 { + if e.TransportEndpointInfo.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2229,7 +2163,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + sameAddr := e.TransportEndpointInfo.ID.LocalAddress == e.TransportEndpointInfo.ID.RemoteAddress // Calculate a port offset based on the destination IP/port and // src IP to ensure that for a given tuple (srcIP, destIP, @@ -2262,21 +2196,21 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { switch netProto { case header.IPv4ProtocolNumber: - reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + reuse = header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.LocalAddress) && header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.RemoteAddress) case header.IPv6ProtocolNumber: - reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + reuse = e.TransportEndpointInfo.ID.LocalAddress == header.IPv6Loopback && e.TransportEndpointInfo.ID.RemoteAddress == header.IPv6Loopback } } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { - if sameAddr && p == e.ID.RemotePort { + if sameAddr && p == e.TransportEndpointInfo.ID.RemotePort { return false, nil } portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2286,7 +2220,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } - transEPID := e.ID + transEPID := e.TransportEndpointInfo.ID transEPID.LocalPort = p // Check if an endpoint is registered with demuxer in TIME-WAIT and if // we can reuse it. If we can't find a transport endpoint then we just @@ -2323,7 +2257,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2334,13 +2268,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } } - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2355,7 +2289,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // Port picking successful. Save the details of // the selected port. - e.ID = id + e.TransportEndpointInfo.ID = id e.isPortReserved = true e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags @@ -2381,10 +2315,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // connection setting here. if !handshake { e.segmentQueue.mu.Lock() - for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.ID - e.sndWaker.Assert() + s.id = e.TransportEndpointInfo.ID + e.sndQueueInfo.sndWaker.Assert() } } e.segmentQueue.mu.Unlock() @@ -2426,10 +2360,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for read. if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. - e.rcvListMu.Lock() - e.rcvClosed = true - rcvBufUsed := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + rcvBufUsed := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() // If we're fully closed and we have unread data we need to abort // the connection with a RST. @@ -2443,10 +2377,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for write. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - e.sndBufMu.Lock() - if e.sndClosed { + e.sndQueueInfo.sndQueueMu.Lock() + if e.sndQueueInfo.SndClosed { // Already closed. - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() if e.EndpointState() == StateTimeWait { return &tcpip.ErrNotConnected{} } @@ -2454,12 +2388,12 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.ID, nil) - e.sndQueue.PushBack(s) - e.sndBufInQueue++ + s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) + e.sndQueueInfo.sndQueue.PushBack(s) + e.sndQueueInfo.SndBufInQueue++ // Mark endpoint as closed. - e.sndClosed = true - e.sndBufMu.Unlock() + e.sndQueueInfo.SndClosed = true + e.sndQueueInfo.sndQueueMu.Unlock() e.handleClose() } @@ -2472,9 +2406,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // // By not removing this endpoint from the demuxer mapping, we // ensure that any other bind to the same port fails, as on Linux. - e.rcvListMu.Lock() - e.rcvClosed = true - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + e.rcvQueueInfo.rcvQueueMu.Unlock() e.closePendingAcceptableConnectionsLocked() // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) @@ -2513,9 +2447,9 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // listen is called after shutdown. e.accepted.cap = backlog e.shutdownFlags = 0 - e.rcvListMu.Lock() - e.rcvClosed = false - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() } else { // Adjust the size of the backlog iff we can fit // existing pending connections into the new one. @@ -2548,7 +2482,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } @@ -2588,9 +2522,9 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { return nil, nil, &tcpip.ErrInvalidEndpointState{} @@ -2656,7 +2590,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { if nic == 0 { return &tcpip.ErrBadLocalAddress{} } - e.ID.LocalAddress = addr.Addr + e.TransportEndpointInfo.ID.LocalAddress = addr.Addr } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) @@ -2670,7 +2604,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { Dest: tcpip.FullAddress{}, } port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a // listening endpoint bound with the same id and portFlags and bindToDevice @@ -2696,7 +2630,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { e.boundNICID = nic e.isPortReserved = true e.effectiveNetProtos = netProtos - e.ID.LocalPort = port + e.TransportEndpointInfo.ID.LocalPort = port // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2710,8 +2644,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { defer e.UnlockUser() return tcpip.FullAddress{ - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -2730,8 +2664,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, NIC: e.boundNICID, } } @@ -2770,13 +2704,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p Payload: pkt.Data().AsRange().ToOwnedView(), Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -2789,12 +2723,12 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p // HandleError implements stack.TransportEndpoint. func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { handlePacketTooBig := func(mtu uint32) { - e.sndBufMu.Lock() - e.packetTooBigCount++ - if v := int(mtu); v < e.sndMTU { - e.sndMTU = v + e.sndQueueInfo.sndQueueMu.Lock() + e.sndQueueInfo.PacketTooBigCount++ + if v := int(mtu); v < e.sndQueueInfo.SndMTU { + e.sndQueueInfo.SndMTU = v } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) } @@ -2813,14 +2747,14 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { sendBufferSize := e.getSendBufferSize() - e.sndBufMu.Lock() - notify := e.sndBufUsed >= sendBufferSize>>1 - e.sndBufUsed -= v + e.sndQueueInfo.sndQueueMu.Lock() + notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1 + e.sndQueueInfo.SndBufUsed -= v // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < sendBufferSize>>1 - e.sndBufMu.Unlock() + notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + e.sndQueueInfo.sndQueueMu.Unlock() if notify { e.waiterQueue.Notify(waiter.WritableEvents) @@ -2831,55 +2765,55 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() if s != nil { - e.rcvBufUsed += s.payloadSize() + e.rcvQueueInfo.RcvBufUsed += s.payloadSize() s.incRef() - e.rcvList.PushBack(s) + e.rcvQueueInfo.rcvQueue.PushBack(s) } else { - e.rcvClosed = true + e.rcvQueueInfo.RcvClosed = true } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() e.waiterQueue.Notify(waiter.ReadableEvents) } // receiveBufferAvailableLocked calculates how many bytes are still available // in the receive buffer. -// rcvListMu must be held when this function is called. +// rcvQueueMu must be held when this function is called. func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. memUsed := e.receiveMemUsed() - if memUsed >= e.rcvBufSize { + if memUsed >= e.rcvQueueInfo.RcvBufSize { return 0 } - return e.rcvBufSize - memUsed + return e.rcvQueueInfo.RcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the // receive buffer based on the actual memory used by all segments held in // receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() available := e.receiveBufferAvailableLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() return available } // receiveBufferUsed returns the amount of in-use receive buffer. func (e *endpoint) receiveBufferUsed() int { - e.rcvListMu.Lock() - used := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + used := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() return used } // receiveBufferSize returns the current size of the receive buffer. func (e *endpoint) receiveBufferSize() int { - e.rcvListMu.Lock() - size := e.rcvBufSize - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + size := e.rcvQueueInfo.RcvBufSize + e.rcvQueueInfo.rcvQueueMu.Unlock() return size } @@ -2913,9 +2847,9 @@ func (e *endpoint) maxReceiveBufferSize() int { func (e *endpoint) rcvWndScaleForHandshake() int { bufSizeForScale := e.receiveBufferSize() - e.rcvListMu.Lock() - autoTuningDisabled := e.rcvAutoParams.disabled - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled + e.rcvQueueInfo.rcvQueueMu.Unlock() if autoTuningDisabled { return FindWndScale(seqnum.Size(bufSizeForScale)) } @@ -2926,7 +2860,7 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + if e.SendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { e.setRecentTimestamp(tsVal) } } @@ -2936,7 +2870,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, // initializes the recentTS with the value provided in synOpts.TSval. func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { - e.sendTSOk = true + e.SendTSOk = true e.setRecentTimestamp(synOpts.TSVal) } } @@ -2944,7 +2878,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.tsOffset) + return tcpTimeStamp(time.Now(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is @@ -2983,7 +2917,7 @@ func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { return } if bool(v) && synOpts.SACKPermitted { - e.sackPermitted = true + e.SACKPermitted = true } } @@ -2997,118 +2931,46 @@ func (e *endpoint) maxOptionSize() (size int) { return size } -// completeState makes a full copy of the endpoint and returns it. This is used -// before invoking the probe. The state returned may not be fully consistent if -// there are intervening syscalls when the state is being copied. -func (e *endpoint) completeState() stack.TCPEndpointState { - var s stack.TCPEndpointState - s.SegTime = time.Now() - - // Copy EndpointID. - s.ID = stack.TCPEndpointID(e.ID) - - // Copy endpoint rcv state. - e.rcvListMu.Lock() - s.RcvBufSize = e.rcvBufSize - s.RcvBufUsed = e.rcvBufUsed - s.RcvClosed = e.rcvClosed - s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime - s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied - s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied - s.RcvAutoParams.RTT = e.rcvAutoParams.rtt - s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber - s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime - s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled - e.rcvListMu.Unlock() - - // Endpoint TCP Option state. - s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTimestamp() - s.TSOffset = e.tsOffset - s.SACKPermitted = e.sackPermitted +// completeStateLocked makes a full copy of the endpoint and returns it. This is +// used before invoking the probe. +// +// Precondition: e.mu must be held. +func (e *endpoint) completeStateLocked() stack.TCPEndpointState { + s := stack.TCPEndpointState{ + TCPEndpointStateInner: e.TCPEndpointStateInner, + ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), + SegTime: time.Now(), + Receiver: e.rcv.TCPReceiverState, + Sender: e.snd.TCPSenderState, + } + + sndBufSize := e.getSendBufferSize() + // Copy the send buffer atomically. + e.sndQueueInfo.sndQueueMu.Lock() + s.SndBufState = e.sndQueueInfo.TCPSndBufState + s.SndBufState.SndBufSize = sndBufSize + e.sndQueueInfo.sndQueueMu.Unlock() + + // Copy the receive buffer atomically. + e.rcvQueueInfo.rcvQueueMu.Lock() + s.RcvBufState = e.rcvQueueInfo.TCPRcvBufState + e.rcvQueueInfo.rcvQueueMu.Unlock() + + // Copy the endpoint TCP Option state. s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() - // Copy endpoint send state. - sndBufSize := e.getSendBufferSize() - e.sndBufMu.Lock() - s.SndBufSize = sndBufSize - s.SndBufUsed = e.sndBufUsed - s.SndClosed = e.sndClosed - s.SndBufInQueue = e.sndBufInQueue - s.PacketTooBigCount = e.packetTooBigCount - s.SndMTU = e.sndMTU - e.sndBufMu.Unlock() - - // Copy receiver state. - s.Receiver = stack.TCPReceiverState{ - RcvNxt: e.rcv.rcvNxt, - RcvAcc: e.rcv.rcvAcc, - RcvWndScale: e.rcv.rcvWndScale, - PendingBufUsed: e.rcv.pendingBufUsed, - } - - // Copy sender state. - s.Sender = stack.TCPSenderState{ - LastSendTime: e.snd.lastSendTime, - DupAckCount: e.snd.dupAckCount, - FastRecovery: stack.TCPFastRecoveryState{ - Active: e.snd.fr.active, - First: e.snd.fr.first, - Last: e.snd.fr.last, - MaxCwnd: e.snd.fr.maxCwnd, - HighRxt: e.snd.fr.highRxt, - RescueRxt: e.snd.fr.rescueRxt, - }, - SndCwnd: e.snd.sndCwnd, - Ssthresh: e.snd.sndSsthresh, - SndCAAckCount: e.snd.sndCAAckCount, - Outstanding: e.snd.outstanding, - SackedOut: e.snd.sackedOut, - SndWnd: e.snd.sndWnd, - SndUna: e.snd.sndUna, - SndNxt: e.snd.sndNxt, - RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, - RTTMeasureTime: e.snd.rttMeasureTime, - Closed: e.snd.closed, - RTO: e.snd.rto, - MaxPayloadSize: e.snd.maxPayloadSize, - SndWndScale: e.snd.sndWndScale, - MaxSentAck: e.snd.maxSentAck, - } e.snd.rtt.Lock() - s.Sender.SRTT = e.snd.rtt.srtt - s.Sender.SRTTInited = e.snd.rtt.srttInited + s.Sender.RTTState = e.snd.rtt.TCPRTTState e.snd.rtt.Unlock() if cubic, ok := e.snd.cc.(*cubicState); ok { - s.Sender.Cubic = stack.TCPCubicState{ - WMax: cubic.wMax, - WLastMax: cubic.wLastMax, - T: cubic.t, - TimeSinceLastCongestion: time.Since(cubic.t), - C: cubic.c, - K: cubic.k, - Beta: cubic.beta, - WC: cubic.wC, - WEst: cubic.wEst, - } - } - - rc := &e.snd.rc - s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, - ReoWnd: rc.reoWnd, - ReoWndIncr: rc.reoWndIncr, - ReoWndPersist: rc.reoWndPersist, - RTTSeq: rc.rttSeq, + s.Sender.Cubic = cubic.TCPCubicState + s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) } + + s.Sender.RACKState = e.snd.rc.TCPRACKState return s } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 590775434..034eacd72 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -58,7 +58,7 @@ func (e *endpoint) beforeSave() { if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { panic(&tcpip.ErrSaveRejection{ - Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), }) } e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) @@ -88,7 +88,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if e.workerRunning { - panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID)) } default: panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) @@ -180,14 +180,14 @@ func (e *endpoint) Resume(s *stack.Stack) { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) + if e.rcvQueueInfo.RcvBufSize < rs.Min || e.rcvQueueInfo.RcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvQueueInfo.RcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvQueueInfo.RcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) if err != nil { panic("unable to parse BindAddr: " + err.String()) } @@ -213,19 +213,19 @@ func (e *endpoint) Resume(s *stack.Stack) { case epState.connected(): bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.ID.RemoteAddress + e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -263,7 +263,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -310,23 +310,3 @@ func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) } - -// saveMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.measureTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime { - return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) { - r.rttMeasureTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 0a0d5f7a1..9e332dcf7 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -46,54 +47,16 @@ const ( // // +stateify savable type rackControl struct { - // dsackSeen indicates if the connection has seen a DSACK. - dsackSeen bool - - // endSequence is the ending TCP sequence number of the most recent - // acknowledged segment. - endSequence seqnum.Value + stack.TCPRACKState // exitedRecovery indicates if the connection is exiting loss recovery. // This flag is set if the sender is leaving the recovery after // receiving an ACK and is reset during updating of reorder window. exitedRecovery bool - // fack is the highest selectively or cumulatively acknowledged - // sequence. - fack seqnum.Value - // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool - - // reoWnd is the reordering window time used for recording packet - // transmission times. It is used to defer the moment at which RACK - // marks a packet lost. - reoWnd time.Duration - - // reoWndIncr is the multiplier applied to adjust reorder window. - reoWndIncr uint8 - - // reoWndPersist is the number of loss recoveries before resetting - // reorder window. - reoWndPersist int8 - - // rtt is the RTT of the most recently delivered packet on the - // connection (either cumulatively acknowledged or selectively - // acknowledged) that was not marked invalid as a possible spurious - // retransmission. - rtt time.Duration - - // rttSeq is the SND.NXT when rtt is updated. - rttSeq seqnum.Value - - // xmitTime is the latest transmission timestamp of the most recent - // acknowledged segment. - xmitTime time.Time `state:".(unixTime)"` - // tlpRxtOut indicates whether there is an unacknowledged // TLP retransmission. tlpRxtOut bool @@ -108,8 +71,8 @@ type rackControl struct { // init initializes RACK specific fields. func (rc *rackControl) init(snd *sender, iss seqnum.Value) { - rc.fack = iss - rc.reoWndIncr = 1 + rc.FACK = iss + rc.ReoWndIncr = 1 rc.snd = snd } @@ -117,7 +80,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) - tsOffset := rc.snd.ep.tsOffset + tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: @@ -138,7 +101,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { } } - rc.rtt = rtt + rc.RTT = rtt // The sender can either track a simple global minimum of all RTT // measurements from the connection, or a windowed min-filtered value @@ -152,9 +115,9 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - rc.xmitTime = seg.xmitTime - rc.endSequence = endSeq + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + rc.XmitTime = seg.xmitTime + rc.EndSequence = endSeq } } @@ -171,18 +134,18 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // is identified. func (rc *rackControl) detectReorder(seg *segment) { endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.fack.LessThan(endSeq) { - rc.fack = endSeq + if rc.FACK.LessThan(endSeq) { + rc.FACK = endSeq return } - if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 { - rc.reorderSeen = true + if endSeq.LessThan(rc.FACK) && seg.xmitCount == 1 { + rc.Reord = true } } func (rc *rackControl) setDSACKSeen(dsackSeen bool) { - rc.dsackSeen = dsackSeen + rc.DSACKSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -191,7 +154,7 @@ func (s *sender) shouldSchedulePTO() bool { // Schedule PTO only if RACK loss detection is enabled. return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 && // The connection supports SACK. - s.ep.sackPermitted && + s.ep.SACKPermitted && // The connection is not in loss recovery. (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. @@ -203,9 +166,9 @@ func (s *sender) shouldSchedulePTO() bool { func (s *sender) schedulePTO() { pto := time.Second s.rtt.Lock() - if s.rtt.srttInited && s.rtt.srtt > 0 { - pto = s.rtt.srtt * 2 - if s.outstanding == 1 { + if s.rtt.TCPRTTState.SRTTInited && s.rtt.TCPRTTState.SRTT > 0 { + pto = s.rtt.TCPRTTState.SRTT * 2 + if s.Outstanding == 1 { pto += wcDelayedACKTimeout } } @@ -230,10 +193,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } var dataSent bool - if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.outstanding < s.sndCwnd { - dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.Outstanding < s.SndCwnd { + dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { - s.outstanding += s.pCount(s.writeNext, s.maxPayloadSize) + s.Outstanding += s.pCount(s.writeNext, s.MaxPayloadSize) s.writeNext = s.writeNext.Next() } } @@ -255,10 +218,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } if highestSeqXmit != nil { - dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { s.rc.tlpRxtOut = true - s.rc.tlpHighRxt = s.sndNxt + s.rc.tlpHighRxt = s.SndNxt } } } @@ -274,7 +237,7 @@ func (s *sender) probeTimerExpired() tcpip.Error { // and updates TLP state accordingly. // See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3. func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { - if !(s.ep.sackPermitted && s.rc.tlpRxtOut) { + if !(s.ep.SACKPermitted && s.rc.tlpRxtOut) { return } @@ -317,13 +280,13 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { - dsackSeen := rc.dsackSeen + dsackSeen := rc.DSACKSeen snd := rc.snd // React to DSACK once per round trip. // If SND.UNA < RACK.rtt_seq: // RACK.dsack = false - if snd.sndUna.LessThan(rc.rttSeq) { + if snd.SndUna.LessThan(rc.RTTSeq) { dsackSeen = false } @@ -333,18 +296,18 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.rtt_seq = SND.NXT // RACK.reo_wnd_persist = 16 if dsackSeen { - rc.reoWndIncr++ + rc.ReoWndIncr++ dsackSeen = false - rc.rttSeq = snd.sndNxt - rc.reoWndPersist = tcpRACKRecoveryThreshold + rc.RTTSeq = snd.SndNxt + rc.ReoWndPersist = tcpRACKRecoveryThreshold } else if rc.exitedRecovery { // Else if exiting loss recovery: // RACK.reo_wnd_persist -= 1 // If RACK.reo_wnd_persist <= 0: // RACK.reo_wnd_incr = 1 - rc.reoWndPersist-- - if rc.reoWndPersist <= 0 { - rc.reoWndIncr = 1 + rc.ReoWndPersist-- + if rc.ReoWndPersist <= 0 { + rc.ReoWndIncr = 1 } rc.exitedRecovery = false } @@ -358,14 +321,14 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // Else if RACK.pkts_sacked >= RACK.dupthresh: // RACK.reo_wnd = 0 // return - if !rc.reorderSeen { + if !rc.Reord { if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { - rc.reoWnd = 0 + rc.ReoWnd = 0 return } - if snd.sackedOut >= nDupAckThreshold { - rc.reoWnd = 0 + if snd.SackedOut >= nDupAckThreshold { + rc.ReoWnd = 0 return } } @@ -374,11 +337,11 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) snd.rtt.Lock() - srtt := snd.rtt.srtt + srtt := snd.rtt.TCPRTTState.SRTT snd.rtt.Unlock() - rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) - if srtt < rc.reoWnd { - rc.reoWnd = srtt + rc.ReoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.ReoWndIncr)) + if srtt < rc.ReoWnd { + rc.ReoWnd = srtt } } @@ -403,8 +366,8 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true numLost++ @@ -435,7 +398,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { } fastRetransmit := false - if !rc.snd.fr.active { + if !rc.snd.FastRecovery.Active { rc.snd.cc.HandleLossDetected() rc.snd.enterRecovery() fastRetransmit = true @@ -471,15 +434,15 @@ func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) { } // Check the congestion window after entering recovery. - if snd.outstanding >= snd.sndCwnd { + if snd.Outstanding >= snd.SndCwnd { break } - if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.sndUna.Add(snd.sndWnd)); !sent { + if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.SndUna.Add(snd.SndWnd)); !sent { break } dataSent = true - snd.outstanding += snd.pCount(seg, snd.maxPayloadSize) + snd.Outstanding += snd.pCount(seg, snd.MaxPayloadSize) } snd.postXmit(dataSent, true /* shouldScheduleProbe */) diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go deleted file mode 100644 index c9dc7e773..000000000 --- a/pkg/tcpip/transport/tcp/rack_state.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// saveXmitTime is invoked by stateify. -func (rc *rackControl) saveXmitTime() unixTime { - return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (rc *rackControl) loadXmitTime(unix unixTime) { - rc.xmitTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index bc6793fc6..fc11b4ba9 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // receiver holds the state necessary to receive TCP segments and turn them @@ -29,26 +30,15 @@ import ( // // +stateify savable type receiver struct { + stack.TCPReceiverState ep *endpoint - rcvNxt seqnum.Value - - // rcvAcc is one beyond the last acceptable sequence number. That is, - // the "largest" sequence value that the receiver has announced to the - // its peer that it's willing to accept. This may be different than - // rcvNxt + rcvWnd if the receive window is reduced; in that case we - // have to reduce the window as we receive more data instead of - // shrinking it. - rcvAcc seqnum.Value - // rcvWnd is the non-scaled receive window last advertised to the peer. rcvWnd seqnum.Size - // rcvWUP is the rcvNxt value at the last window update sent. + // rcvWUP is the RcvNxt value at the last window update sent. rcvWUP seqnum.Value - rcvWndScale uint8 - // prevBufused is the snapshot of endpoint rcvBufUsed taken when we // advertise a receive window. prevBufUsed int @@ -58,9 +48,6 @@ type receiver struct { // pendingRcvdSegments is bounded by the receive buffer size of the // endpoint. pendingRcvdSegments segmentHeap - // pendingBufUsed tracks the total number of bytes (including segment - // overhead) currently queued in pendingRcvdSegments. - pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` @@ -68,12 +55,14 @@ type receiver struct { func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), + ep: ep, + TCPReceiverState: stack.TCPReceiverState{ + RcvNxt: irs + 1, + RcvAcc: irs.Add(rcvWnd + 1), + RcvWndScale: rcvWndScale, + }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - rcvWndScale: rcvWndScale, lastRcvdAckTime: time.Now(), } } @@ -84,34 +73,34 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // r.rcvWnd could be much larger than the window size we advertised in our // outgoing packets, we should use what we have advertised for acceptability // test. - scaledWindowSize := r.rcvWnd >> r.rcvWndScale + scaledWindowSize := r.rcvWnd >> r.RcvWndScale if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. scaledWindowSize = math.MaxUint16 } - advertisedWindowSize := scaledWindowSize << r.rcvWndScale - return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) + advertisedWindowSize := scaledWindowSize << r.RcvWndScale + return header.Acceptable(segSeq, segLen, r.RcvNxt, r.RcvNxt.Add(advertisedWindowSize)) } // currentWindow returns the available space in the window that was advertised // last to our peer. func (r *receiver) currentWindow() (curWnd seqnum.Size) { endOfWnd := r.rcvWUP.Add(r.rcvWnd) - if endOfWnd.LessThan(r.rcvNxt) { - // return 0 if r.rcvNxt is past the end of the previously advertised window. + if endOfWnd.LessThan(r.RcvNxt) { + // return 0 if r.RcvNxt is past the end of the previously advertised window. // This can happen because we accept a large segment completely even if // accepting it causes it to partially exceed the advertised window. return 0 } - return r.rcvNxt.Size(endOfWnd) + return r.RcvNxt.Size(endOfWnd) } // getSendParams returns the parameters needed by the sender when building // segments to send. -func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { +func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() - unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + unackLen := int(r.ep.snd.MaxSentAck.Size(r.RcvNxt)) bufUsed := r.ep.receiveBufferUsed() // Grow the right edge of the window only for payloads larger than the @@ -139,18 +128,18 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // edge, as we are still advertising a window that we think can be serviced. toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed - // Update rcvAcc only if new window is > previously advertised window. We + // Update RcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we // would end up dropping bytes that might already be in flight. // ==================================================== sequence space. // ^ ^ ^ ^ - // rcvWUP rcvNxt rcvAcc new rcvAcc + // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { - // If the new window moves the right edge, then update rcvAcc. - r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) + if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + // If the new window moves the right edge, then update RcvAcc. + r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -162,9 +151,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd - r.rcvWUP = r.rcvNxt + r.rcvWUP = r.RcvNxt r.prevBufUsed = bufUsed - scaledWnd := r.rcvWnd >> r.rcvWndScale + scaledWnd := r.rcvWnd >> r.RcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() @@ -177,9 +166,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Ensure that the stashed receive window always reflects what // is being advertised. - r.rcvWnd = scaledWnd << r.rcvWndScale + r.rcvWnd = scaledWnd << r.RcvWndScale } - return r.rcvNxt, scaledWnd + return r.RcvNxt, scaledWnd } // nonZeroWindow is called when the receive window grows from zero to nonzero; @@ -201,13 +190,13 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // If the segment doesn't include the seqnum we're expecting to // consume now, we're missing a segment. We cannot proceed until // we receive that segment though. - if !r.rcvNxt.InWindow(segSeq, segLen) { + if !r.RcvNxt.InWindow(segSeq, segLen) { return false } // Trim segment to eliminate already acknowledged data. - if segSeq.LessThan(r.rcvNxt) { - diff := segSeq.Size(r.rcvNxt) + if segSeq.LessThan(r.RcvNxt) { + diff := segSeq.Size(r.RcvNxt) segLen -= diff segSeq.UpdateForward(diff) s.sequenceNumber.UpdateForward(diff) @@ -217,35 +206,35 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Move segment to ready-to-deliver list. Wakeup any waiters. r.ep.readyToRead(s) - } else if segSeq != r.rcvNxt { + } else if segSeq != r.RcvNxt { return false } // Update the segment that we're expecting to consume. - r.rcvNxt = segSeq.Add(segLen) + r.RcvNxt = segSeq.Add(segLen) // In cases of a misbehaving sender which could send more than the // advertised window, we could end up in a situation where we get a // segment that exceeds the window advertised. Instead of partially // accepting the segment and discarding bytes beyond the advertised - // window, we accept the whole segment and make sure r.rcvAcc is moved - // forward to match r.rcvNxt to indicate that the window is now closed. + // window, we accept the whole segment and make sure r.RcvAcc is moved + // forward to match r.RcvNxt to indicate that the window is now closed. // // In absence of this check the r.acceptable() check fails and accepts // segments that should be dropped because rcvWnd is calculated as - // the size of the interval (rcvNxt, rcvAcc] which becomes extremely - // large if rcvAcc is ever less than rcvNxt. - if r.rcvAcc.LessThan(r.rcvNxt) { - r.rcvAcc = r.rcvNxt + // the size of the interval (RcvNxt, RcvAcc] which becomes extremely + // large if RcvAcc is ever less than RcvNxt. + if r.RcvAcc.LessThan(r.RcvNxt) { + r.RcvAcc = r.RcvNxt } // Trim SACK Blocks to remove any SACK information that covers // sequence numbers that have been consumed. - TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { - r.rcvNxt++ + r.RcvNxt++ // Send ACK immediately. r.ep.snd.sendAck() @@ -260,7 +249,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -280,7 +269,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { - r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() + r.PendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() // Note that slice truncation does not allow garbage collection of @@ -295,7 +284,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -323,40 +312,40 @@ func (r *receiver) updateRTT() { // estimate the round-trip time by observing the time between when a byte // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. - r.ep.rcvListMu.Lock() - if r.ep.rcvAutoParams.rttMeasureTime.IsZero() { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { // New measurement. - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) { - r.ep.rcvListMu.Unlock() + if r.RcvNxt.LessThan(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber) { + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime) + rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. - if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt { - r.ep.rcvAutoParams.rtt = rtt + if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { + r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { - r.ep.rcvListMu.Lock() - rcvClosed := r.ep.rcvClosed || r.closed - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := r.ep.rcvQueueInfo.RcvClosed || r.closed + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { case StateCloseWait, StateClosing, StateLastAck: - if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + if !s.sequenceNumber.LessThanEq(r.RcvNxt) { // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -384,17 +373,17 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // The ESTABLISHED state processing is here where if the ACK check // fails, we ignore the packet: // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + if r.ep.snd.SndNxt.LessThan(s.ackNumber) { r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., - // SHUT_RD) then any data past the rcvNxt should + // SHUT_RD) then any data past the RcvNxt should // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + if state != StateCloseWait && rcvClosed && r.RcvNxt.LessThan(endDataSeq) { return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { @@ -403,7 +392,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If it's a retransmission of an old data segment // or a pure ACK then allow it. - if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.RcvNxt) || s.logicalLen() == 0 { break } @@ -413,7 +402,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // then the only acceptable segment is a // FIN. Since FIN can technically also carry // data we verify that the segment carrying a - // FIN ends at exactly e.rcvNxt+1. + // FIN ends at exactly e.RcvNxt+1. // // From RFC793 page 25. // @@ -423,7 +412,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -435,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // end has closed and the peer is yet to send a FIN. Hence we // compare only the payload. segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + if rcvClosed && !segEnd.LessThanEq(r.RcvNxt) { return true, nil } return false, nil @@ -477,13 +466,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { - r.ep.rcvListMu.Lock() - r.pendingBufUsed += s.segMemSize() - r.ep.rcvListMu.Unlock() + if r.ep.receiveBufferAvailable() > 0 && r.PendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed += s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) - UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.RcvNxt) } // Immediately send an ack so that the peer knows it may @@ -508,15 +497,15 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { segSeq := s.sequenceNumber // Skip segment altogether if it has already been acknowledged. - if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + if !segSeq.Add(segLen-1).LessThan(r.RcvNxt) && !r.consumeSegment(s, segSeq, segLen) { break } heap.Pop(&r.pendingRcvdSegments) - r.ep.rcvListMu.Lock() - r.pendingBufUsed -= s.segMemSize() - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed -= s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.decRef() } return false, nil @@ -558,7 +547,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } @@ -569,11 +558,11 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn } // Update Timestamp if required. See RFC7323, section-4.3. - if r.ep.sendTSOk && s.parsedOptions.TS { - r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + if r.ep.SendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() @@ -584,8 +573,8 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // carries data then just send an ACK. This is according to RFC 793, // page 37. // - // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. - if segSeq != r.rcvNxt || segLen != 0 { + // NOTE: In TIME_WAIT the only acceptable sequence number is RcvNxt. + if segSeq != r.RcvNxt || segLen != 0 { r.ep.snd.sendAck() } return false, false diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go index ff39780a5..063552c7f 100644 --- a/pkg/tcpip/transport/tcp/reno.go +++ b/pkg/tcpip/transport/tcp/reno.go @@ -34,14 +34,14 @@ func newRenoCC(s *sender) *renoState { func (r *renoState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := r.s.sndCwnd + packetsAcked - if newcwnd >= r.s.sndSsthresh { - newcwnd = r.s.sndSsthresh - r.s.sndCAAckCount = 0 + newcwnd := r.s.SndCwnd + packetsAcked + if newcwnd >= r.s.Ssthresh { + newcwnd = r.s.Ssthresh + r.s.SndCAAckCount = 0 } - packetsAcked -= newcwnd - r.s.sndCwnd - r.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - r.s.SndCwnd + r.s.SndCwnd = newcwnd return packetsAcked } @@ -49,19 +49,19 @@ func (r *renoState) updateSlowStart(packetsAcked int) int { // avoidance mode as described in RFC5681 section 3.1 func (r *renoState) updateCongestionAvoidance(packetsAcked int) { // Consume the packets in congestion avoidance mode. - r.s.sndCAAckCount += packetsAcked - if r.s.sndCAAckCount >= r.s.sndCwnd { - r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd - r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + r.s.SndCAAckCount += packetsAcked + if r.s.SndCAAckCount >= r.s.SndCwnd { + r.s.SndCwnd += r.s.SndCAAckCount / r.s.SndCwnd + r.s.SndCAAckCount = r.s.SndCAAckCount % r.s.SndCwnd } } // reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, // page 6, eq. 4. It is called when we detect congestion in the network. func (r *renoState) reduceSlowStartThreshold() { - r.s.sndSsthresh = r.s.outstanding / 2 - if r.s.sndSsthresh < 2 { - r.s.sndSsthresh = 2 + r.s.Ssthresh = r.s.Outstanding / 2 + if r.s.Ssthresh < 2 { + r.s.Ssthresh = 2 } } @@ -70,7 +70,7 @@ func (r *renoState) reduceSlowStartThreshold() { // were acknowledged. // Update implements congestionControl.Update. func (r *renoState) Update(packetsAcked int) { - if r.s.sndCwnd < r.s.sndSsthresh { + if r.s.SndCwnd < r.s.Ssthresh { packetsAcked = r.updateSlowStart(packetsAcked) if packetsAcked == 0 { return @@ -94,7 +94,7 @@ func (r *renoState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - r.s.sndCwnd = 1 + r.s.SndCwnd = 1 } // PostRecovery implements congestionControl.PostRecovery. diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go index 2aa708e97..d368a29fc 100644 --- a/pkg/tcpip/transport/tcp/reno_recovery.go +++ b/pkg/tcpip/transport/tcp/reno_recovery.go @@ -31,25 +31,25 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { snd := rr.s // We are in fast recovery mode. Ignore the ack if it's out of range. - if !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // Don't count this as a duplicate if it is carrying data or // updating the window. - if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window { + if rcvdSeg.logicalLen() != 0 || snd.SndWnd != rcvdSeg.window { return } // Inflate the congestion window if we're getting duplicate acks // for the packet we retransmitted. - if !fastRetransmit && ack == snd.fr.first { + if !fastRetransmit && ack == snd.FastRecovery.First { // We received a dup, inflate the congestion window by 1 packet // if we're not at the max yet. Only inflate the window if // regular FastRecovery is in use, RFC6675 does not require // inflating cwnd on duplicate ACKs. - if snd.sndCwnd < snd.fr.maxCwnd { - snd.sndCwnd++ + if snd.SndCwnd < snd.FastRecovery.MaxCwnd { + snd.SndCwnd++ } return } @@ -61,7 +61,7 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { // back onto the wire. // // N.B. The retransmit timer will be reset by the caller. - snd.fr.first = ack - snd.dupAckCount = 0 + snd.FastRecovery.First = ack + snd.DupAckCount = 0 snd.resendSegment() } diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go index 9d406b0bc..cd860b5e8 100644 --- a/pkg/tcpip/transport/tcp/sack_recovery.go +++ b/pkg/tcpip/transport/tcp/sack_recovery.go @@ -42,14 +42,14 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen } nextSegHint := snd.writeList.Front() - for snd.outstanding < snd.sndCwnd { + for snd.Outstanding < snd.SndCwnd { var nextSeg *segment var rescueRtx bool nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint) if nextSeg == nil { return dataSent } - if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + if !snd.isAssignedSequenceNumber(nextSeg) || snd.SndNxt.LessThanEq(nextSeg.sequenceNumber) { // New data being sent. // Step C.3 described below is handled by @@ -67,7 +67,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen return dataSent } dataSent = true - snd.outstanding++ + snd.Outstanding++ snd.writeNext = nextSeg.Next() continue } @@ -79,7 +79,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // "The estimate of the amount of data outstanding in the network // must be updated by incrementing pipe by the number of octets // transmitted in (C.1)." - snd.outstanding++ + snd.Outstanding++ dataSent = true snd.sendSegment(nextSeg) @@ -88,7 +88,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // We do the last part of rule (4) of NextSeg here to update // RescueRxt as until this point we don't know if we are going // to use the rescue transmission. - snd.fr.rescueRxt = snd.fr.last + snd.FastRecovery.RescueRxt = snd.FastRecovery.Last } else { // RFC 6675, Step C.2 // @@ -96,7 +96,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // HighData, HighRxt MUST be set to the highest sequence // number of the retransmitted segment unless NextSeg () // rule (4) was invoked for this retransmission." - snd.fr.highRxt = segEnd - 1 + snd.FastRecovery.HighRxt = segEnd - 1 } } return dataSent @@ -109,12 +109,12 @@ func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { } // We are in fast recovery mode. Ignore the ack if it's out of range. - if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if ack := rcvdSeg.ackNumber; !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // RFC 6675 recovery algorithm step C 1-5. - end := snd.sndUna.Add(snd.sndWnd) - dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end) + end := snd.SndUna.Add(snd.SndWnd) + dataSent := sr.handleSACKRecovery(snd.MaxPayloadSize, end) snd.postXmit(dataSent, true /* shouldScheduleProbe */) } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index faca35892..cf2e8dcd8 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -85,56 +86,12 @@ type lossRecovery interface { // // +stateify savable type sender struct { + stack.TCPSenderState ep *endpoint - // lastSendTime is the timestamp when the last packet was sent. - lastSendTime time.Time `state:".(unixTime)"` - - // dupAckCount is the number of duplicated acks received. It is used for - // fast retransmit. - dupAckCount int - - // fr holds state related to fast recovery. - fr fastRecovery - // lr is the loss recovery algorithm used by the sender. lr lossRecovery - // sndCwnd is the congestion window, in packets. - sndCwnd int - - // sndSsthresh is the threshold between slow start and congestion - // avoidance. - sndSsthresh int - - // sndCAAckCount is the number of packets acknowledged during congestion - // avoidance. When enough packets have been ack'd (typically cwnd - // packets), the congestion window is incremented by one. - sndCAAckCount int - - // outstanding is the number of outstanding packets, that is, packets - // that have been sent but not yet acknowledged. - outstanding int - - // sackedOut is the number of packets which are selectively acked. - sackedOut int - - // sndWnd is the send window size. - sndWnd seqnum.Size - - // sndUna is the next unacknowledged sequence number. - sndUna seqnum.Value - - // sndNxt is the sequence number of the next segment to be sent. - sndNxt seqnum.Value - - // rttMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - rttMeasureSeqNum seqnum.Value - - // rttMeasureTime is the time when the rttMeasureSeqNum was sent. - rttMeasureTime time.Time `state:".(unixTime)"` - // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` @@ -147,17 +104,15 @@ type sender struct { // window probes. unackZeroWindowProbes uint32 `state:"nosave"` - closed bool writeNext *segment writeList segmentList resendTimer timer `state:"nosave"` resendWaker sleep.Waker `state:"nosave"` - // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", - // "round-trip time variation" and "retransmit timeout", as defined in + // rtt.TCPRTTState.SRTT and rtt.TCPRTTState.RTTVar are the "smoothed + // round-trip time", and "round-trip time variation", as defined in // section 2 of RFC 6298. rtt rtt - rto time.Duration // minRTO is the minimum permitted value for sender.rto. minRTO time.Duration @@ -168,20 +123,9 @@ type sender struct { // maxRetries is the maximum permitted retransmissions. maxRetries uint32 - // maxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - maxPayloadSize int - // gso is set if generic segmentation offload is enabled. gso bool - // sndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - sndWndScale uint8 - - // maxSentAck is the maxium acknowledgement actually sent. - maxSentAck seqnum.Value - // state is the current state of congestion control for this endpoint. state tcpip.CongestionControlState @@ -209,41 +153,7 @@ type sender struct { type rtt struct { sync.Mutex `state:"nosave"` - srtt time.Duration - rttvar time.Duration - srttInited bool -} - -// fastRecovery holds information related to fast recovery from a packet loss. -// -// +stateify savable -type fastRecovery struct { - // active whether the endpoint is in fast recovery. The following fields - // are only meaningful when active is true. - active bool - - // first and last represent the inclusive sequence number range being - // recovered. - first seqnum.Value - last seqnum.Value - - // maxCwnd is the maximum value the congestion window may be inflated to - // due to duplicate acks. This exists to avoid attacks where the - // receiver intentionally sends duplicate acks to artificially inflate - // the sender's cwnd. - maxCwnd int - - // highRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - highRxt seqnum.Value - - // rescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - rescueRxt seqnum.Value + stack.TCPRTTState } func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { @@ -253,20 +163,22 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint maxPayloadSize := int(mss) - ep.maxOptionSize() s := &sender{ - ep: ep, - sndWnd: sndWnd, - sndUna: iss + 1, - sndNxt: iss + 1, - rto: 1 * time.Second, - rttMeasureSeqNum: iss + 1, - lastSendTime: time.Now(), - maxPayloadSize: maxPayloadSize, - maxSentAck: irs + 1, - fr: fastRecovery{ - // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. - last: iss, - highRxt: iss, - rescueRxt: iss, + ep: ep, + TCPSenderState: stack.TCPSenderState{ + SndWnd: sndWnd, + SndUna: iss + 1, + SndNxt: iss + 1, + RTTMeasureSeqNum: iss + 1, + LastSendTime: time.Now(), + MaxPayloadSize: maxPayloadSize, + MaxSentAck: irs + 1, + FastRecovery: stack.TCPFastRecoveryState{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + Last: iss, + HighRxt: iss, + RescueRxt: iss, + }, + RTO: 1 * time.Second, }, gso: ep.gso != nil, } @@ -282,7 +194,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { - s.sndWndScale = uint8(sndWndScale) + s.SndWndScale = uint8(sndWndScale) } s.resendTimer.init(&s.resendWaker) @@ -294,7 +206,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // Initialize SACK Scoreboard after updating max payload size as we use // the maxPayloadSize as the smss when determining if a segment is lost // etc. - s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + s.ep.scoreboard = NewSACKScoreboard(uint16(s.MaxPayloadSize), iss) // Get Stack wide config. var minRTO tcpip.TCPMinRTOOption @@ -322,10 +234,10 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // returns a handle to it. It also initializes the sndCwnd and sndSsThresh to // their initial values. func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { - s.sndCwnd = InitialCwnd + s.SndCwnd = InitialCwnd // Set sndSsthresh to the maximum int value, which depends on the // platform. - s.sndSsthresh = int(^uint(0) >> 1) + s.Ssthresh = int(^uint(0) >> 1) switch congestionControlName { case ccCubic: @@ -339,7 +251,7 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon // initLossRecovery initiates the loss recovery algorithm for the sender. func (s *sender) initLossRecovery() lossRecovery { - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return newSACKRecovery(s) } return newRenoRecovery(s) @@ -355,7 +267,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m -= s.ep.maxOptionSize() // We don't adjust up for now. - if m >= s.maxPayloadSize { + if m >= s.MaxPayloadSize { return } @@ -364,8 +276,8 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } - oldMSS := s.maxPayloadSize - s.maxPayloadSize = m + oldMSS := s.MaxPayloadSize + s.MaxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) } @@ -380,9 +292,9 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // maxPayloadSize. s.ep.scoreboard.smss = uint16(m) - s.outstanding -= count - if s.outstanding < 0 { - s.outstanding = 0 + s.Outstanding -= count + if s.Outstanding < 0 { + s.Outstanding = 0 } // Rewind writeNext to the first segment exceeding the MTU. Do nothing @@ -401,10 +313,10 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { nextSeg = seg } - if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Update sackedOut for new maximum payload size. - s.sackedOut -= s.pCount(seg, oldMSS) - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, oldMSS) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } } @@ -416,32 +328,32 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // sendAck sends an ACK segment. func (s *sender) sendAck() { - s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) + s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.SndNxt) } // updateRTO updates the retransmit timeout when a new roud-trip time is // available. This is done in accordance with section 2 of RFC 6298. func (s *sender) updateRTO(rtt time.Duration) { s.rtt.Lock() - if !s.rtt.srttInited { - s.rtt.rttvar = rtt / 2 - s.rtt.srtt = rtt - s.rtt.srttInited = true + if !s.rtt.TCPRTTState.SRTTInited { + s.rtt.TCPRTTState.RTTVar = rtt / 2 + s.rtt.TCPRTTState.SRTT = rtt + s.rtt.TCPRTTState.SRTTInited = true } else { - diff := s.rtt.srtt - rtt + diff := s.rtt.TCPRTTState.SRTT - rtt if diff < 0 { diff = -diff } - // Use RFC6298 standard algorithm to update rttvar and srtt when + // Use RFC6298 standard algorithm to update TCPRTTState.RTTVar and TCPRTTState.SRTT when // no timestamps are available. - if !s.ep.sendTSOk { - s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 - s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 + if !s.ep.SendTSOk { + s.rtt.TCPRTTState.RTTVar = (3*s.rtt.TCPRTTState.RTTVar + diff) / 4 + s.rtt.TCPRTTState.SRTT = (7*s.rtt.TCPRTTState.SRTT + rtt) / 8 } else { // When we are taking RTT measurements of every ACK then // we need to use a modified method as specified in // https://tools.ietf.org/html/rfc7323#appendix-G - if s.outstanding == 0 { + if s.Outstanding == 0 { s.rtt.Unlock() return } @@ -449,7 +361,7 @@ func (s *sender) updateRTO(rtt time.Duration) { // terms of packets and not bytes. This is similar to // how linux also does cwnd and inflight. In practice // this approximation works as expected. - expectedSamples := math.Ceil(float64(s.outstanding) / 2) + expectedSamples := math.Ceil(float64(s.Outstanding) / 2) // alpha & beta values are the original values as recommended in // https://tools.ietf.org/html/rfc6298#section-2.3. @@ -458,17 +370,17 @@ func (s *sender) updateRTO(rtt time.Duration) { alphaPrime := alpha / expectedSamples betaPrime := beta / expectedSamples - rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds() - srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds() - s.rtt.rttvar = time.Duration(rttVar * float64(time.Second)) - s.rtt.srtt = time.Duration(srtt * float64(time.Second)) + rttVar := (1-betaPrime)*s.rtt.TCPRTTState.RTTVar.Seconds() + betaPrime*diff.Seconds() + srtt := (1-alphaPrime)*s.rtt.TCPRTTState.SRTT.Seconds() + alphaPrime*rtt.Seconds() + s.rtt.TCPRTTState.RTTVar = time.Duration(rttVar * float64(time.Second)) + s.rtt.TCPRTTState.SRTT = time.Duration(srtt * float64(time.Second)) } } - s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.RTO = s.rtt.TCPRTTState.SRTT + 4*s.rtt.TCPRTTState.RTTVar s.rtt.Unlock() - if s.rto < s.minRTO { - s.rto = s.minRTO + if s.RTO < s.minRTO { + s.RTO = s.minRTO } } @@ -476,20 +388,20 @@ func (s *sender) updateRTO(rtt time.Duration) { func (s *sender) resendSegment() { // Don't use any segments we already sent to measure RTT as they may // have been affected by packets being lost. - s.rttMeasureSeqNum = s.sndNxt + s.RTTMeasureSeqNum = s.SndNxt // Resend the segment. if seg := s.writeList.Front(); seg != nil { - if seg.data.Size() > s.maxPayloadSize { - s.splitSeg(seg, s.maxPayloadSize) + if seg.data.Size() > s.MaxPayloadSize { + s.splitSeg(seg, s.MaxPayloadSize) } // See: RFC 6675 section 5 Step 4.3 // // To prevent retransmission, set both the HighRXT and RescueRXT // to the highest sequence number in the retransmitted segment. - s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 - s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.HighRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.RescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() s.ep.stats.SendErrors.FastRetransmit.Increment() @@ -554,15 +466,15 @@ func (s *sender) retransmitTimerExpired() bool { // Set new timeout. The timer will be restarted by the call to sendData // below. - s.rto *= 2 + s.RTO *= 2 // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 - if s.rto > s.maxRTO { - s.rto = s.maxRTO + if s.RTO > s.maxRTO { + s.RTO = s.maxRTO } // Cap RTO to remaining time. - if s.rto > remaining { - s.rto = remaining + if s.RTO > remaining { + s.RTO = remaining } // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. @@ -571,9 +483,9 @@ func (s *sender) retransmitTimerExpired() bool { // After a retransmit timeout, record the highest sequence number // transmitted in the variable recover, and exit the fast recovery // procedure if applicable. - s.fr.last = s.sndNxt - 1 + s.FastRecovery.Last = s.SndNxt - 1 - if s.fr.active { + if s.FastRecovery.Active { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it // has already been updated when entered fast-recovery. @@ -589,7 +501,7 @@ func (s *sender) retransmitTimerExpired() bool { // // We'll keep on transmitting (or retransmitting) as we get acks for // the data we transmit. - s.outstanding = 0 + s.Outstanding = 0 // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1 // @@ -663,7 +575,7 @@ func (s *sender) splitSeg(seg *segment, size int) { // window space. // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() - if seg.data.Size() > s.maxPayloadSize { + if seg.data.Size() > s.MaxPayloadSize { seg.flags ^= header.TCPFlagPsh } @@ -689,7 +601,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // transmitted (i.e. either it has no assigned sequence number // or if it does have one, it's >= the next sequence number // to be sent [i.e. >= s.sndNxt]). - if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) { + if !s.isAssignedSequenceNumber(seg) || s.SndNxt.LessThanEq(seg.sequenceNumber) { hint = nil break } @@ -710,7 +622,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // (1.a) S2 is greater than HighRxt // (1.b) S2 is less than highest octect covered by // any received SACK. - if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { + if s.FastRecovery.HighRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { // NextSeg(): // (1.c) IsLost(S2) returns true. if s.ep.scoreboard.IsLost(segSeq) { @@ -743,7 +655,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // unSACKed sequence number SHOULD be returned, and // RescueRxt set to RecoveryPoint. HighRxt MUST NOT // be updated. - if s.fr.rescueRxt.LessThan(s.sndUna - 1) { + if s.FastRecovery.RescueRxt.LessThan(s.SndUna - 1) { if s4 != nil { if s4.sequenceNumber.LessThan(segSeq) { s4 = seg @@ -763,7 +675,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // previously unsent data starting with sequence number // HighData+1 MUST be returned." for seg := s.writeNext; seg != nil; seg = seg.Next() { - if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.SndNxt) { continue } // We do not split the segment here to <= smss as it has @@ -788,7 +700,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(s.sndNxt.Size(end)) + available := int(s.SndNxt.Size(end)) if available > limit { available = limit } @@ -816,7 +728,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } if !nextTooBig && seg.data.Size() < available { // Segment is not full. - if s.outstanding > 0 && s.ep.ops.GetDelayOption() { + if s.Outstanding > 0 && s.ep.ops.GetDelayOption() { // Nagle's algorithm. From Wikipedia: // Nagle's algorithm works by // combining a number of small @@ -835,7 +747,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // send space and MSS. // TODO(gvisor.dev/issue/2833): Drain the held segments after a // timeout. - if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() { + if seg.data.Size() < s.MaxPayloadSize && s.ep.ops.GetCorkOption() { return false } } @@ -843,7 +755,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Assign flags. We don't do it above so that we can merge // additional data if Nagle holds the segment. - seg.sequenceNumber = s.sndNxt + seg.sequenceNumber = s.SndNxt seg.flags = header.TCPFlagAck | header.TCPFlagPsh } @@ -893,12 +805,12 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // the segment right here if there are no pending segments. If // there are pending segments, segment transmits are deferred to // the retransmit timer handler. - if s.sndUna != s.sndNxt { + if s.SndUna != s.SndNxt { switch { case available >= seg.data.Size(): // OK to send, the whole segments fits in the // receiver's advertised window. - case available >= s.maxPayloadSize: + case available >= s.MaxPayloadSize: // OK to send, at least 1 MSS sized segment fits // in the receiver's advertised window. default: @@ -918,8 +830,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // If GSO is not in use then cap available to // maxPayloadSize. When GSO is in use the gVisor GSO logic or // the host GSO logic will cap the segment to the correct size. - if s.ep.gso == nil && available > s.maxPayloadSize { - available = s.maxPayloadSize + if s.ep.gso == nil && available > s.MaxPayloadSize { + available = s.MaxPayloadSize } if seg.data.Size() > available { @@ -933,8 +845,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Update sndNxt if we actually sent new data (as opposed to // retransmitting some previously sent data). - if s.sndNxt.LessThan(segEnd) { - s.sndNxt = segEnd + if s.SndNxt.LessThan(segEnd) { + s.SndNxt = segEnd } return true @@ -945,9 +857,9 @@ func (s *sender) sendZeroWindowProbe() { s.unackZeroWindowProbes++ // Send a zero window probe with sequence number pointing to // the last acknowledged byte. - s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win) + s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.SndUna-1, ack, win) // Rearm the timer to continue probing. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) enableZeroWindowProbing() { @@ -958,7 +870,7 @@ func (s *sender) enableZeroWindowProbing() { if s.firstRetransmittedSegXmitTime.IsZero() { s.firstRetransmittedSegXmitTime = time.Now() } - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) disableZeroWindowProbing() { @@ -978,12 +890,12 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // If the sender has advertized zero receive window and we have // data to be sent out, start zero window probing to query the // the remote for it's receive window size. - if s.writeNext != nil && s.sndWnd == 0 { + if s.writeNext != nil && s.SndWnd == 0 { s.enableZeroWindowProbing() } // If we have no more pending data, start the keepalive timer. - if s.sndUna == s.sndNxt { + if s.SndUna == s.SndNxt { s.ep.resetKeepaliveTimer(false) } else { // Enable timers if we have pending data. @@ -992,10 +904,10 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { s.schedulePTO() } else if !s.resendTimer.enabled() { s.probeTimer.disable() - if s.outstanding > 0 { + if s.Outstanding > 0 { // Enable the resend timer if it's not enabled yet and there is // outstanding data. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1004,29 +916,29 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { - limit := s.maxPayloadSize + limit := s.MaxPayloadSize if s.gso { limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) } - end := s.sndUna.Add(s.sndWnd) + end := s.SndUna.Add(s.SndWnd) // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { - if s.sndCwnd > InitialCwnd { - s.sndCwnd = InitialCwnd + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if s.SndCwnd > InitialCwnd { + s.SndCwnd = InitialCwnd } } var dataSent bool - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + for seg := s.writeNext; seg != nil && s.Outstanding < s.SndCwnd; seg = seg.Next() { + cwndLimit := (s.SndCwnd - s.Outstanding) * s.MaxPayloadSize if cwndLimit < limit { limit = cwndLimit } - if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.isAssignedSequenceNumber(seg) && s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Move writeNext along so that we don't try and scan data that // has already been SACKED. s.writeNext = seg.Next() @@ -1036,7 +948,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding += s.pCount(seg, s.maxPayloadSize) + s.Outstanding += s.pCount(seg, s.MaxPayloadSize) s.writeNext = seg.Next() } @@ -1044,21 +956,21 @@ func (s *sender) sendData() { } func (s *sender) enterRecovery() { - s.fr.active = true + s.FastRecovery.Active = true // Save state to reflect we're now in fast recovery. // // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. - s.sndCwnd = s.sndSsthresh + 3 - s.sackedOut = 0 - s.dupAckCount = 0 - s.fr.first = s.sndUna - s.fr.last = s.sndNxt - 1 - s.fr.maxCwnd = s.sndCwnd + s.outstanding - s.fr.highRxt = s.sndUna - s.fr.rescueRxt = s.sndUna - if s.ep.sackPermitted { + s.SndCwnd = s.Ssthresh + 3 + s.SackedOut = 0 + s.DupAckCount = 0 + s.FastRecovery.First = s.SndUna + s.FastRecovery.Last = s.SndNxt - 1 + s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding + s.FastRecovery.HighRxt = s.SndUna + s.FastRecovery.RescueRxt = s.SndUna + if s.ep.SACKPermitted { s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to @@ -1075,12 +987,12 @@ func (s *sender) enterRecovery() { } func (s *sender) leaveRecovery() { - s.fr.active = false - s.fr.maxCwnd = 0 - s.dupAckCount = 0 + s.FastRecovery.Active = false + s.FastRecovery.MaxCwnd = 0 + s.DupAckCount = 0 // Deflate cwnd. It had been artificially inflated when new dups arrived. - s.sndCwnd = s.sndSsthresh + s.SndCwnd = s.Ssthresh s.cc.PostRecovery() } @@ -1099,7 +1011,7 @@ func (s *sender) isAssignedSequenceNumber(seg *segment) bool { func (s *sender) SetPipe() { // If SACK isn't permitted or it is permitted but recovery is not active // then ignore pipe calculations. - if !s.ep.sackPermitted || !s.fr.active { + if !s.ep.SACKPermitted || !s.FastRecovery.Active { return } pipe := 0 @@ -1119,7 +1031,7 @@ func (s *sender) SetPipe() { // After initializing pipe to zero, the following steps are // taken for each octet 'S1' in the sequence space between // HighACK and HighData that has not been SACKed: - if !s1.sequenceNumber.LessThan(s.sndNxt) { + if !s1.sequenceNumber.LessThan(s.SndNxt) { break } if s.ep.scoreboard.IsSACKED(sb) { @@ -1138,20 +1050,20 @@ func (s *sender) SetPipe() { } // SetPipe(): // (b) If S1 <= HighRxt, Pipe is incremented by 1. - if s1.sequenceNumber.LessThanEq(s.fr.highRxt) { + if s1.sequenceNumber.LessThanEq(s.FastRecovery.HighRxt) { pipe++ } } } - s.outstanding = pipe + s.Outstanding = pipe } // shouldEnterRecovery returns true if the sender should enter fast recovery // based on dupAck count and sack scoreboard. // See RFC 6675 section 5. func (s *sender) shouldEnterRecovery() bool { - return s.dupAckCount >= nDupAckThreshold || - (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna)) + return s.DupAckCount >= nDupAckThreshold || + (s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.SndUna)) } // detectLoss is called when an ack is received and returns whether a loss is @@ -1163,24 +1075,24 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // If RACK is enabled and there is no reordering we should honor the // three duplicate ACK rule to enter recovery. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4 - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { - if s.rc.reorderSeen { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.rc.Reord { return false } } if !s.isDupAck(seg) { - s.dupAckCount = 0 + s.DupAckCount = 0 return false } - s.dupAckCount++ + s.DupAckCount++ // Do not enter fast recovery until we reach nDupAckThreshold or the // first unacknowledged byte is considered lost as per SACK scoreboard. if !s.shouldEnterRecovery() { // RFC 6675 Step 3. - s.fr.highRxt = s.sndUna - 1 + s.FastRecovery.HighRxt = s.SndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() s.state = tcpip.Disorder @@ -1196,8 +1108,8 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // Note that we only enter recovery when at least one more byte of data // beyond s.fr.last (the highest byte that was outstanding when fast // retransmit was last entered) is acked. - if !s.fr.last.LessThan(seg.ackNumber - 1) { - s.dupAckCount = 0 + if !s.FastRecovery.Last.LessThan(seg.ackNumber - 1) { + s.DupAckCount = 0 return false } s.cc.HandleLossDetected() @@ -1212,22 +1124,22 @@ func (s *sender) isDupAck(seg *segment) bool { // can leverage the SACK information to determine when an incoming ACK is a // "duplicate" (e.g., if the ACK contains previously unknown SACK // information). - if s.ep.sackPermitted && !seg.hasNewSACKInfo { + if s.ep.SACKPermitted && !seg.hasNewSACKInfo { return false } // (a) The receiver of the ACK has outstanding data. - return s.sndUna != s.sndNxt && + return s.SndUna != s.SndNxt && // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). - seg.ackNumber == s.sndUna && + seg.ackNumber == s.SndUna && // (e) the advertised window in the incoming acknowledgment equals the // advertised window in the last incoming acknowledgment. - s.sndWnd == seg.window + s.SndWnd == seg.window } // Iterate the writeList and update RACK for each segment which is newly acked @@ -1267,7 +1179,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } seg = seg.Next() } @@ -1322,18 +1234,18 @@ func checkDSACK(rcvdSeg *segment) bool { // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.rttMeasureTime)) - s.rttMeasureSeqNum = s.sndNxt + if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { + s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.RTTMeasureSeqNum = s.SndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { - s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.MaxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: @@ -1347,7 +1259,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.SndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.SndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) rcvdSeg.hasNewSACKInfo = true } @@ -1375,10 +1287,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ack := rcvdSeg.ackNumber fastRetransmit := false // Do not leave fast recovery, if the ACK is out of range. - if s.fr.active { + if s.FastRecovery.Active { // Leave fast recovery if it acknowledges all the data covered by // this fast recovery session. - if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) { + if (ack-1).InRange(s.SndUna, s.SndNxt) && s.FastRecovery.Last.LessThan(ack) { s.leaveRecovery() } } else { @@ -1392,28 +1304,28 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Stash away the current window size. - s.sndWnd = rcvdSeg.window + s.SndWnd = rcvdSeg.window // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously // unacknowledged segment. if s.zeroWindowProbing && rcvdSeg.window > 0 && - (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { + (ack == s.SndUna || (ack-1).InRange(s.SndUna, s.SndNxt)) { s.disableZeroWindowProbing() } // On receiving the ACK for the zero window probe, account for it and // skip trying to send any segment as we are still probing for // receive window to become non-zero. - if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna { + if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.SndUna { s.unackZeroWindowProbes-- return } // Ignore ack if it doesn't acknowledge any new data. - if (ack - 1).InRange(s.sndUna, s.sndNxt) { - s.dupAckCount = 0 + if (ack - 1).InRange(s.SndUna, s.SndNxt) { + s.DupAckCount = 0 // See : https://tools.ietf.org/html/rfc1323#section-3.3. // Specifically we should only update the RTO using TSEcr if the @@ -1423,7 +1335,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond @@ -1438,12 +1350,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // When an ack is received we must rearm the timer. // RFC 6298 5.3 s.probeTimer.disable() - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } // Remove all acknowledged data from the write list. - acked := s.sndUna.Size(ack) - s.sndUna = ack + acked := s.SndUna.Size(ack) + s.SndUna = ack // The remote ACK-ing at least 1 byte is an indication that we have a // full-duplex connection to the remote as the only way we will receive an @@ -1457,7 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } ackLeft := acked - originalOutstanding := s.outstanding + originalOutstanding := s.Outstanding for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1466,10 +1378,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg, s.maxPayloadSize) + prevCount := s.pCount(seg, s.MaxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) + s.Outstanding -= prevCount - s.pCount(seg, s.MaxPayloadSize) break } @@ -1478,7 +1390,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Update the RACK fields if SACK is enabled. - if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1488,10 +1400,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). - if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg, s.maxPayloadSize) + if !s.ep.SACKPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + s.Outstanding -= s.pCount(seg, s.MaxPayloadSize) } else { - s.sackedOut -= s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, s.MaxPayloadSize) } seg.decRef() ackLeft -= datalen @@ -1501,13 +1413,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.ep.updateSndBufferUsage(int(acked)) // Clear SACK information for all acked data. - s.ep.scoreboard.Delete(s.sndUna) + s.ep.scoreboard.Delete(s.SndUna) // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. - if !s.fr.active { - s.cc.Update(originalOutstanding - s.outstanding) - if s.fr.last.LessThan(s.sndUna) { + if !s.FastRecovery.Active { + s.cc.Update(originalOutstanding - s.Outstanding) + if s.FastRecovery.Last.LessThan(s.SndUna) { s.state = tcpip.Open // Update RACK when we are exiting fast or RTO // recovery as described in the RFC @@ -1522,16 +1434,16 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // It is possible for s.outstanding to drop below zero if we get // a retransmit timeout, reset outstanding to zero but later // get an ack that cover previously sent data. - if s.outstanding < 0 { - s.outstanding = 0 + if s.Outstanding < 0 { + s.Outstanding = 0 } s.SetPipe() // If all outstanding data was acknowledged the disable the timer. // RFC 6298 Rule 5.3 - if s.sndUna == s.sndNxt { - s.outstanding = 0 + if s.SndUna == s.SndNxt { + s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() @@ -1539,7 +1451,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { // Update RACK reorder window. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: @@ -1549,7 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the // reorder window. - if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active { + if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.FastRecovery.Active { // If any segment is marked as lost by // RACK, enter recovery and retransmit // the lost segments. @@ -1558,19 +1470,19 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { fastRetransmit = true } - if s.fr.active { + if s.FastRecovery.Active { s.rc.DoRecovery(nil, fastRetransmit) } } // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { + if s.FastRecovery.Active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { s.lr.DoRecovery(rcvdSeg, fastRetransmit) // When SACK is enabled data sending is governed by steps in // RFC 6675 Section 5 recovery steps A-C. // See: https://tools.ietf.org/html/rfc6675#section-5. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return } } @@ -1587,7 +1499,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() - if s.sndCwnd < s.sndSsthresh { + if s.SndCwnd < s.Ssthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } @@ -1601,11 +1513,11 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // then use the conservative timer described in RFC6675 Section 6.0, // otherwise follow the standard time described in RFC6298 Section 5.1. if err != nil && seg.data.Size() != 0 { - if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted { - s.resendTimer.enable(s.rto) + if s.FastRecovery.Active && seg.xmitCount > 1 && s.ep.SACKPermitted { + s.resendTimer.enable(s.RTO) } else { if !s.resendTimer.enabled() { - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1616,15 +1528,15 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.lastSendTime = time.Now() - if seq == s.rttMeasureSeqNum { - s.rttMeasureTime = s.lastSendTime + s.LastSendTime = time.Now() + if seq == s.RTTMeasureSeqNum { + s.RTTMeasureTime = s.LastSendTime } rcvNxt, rcvWnd := s.ep.rcv.getSendParams() // Remember the max sent ack. - s.maxSentAck = rcvNxt + s.MaxSentAck = rcvNxt return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index ba41cff6d..2f805d8ce 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -24,26 +24,6 @@ type unixTime struct { nano int64 } -// saveLastSendTime is invoked by stateify. -func (s *sender) saveLastSendTime() unixTime { - return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (s *sender) loadLastSendTime(unix unixTime) { - s.lastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (s *sender) saveRttMeasureTime() unixTime { - return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (s *sender) loadRttMeasureTime(unix unixTime) { - s.rttMeasureTime = time.Unix(unix.second, unix.nano) -} - // afterLoad is invoked by stateify. func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 81e7dc36e..c58361bc1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -856,8 +856,8 @@ func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan return } - if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { - probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT) return } |