diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 7 | ||||
-rw-r--r-- | test/syscalls/linux/tcp_socket.cc | 48 |
3 files changed, 66 insertions, 13 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 89154391b..cc49c8272 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -489,8 +489,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result |= waiter.EventIn } } - - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + } + if e.state.connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -1323,6 +1323,17 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return err } + if e.state.connected() { + // The endpoint is already connected. If caller hasn't been + // notified yet, return success. + if !e.isConnectNotified { + e.isConnectNotified = true + return nil + } + // Otherwise return that it's already connected. + return tcpip.ErrAlreadyConnected + } + nicid := addr.NIC switch e.state { case StateBound: @@ -1347,15 +1358,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // yet. return tcpip.ErrAlreadyConnecting - case StateEstablished: - // The endpoint is already connected. If caller hasn't been notified yet, return success. - if !e.isConnectNotified { - e.isConnectNotified = true - return nil - } - // Otherwise return that it's already connected. - return tcpip.ErrAlreadyConnected - case StateError: return e.hardError diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b93959034..b3f0f6c5d 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -50,6 +50,8 @@ func (e *endpoint) beforeSave() { switch e.state { case StateInitial, StateBound: + // TODO(b/138137272): this enumeration duplicates + // EndpointState.connected. remove it. case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { @@ -149,9 +151,10 @@ var connectingLoading sync.WaitGroup func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + if state.connected() { connectedLoading.Add(1) + } + switch state { case StateListen: listenLoading.Add(1) case StateConnecting, StateSynSent, StateSynRecv: diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 77aab1e7d..8f4d3f386 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -856,6 +856,54 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnect) { EXPECT_THAT(close(t), SyscallSucceeds()); } +TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) { + const FileDescriptor listener = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT( + bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds()); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + + ASSERT_THAT(getsockname(listener.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + ASSERT_THAT(RetryEINTR(connect)( + s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EINPROGRESS)); + + int t; + ASSERT_THAT(t = RetryEINTR(accept)(listener.get(), nullptr, nullptr), + SyscallSucceeds()); + + EXPECT_THAT(close(t), SyscallSucceeds()); + + // Now polling on the FD with a timeout should return 0 corresponding to no + // FDs ready. + struct pollfd poll_fd = {s.get(), POLLOUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); + + ASSERT_THAT(RetryEINTR(connect)( + s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(RetryEINTR(connect)( + s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); +} + // Test that we get an ECONNREFUSED with a blocking socket when no one is // listening on the other end. TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) { |