summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go24
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go7
-rw-r--r--test/syscalls/linux/tcp_socket.cc48
3 files changed, 66 insertions, 13 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 89154391b..cc49c8272 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -489,8 +489,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result |= waiter.EventIn
}
}
-
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ }
+ if e.state.connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -1323,6 +1323,17 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
return err
}
+ if e.state.connected() {
+ // The endpoint is already connected. If caller hasn't been
+ // notified yet, return success.
+ if !e.isConnectNotified {
+ e.isConnectNotified = true
+ return nil
+ }
+ // Otherwise return that it's already connected.
+ return tcpip.ErrAlreadyConnected
+ }
+
nicid := addr.NIC
switch e.state {
case StateBound:
@@ -1347,15 +1358,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// yet.
return tcpip.ErrAlreadyConnecting
- case StateEstablished:
- // The endpoint is already connected. If caller hasn't been notified yet, return success.
- if !e.isConnectNotified {
- e.isConnectNotified = true
- return nil
- }
- // Otherwise return that it's already connected.
- return tcpip.ErrAlreadyConnected
-
case StateError:
return e.hardError
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b93959034..b3f0f6c5d 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -50,6 +50,8 @@ func (e *endpoint) beforeSave() {
switch e.state {
case StateInitial, StateBound:
+ // TODO(b/138137272): this enumeration duplicates
+ // EndpointState.connected. remove it.
case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
@@ -149,9 +151,10 @@ var connectingLoading sync.WaitGroup
func (e *endpoint) loadState(state EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ if state.connected() {
connectedLoading.Add(1)
+ }
+ switch state {
case StateListen:
listenLoading.Add(1)
case StateConnecting, StateSynSent, StateSynRecv:
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 77aab1e7d..8f4d3f386 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -856,6 +856,54 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnect) {
EXPECT_THAT(close(t), SyscallSucceeds());
}
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) {
+ const FileDescriptor listener =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(
+ bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+
+ ASSERT_THAT(getsockname(listener.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ int t;
+ ASSERT_THAT(t = RetryEINTR(accept)(listener.get(), nullptr, nullptr),
+ SyscallSucceeds());
+
+ EXPECT_THAT(close(t), SyscallSucceeds());
+
+ // Now polling on the FD with a timeout should return 0 corresponding to no
+ // FDs ready.
+ struct pollfd poll_fd = {s.get(), POLLOUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EISCONN));
+}
+
// Test that we get an ECONNREFUSED with a blocking socket when no one is
// listening on the other end.
TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) {