summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go7
-rw-r--r--pkg/tcpip/transport/tcp/connect.go30
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go37
-rw-r--r--test/packetimpact/tests/BUILD2
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc62
6 files changed, 138 insertions, 30 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 7a9dea4ac..e07b436c4 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -330,6 +330,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
+
+ ep.drainClosingSegmentQueue()
+
return nil, err
}
ep.isConnectNotified = true
@@ -378,7 +381,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
for {
if e.acceptedChan == nil {
e.acceptMu.Unlock()
- n.Close()
+ n.notifyProtocolGoroutine(notifyReset)
return
}
select {
@@ -656,6 +659,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
}
e.mu.Unlock()
+ e.drainClosingSegmentQueue()
+
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 2ca3fb809..994ac52a3 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1062,6 +1062,20 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
}
}
+// Drain segment queue from the endpoint and try to re-match the segment to a
+// different endpoint. This is used when the current endpoint is transitioned to
+// StateClose and has been unregistered from the transport demuxer.
+func (e *endpoint) drainClosingSegmentQueue() {
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+}
+
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
@@ -1315,6 +1329,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.mu.Unlock()
+
+ e.drainClosingSegmentQueue()
+
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1565,19 +1582,6 @@ loop:
// Lock released below.
epilogue()
- // epilogue removes the endpoint from the transport-demuxer and
- // unlocks e.mu. Now that no new segments can get enqueued to this
- // endpoint, try to re-match the segment to a different endpoint
- // as the current endpoint is closed.
- for {
- s := e.segmentQueue.dequeue()
- if s == nil {
- break
- }
-
- e.tryDeliverSegmentFromClosedEndpoint(s)
- }
-
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a8d443f73..7ed78d57f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -980,25 +980,22 @@ func (e *endpoint) closeNoShutdownLocked() {
// Mark endpoint as closed.
e.closed = true
+
+ switch e.EndpointState() {
+ case StateClose, StateError:
+ return
+ }
+
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
- switch e.EndpointState() {
- // Sockets in StateSynRecv state(passive connections) are closed when
- // the handshake fails or if the listening socket is closed while
- // handshake was in progress. In such cases the handshake goroutine
- // is already gone by the time Close is called and we need to cleanup
- // here.
- case StateInitial, StateBound, StateSynRecv:
- e.cleanupLocked()
- e.setEndpointState(StateClose)
- case StateError, StateClose:
- // do nothing.
- default:
+ if e.workerRunning {
e.workerCleanup = true
tcpip.AddDanglingEndpoint(e)
// Worker will remove the dangling endpoint when the endpoint
// goroutine terminates.
e.notifyProtocolGoroutine(notifyClose)
+ } else {
+ e.transitionToStateCloseLocked()
}
}
@@ -1010,13 +1007,18 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Unlock()
return
}
-
close(e.acceptedChan)
+ ch := e.acceptedChan
e.acceptedChan = nil
e.acceptCond.Broadcast()
e.acceptMu.Unlock()
- // Wait for all pending endpoints to close.
+ // Reset all connections that are waiting to be accepted.
+ for n := range ch {
+ n.notifyProtocolGoroutine(notifyReset)
+ }
+ // Wait for reset of all endpoints that are still waiting to be delivered to
+ // the now closed acceptedChan.
e.pendingAccepted.Wait()
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 41caa9ed4..a9f121c17 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1068,6 +1068,43 @@ func TestListenShutdown(t *testing.T) {
c.CheckNoPacket("Packet received when listening socket was shutdown")
}
+// TestListenCloseWhileConnect tests for the listening endpoint to
+// drain the accept-queue when closed. This should reset all of the
+// pending connections that are waiting to be accepted.
+func TestListenCloseWhileConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventIn)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+ // Wait for the new endpoint created because of handshake to be delivered
+ // to the listening endpoint's accept queue.
+ <-notifyCh
+
+ // Close the listening endpoint.
+ c.EP.Close()
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
func TestTOSV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 308590162..1274d9f60 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -43,8 +43,6 @@ packetimpact_go_test(
packetimpact_go_test(
name = "tcp_noaccept_close_rst",
srcs = ["tcp_noaccept_close_rst_test.go"],
- # TODO(b/153380909): Fix netstack then remove the line below.
- netstack = False,
deps = [
"//pkg/tcpip/header",
"//test/packetimpact/testbench",
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 71bd7c14d..cd84e633a 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -365,6 +365,68 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
}
}
+TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kClients = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ clients.push_back(std::move(client));
+ }
+ }
+ // Close the listening socket.
+ listen_fd.reset();
+
+ for (auto& client : clients) {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = client.get(),
+ .events = POLLIN,
+ };
+ // When the listening socket is closed, then we expect the remote to reset
+ // the connection.
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+ char c;
+ // Subsequent read can fail with:
+ // ECONNRESET: If the client connection was established and was reset by the
+ // remote. ECONNREFUSED: If the client connection failed to be established.
+ ASSERT_THAT(read(client.get(), &c, sizeof(c)),
+ AnyOf(SyscallFailsWithErrno(ECONNRESET),
+ SyscallFailsWithErrno(ECONNREFUSED)));
+ }
+}
+
TEST_P(SocketInetLoopbackTest, TCPbacklog) {
auto const& param = GetParam();