summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go5
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc43
3 files changed, 71 insertions, 20 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 45f2aa78b..07d3e64c8 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2158,8 +2158,6 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
//
// By not removing this endpoint from the demuxer mapping, we
// ensure that any other bind to the same port fails, as on Linux.
- // TODO(gvisor.dev/issue/2468): We need to enable applications to
- // start listening on this endpoint again similar to Linux.
e.rcvListMu.Lock()
e.rcvClosed = true
e.rcvListMu.Unlock()
@@ -2188,26 +2186,31 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
e.LockUser()
defer e.UnlockUser()
- // Allow the backlog to be adjusted if the endpoint is not shutting down.
- // When the endpoint shuts down, it sets workerCleanup to true, and from
- // that point onward, acceptedChan is the responsibility of the cleanup()
- // method (and should not be touched anywhere else, including here).
- if e.EndpointState() == StateListen && !e.workerCleanup {
- // Adjust the size of the channel iff we can fix existing
- // pending connections into the new one.
+ if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if len(e.acceptedChan) > backlog {
- return tcpip.ErrInvalidEndpointState
- }
- if cap(e.acceptedChan) == backlog {
- return nil
- }
- origChan := e.acceptedChan
- e.acceptedChan = make(chan *endpoint, backlog)
- close(origChan)
- for ep := range origChan {
- e.acceptedChan <- ep
+ if e.acceptedChan == nil {
+ // listen is called after shutdown.
+ e.acceptedChan = make(chan *endpoint, backlog)
+ e.shutdownFlags = 0
+ e.rcvListMu.Lock()
+ e.rcvClosed = false
+ e.rcvListMu.Unlock()
+ } else {
+ // Adjust the size of the channel iff we can fix
+ // existing pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
}
// Notify any blocked goroutines that they can attempt to
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index c3c692555..8b7562396 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -247,6 +247,11 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
+ e.LockUser()
+ if e.shutdownFlags != 0 {
+ e.shutdownLocked(e.shutdownFlags)
+ }
+ e.UnlockUser()
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index d3000dbc6..9400ffaeb 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -319,6 +319,49 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) {
tcpSimpleConnectTest(listener, connector, false);
}
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
+ const auto& param = GetParam();
+
+ const TestAddress& listener = param.listener;
+ const TestAddress& connector = param.connector;
+
+ constexpr int kBacklog = 5;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ for (int i = 0; i < kBacklog; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ }
+ for (int i = 0; i < kBacklog; i++) {
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
+ }
+}
+
TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
auto const& param = GetParam();