summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/accept.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go39
1 files changed, 22 insertions, 17 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d469758eb..85049e54e 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
- n := newEndpoint(l.stack, netProto, nil)
+ n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6only
n.ID = s.id
n.boundNICID = s.route.NICID()
@@ -236,6 +236,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
n.amss = mssForRoute(&n.route)
+ n.setEndpointState(StateConnecting)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
@@ -271,18 +272,19 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
return n, nil
}
-// createEndpoint creates a new endpoint in connected state and then performs
-// the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+// createEndpointAndPerformHandshake creates a new endpoint in connected state
+// and then performs the TCP 3-way handshake.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
if err != nil {
return nil, err
}
// listenEP is nil when listenContext is used by tcp.Forwarder.
+ deferAccept := time.Duration(0)
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
@@ -290,15 +292,21 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+ deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
// Perform the 3-way handshake.
- h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
-
- h.resetToSynRcvd(isn, irs, opts)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
@@ -377,16 +385,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer e.decSynRcvdCount()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(n)
@@ -546,7 +552,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -576,8 +582,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// space available in the backlog.
// Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
@@ -610,7 +615,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
s := sleep.Sleeper{}