diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 39 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 80 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect_unsafe.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/dispatcher.go | 31 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 139 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment_heap.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 195 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 7 |
12 files changed, 477 insertions, 70 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4acd9fb9a..a32f9eacf 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -32,6 +32,7 @@ go_library( srcs = [ "accept.go", "connect.go", + "connect_unsafe.go", "cubic.go", "cubic_state.go", "dispatcher.go", @@ -57,6 +58,7 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", @@ -90,6 +92,7 @@ go_test( tags = ["flaky"], deps = [ ":tcp", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d469758eb..85049e54e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { netProto = s.route.NetProto } - n := newEndpoint(l.stack, netProto, nil) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6only n.ID = s.id n.boundNICID = s.route.NICID() @@ -236,6 +236,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) n.amss = mssForRoute(&n.route) + n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) @@ -271,18 +272,19 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n, nil } -// createEndpoint creates a new endpoint in connected state and then performs -// the TCP 3-way handshake. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) if err != nil { return nil, err } // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { @@ -290,15 +292,21 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } // Perform the 3-way handshake. - h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - - h.resetToSynRcvd(isn, irs, opts) + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + if l.listenEP != nil { l.removePendingEndpoint(ep) } @@ -377,16 +385,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer e.decSynRcvdCount() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(n) @@ -546,7 +552,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -576,8 +582,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // space available in the backlog. // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } @@ -610,7 +615,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Unlock() // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4e3c5419c..c0f73ef16 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -86,6 +86,19 @@ type handshake struct { // rcvWndScale is the receive window scale, as defined in RFC 1323. rcvWndScale int + + // startTime is the time at which the first SYN/SYN-ACK was sent. + startTime time.Time + + // deferAccept if non-zero will drop the final ACK for a passive + // handshake till an ACK segment with data is received or the timeout is + // hit. + deferAccept time.Duration + + // acked is true if the the final ACK for a 3-way handshake has + // been received. This is required to stop retransmitting the + // original SYN-ACK when deferAccept is enabled. + acked bool } func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { @@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { return h } +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) + h.resetToSynRcvd(isn, irs, opts, deferAccept) + return h +} + // FindWndScale determines the window scale to use for the given maximum window // size. func FindWndScale(wnd seqnum.Size) int { @@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS + h.deferAccept = deferAccept h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() @@ -275,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() ttl := h.ep.ttl + amss := h.ep.amss h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ @@ -287,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // permits SACK. This is not explicitly defined in the RFC but // this is the behaviour implemented by Linux. SACKPermitted: rcvSynOpts.SACKPermitted, - MSS: h.ep.amss, + MSS: amss, } if ttl == 0 { ttl = s.route.DefaultTTL() @@ -336,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + h.ep.mu.RLock() + amss := h.ep.amss + h.ep.mu.RUnlock() + h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -343,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, - MSS: h.ep.amss, + MSS: amss, } h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -352,6 +377,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { + // If deferAccept is not zero and this is a bare ACK and the + // timeout is not hit then drop the ACK. + if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + h.acked = true + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. @@ -365,10 +398,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver + // to process it again once main loop is started. + if s.data.Size() > 0 { + s.incRef() + h.ep.enqueueSegment(s) + } h.ep.mu.Unlock() - return nil } @@ -471,6 +510,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.startTime = time.Now() // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) @@ -495,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. + h.ep.mu.Lock() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ @@ -505,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error { SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } + h.ep.mu.Unlock() // Execute is also called in a listen context so we want to make sure we // only send the TS/SACK option when we received the TS/SACK in the @@ -524,15 +566,25 @@ func (h *handshake) execute() *tcpip.Error { switch index, _ := s.Fetch(true); index { case wakerForResend: timeOut *= 2 - if timeOut > 60*time.Second { + if timeOut > MaxRTO { return tcpip.ErrTimeout } rt.Reset(timeOut) - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + // Resend the SYN/SYN-ACK only if the following conditions hold. + // - It's an active handshake (deferAccept does not apply) + // - It's a passive handshake and we have not yet got the final-ACK. + // - It's a passive handshake and we got an ACK but deferAccept is + // enabled and we are now past the deferAccept duration. + // The last is required to provide a way for the peer to complete + // the connection with another ACK or data (as ACKs are never + // retransmitted on their own). + if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + } case wakerForNotification: n := h.ep.fetchNotifications() - if n¬ifyClose != 0 { + if (n¬ifyClose)|(n¬ifyAbort) != 0 { return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -572,17 +624,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions { var optionPool = sync.Pool{ New: func() interface{} { - return make([]byte, maxOptionSize) + return &[maxOptionSize]byte{} }, } func getOptions() []byte { - return optionPool.Get().([]byte) + return (*optionPool.Get().(*[maxOptionSize]byte))[:] } func putOptions(options []byte) { // Reslice to full capacity. - optionPool.Put(options[0:cap(options)]) + optionPool.Put(optionsToArray(options)) } func makeSynOptions(opts header.TCPSynOptions) []byte { @@ -944,6 +996,10 @@ func (e *endpoint) transitionToStateCloseLocked() { // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + // Dual-stack socket, try IPv4. + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + } if ep == nil { replyWithReset(s) s.decRef() @@ -1323,7 +1379,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { + if n¬ifyReset != 0 || n¬ifyAbort != 0 { return tcpip.ErrConnectionAborted } @@ -1606,7 +1662,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 { + if n¬ifyClose != 0 || n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go new file mode 100644 index 000000000..cfc304616 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect_unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "reflect" + "unsafe" +) + +// optionsToArray converts a slice of capacity >-= maxOptionSize to an array. +// +// optionsToArray panics if the capacity of options is smaller than +// maxOptionSize. +func optionsToArray(options []byte) *[maxOptionSize]byte { + // Reslice to full capacity. + options = options[0:maxOptionSize] + return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data)) +} diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index e18012ac0..d792b07d6 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -68,17 +68,28 @@ func (q *epQueue) empty() bool { type processor struct { epQ epQueue newEndpointWaker sleep.Waker + closeWaker sleep.Waker id int + wg sync.WaitGroup } func newProcessor(id int) *processor { p := &processor{ id: id, } + p.wg.Add(1) go p.handleSegments() return p } +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) wait() { + p.wg.Wait() +} + func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) @@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) handleSegments() { const newEndpointWaker = 1 + const closeWaker = 2 s := sleep.Sleeper{} s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + s.AddWaker(&p.closeWaker, closeWaker) defer s.Done() for { - s.Fetch(true) + id, ok := s.Fetch(true) + if ok && id == closeWaker { + p.wg.Done() + return + } for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { if ep.segmentQueue.empty() { continue @@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher { } } +func (d *dispatcher) close() { + for _, p := range d.processors { + p.close() + } +} + +func (d *dispatcher) wait() { + for _, p := range d.processors { + p.wait() + } +} + func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 13718ff55..dc9c18b6f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,8 @@ const ( notifyDrain notifyReset notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -498,6 +500,13 @@ type endpoint struct { // without any data being acked. userTimeout time.Duration + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -778,6 +787,38 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + // + // If the endpoint connecting finishes after the check, the endpoint + // is either in a connected state (where we would notifyAbort anyway), + // SYN-RECV (where we would also notifyAbort anyway), or in an error + // state where nothing is required and the notification can be safely + // ignored. + // + // Endpoints where a Close during connecting or SYN-RECV state would be + // problematic are set to state connecting before being registered (and + // thus possible to be Aborted). They are never available in initial + // state. + // + // Endpoints transitioning from initial to connecting state may be + // safely either closed or sent notifyAbort. + if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. @@ -822,9 +863,18 @@ func (e *endpoint) closeNoShutdown() { // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { + switch e.EndpointState() { + // Sockets in StateSynRecv state(passive connections) are closed when + // the handshake fails or if the listening socket is closed while + // handshake was in progress. In such cases the handshake goroutine + // is already gone by the time Close is called and we need to cleanup + // here. + case StateInitial, StateBound, StateSynRecv: e.cleanupLocked() - } else { + e.setEndpointState(StateClose) + case StateError, StateClose: + // do nothing. + default: e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } @@ -909,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int { // ModerateRecvBuf adjusts the receive buffer and the advertised window // based on the number of bytes copied to user space. func (e *endpoint) ModerateRecvBuf(copied int) { + e.mu.RLock() e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() + e.mu.RUnlock() return } now := time.Now() if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { e.rcvAutoParams.copied += copied e.rcvListMu.Unlock() + e.mu.RUnlock() return } prevRTTCopied := e.rcvAutoParams.copied + copied @@ -958,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvBufSize = rcvWnd availAfter := e.receiveBufferAvailableLocked() mask := uint32(notifyReceiveWindowChanged) - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.notifyProtocolGoroutine(mask) @@ -973,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvAutoParams.measureTime = now e.rcvAutoParams.copied = 0 e.rcvListMu.Unlock() + e.mu.RUnlock() } // IPTables implements tcpip.Endpoint.IPTables. @@ -996,13 +1050,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + e.stats.ReadErrors.NotConnected.Increment() + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected } v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() if err == tcpip.ErrClosedForReceive { @@ -1035,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1253,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// windowCrossedACKThreshold checks if the receive window to be announced now -// would be under aMSS or under half receive buffer, whichever smaller. This is -// useful as a receive side silly window syndrome prevention mechanism. If +// windowCrossedACKThresholdLocked checks if the receive window to be announced +// now would be under aMSS or under half receive buffer, whichever smaller. This +// is useful as a receive side silly window syndrome prevention mechanism. If // window grows to reasonable value, we should send ACK to the sender to inform // the rx space is now large. We also want ensure a series of small read()'s // won't trigger a flood of spurious tiny ACK's. @@ -1266,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // crossed will be true if the window size crossed the ACK threshold. // above will be true if the new window is >= ACK threshold and false // otherwise. -func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) { +// +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { newAvail := e.receiveBufferAvailableLocked() oldAvail := newAvail - deltaBefore if oldAvail < 0 { @@ -1329,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { mask := uint32(notifyReceiveWindowChanged) + e.mu.RLock() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1355,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Immediately send an ACK to uncork the sender silly window // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.notifyProtocolGoroutine(mask) return nil @@ -1574,6 +1630,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPDeferAcceptOption: + e.mu.Lock() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1798,18 +1863,25 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.TCPDeferAcceptOption: + e.mu.Lock() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1839,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2025,8 +2097,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // work mutex is available. if e.workMu.TryLock() { e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.notifyProtocolGoroutine(notifyTickleWorker) + // We need to double check here to make + // sure worker has not transitioned the + // endpoint out of a connected state + // before trying to send a reset. + if e.EndpointState().connected() { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.notifyProtocolGoroutine(notifyTickleWorker) + } e.mu.Unlock() e.workMu.Unlock() } else { @@ -2039,10 +2117,13 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Close for write. if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { e.sndBufMu.Lock() - if e.sndClosed { // Already closed. e.sndBufMu.Unlock() + if e.EndpointState() == StateTimeWait { + e.mu.Unlock() + return tcpip.ErrNotConnected + } break } @@ -2138,6 +2219,9 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { e.isRegistered = true e.setEndpointState(StateListen) + // The channel may be non-nil when we're restoring the endpoint, and it + // may be pre-populated with some previously accepted (but not Accepted) + // endpoints. if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } @@ -2149,9 +2233,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { +func (e *endpoint) startAcceptedLoop() { e.mu.Lock() - e.waiterQueue = waiterQueue e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2177,7 +2260,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } - return n, n.waiterQueue, nil } @@ -2198,7 +2280,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2342,13 +2424,14 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { + e.mu.RLock() e.rcvListMu.Lock() if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() // Increase counter if the receive window falls down below MSS // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above { + if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } e.rcvList.PushBack(s) @@ -2356,7 +2439,7 @@ func (e *endpoint) readyToRead(s *segment) { e.rcvClosed = true } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventIn) } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 7eb613be5..c9ee5bf06 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }) + }, queue) if err != nil { return nil, err } // Start the protocol goroutine. - ep.startAcceptedLoop(queue) + ep.startAcceptedLoop() return ep, nil } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 958c06fa7..73098d904 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -194,7 +194,7 @@ func replyWithReset(s *segment) { sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: @@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 958f03ac1..d80aff1b6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -195,6 +195,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum for i := first; i < len(r.pendingRcvdSegments); i++ { r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of + // truncated items, thus truncated items must be set to nil to avoid + // memory leaks. + r.pendingRcvdSegments[i] = nil } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go index 9fd061d7d..e28f213ba 100644 --- a/pkg/tcpip/transport/tcp/segment_heap.go +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -41,6 +41,7 @@ func (h *segmentHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = nil *h = old[:n-1] return x } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index df2fb1071..5b2b16afa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -542,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { ), ) - // Wait for the TIME-WAIT state to transition to CLOSED. - time.Sleep(1 * time.Second) + // Wait for a little more than the TIME-WAIT duration for the socket to + // transition to CLOSED state. + time.Sleep(1200 * time.Millisecond) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) @@ -5404,12 +5406,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } - // Expect InvalidEndpointState errors on a read at this point. - if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { + t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected) } - if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { - t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { + t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1) } if err := ep.Listen(10); err != nil { @@ -6787,3 +6788,183 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { ), ) } + +func TestTCPDeferAccept(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestTCPDeferAcceptTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) + + // On timeout expiry we should get a SYN-ACK retransmission. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestResetDuringClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + iss := seqnum.Value(789) + c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + // Send some data to make sure there is some unread + // data to trigger a reset on c.Close. + irs := c.IRS + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: irs.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(irs.Add(1))), + checker.AckNum(uint32(iss.Add(5))))) + + // Close in a separate goroutine so that we can trigger + // a race with the RST we send below. This should not + // panic due to the route being released depeding on + // whether Close() sends an active RST or the RST sent + // below is processed by the worker first. + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(5), + AckNum: c.IRS.Add(5), + RcvWnd: 30000, + Flags: header.TCPFlagRst, + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.EP.Close() + }() + + wg.Wait() +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 730ac4292..8cea20fb5 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -204,6 +204,7 @@ func (c *Context) Cleanup() { if c.EP != nil { c.EP.Close() } + c.Stack().Close() } // Stack returns a reference to the stack in the Context. @@ -1082,7 +1083,11 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { - c.linkEP.GSO = enable + if enable { + c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + } else { + c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + } } // MSSWithoutOptions returns the value for the MSS used by the stack when no |