diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 37 | ||||
-rw-r--r-- | test/packetimpact/tests/BUILD | 2 | ||||
-rw-r--r-- | test/syscalls/linux/socket_inet_loopback.cc | 62 |
6 files changed, 138 insertions, 30 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7a9dea4ac..e07b436c4 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -330,6 +330,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if l.listenEP != nil { l.removePendingEndpoint(ep) } + + ep.drainClosingSegmentQueue() + return nil, err } ep.isConnectNotified = true @@ -378,7 +381,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { for { if e.acceptedChan == nil { e.acceptMu.Unlock() - n.Close() + n.notifyProtocolGoroutine(notifyReset) return } select { @@ -656,6 +659,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { } e.mu.Unlock() + e.drainClosingSegmentQueue() + // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 2ca3fb809..994ac52a3 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1062,6 +1062,20 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { } } +// Drain segment queue from the endpoint and try to re-match the segment to a +// different endpoint. This is used when the current endpoint is transitioned to +// StateClose and has been unregistered from the transport demuxer. +func (e *endpoint) drainClosingSegmentQueue() { + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } +} + func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states @@ -1315,6 +1329,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.mu.Unlock() + + e.drainClosingSegmentQueue() + // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1565,19 +1582,6 @@ loop: // Lock released below. epilogue() - // epilogue removes the endpoint from the transport-demuxer and - // unlocks e.mu. Now that no new segments can get enqueued to this - // endpoint, try to re-match the segment to a different endpoint - // as the current endpoint is closed. - for { - s := e.segmentQueue.dequeue() - if s == nil { - break - } - - e.tryDeliverSegmentFromClosedEndpoint(s) - } - // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue if reuseTW != nil { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a8d443f73..7ed78d57f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -980,25 +980,22 @@ func (e *endpoint) closeNoShutdownLocked() { // Mark endpoint as closed. e.closed = true + + switch e.EndpointState() { + case StateClose, StateError: + return + } + // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. - switch e.EndpointState() { - // Sockets in StateSynRecv state(passive connections) are closed when - // the handshake fails or if the listening socket is closed while - // handshake was in progress. In such cases the handshake goroutine - // is already gone by the time Close is called and we need to cleanup - // here. - case StateInitial, StateBound, StateSynRecv: - e.cleanupLocked() - e.setEndpointState(StateClose) - case StateError, StateClose: - // do nothing. - default: + if e.workerRunning { e.workerCleanup = true tcpip.AddDanglingEndpoint(e) // Worker will remove the dangling endpoint when the endpoint // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) + } else { + e.transitionToStateCloseLocked() } } @@ -1010,13 +1007,18 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Unlock() return } - close(e.acceptedChan) + ch := e.acceptedChan e.acceptedChan = nil e.acceptCond.Broadcast() e.acceptMu.Unlock() - // Wait for all pending endpoints to close. + // Reset all connections that are waiting to be accepted. + for n := range ch { + n.notifyProtocolGoroutine(notifyReset) + } + // Wait for reset of all endpoints that are still waiting to be delivered to + // the now closed acceptedChan. e.pendingAccepted.Wait() } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 41caa9ed4..a9f121c17 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1068,6 +1068,43 @@ func TestListenShutdown(t *testing.T) { c.CheckNoPacket("Packet received when listening socket was shutdown") } +// TestListenCloseWhileConnect tests for the listening endpoint to +// drain the accept-queue when closed. This should reset all of the +// pending connections that are waiting to be accepted. +func TestListenCloseWhileConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventIn) + defer c.WQ.EventUnregister(&waitEntry) + + executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + // Wait for the new endpoint created because of handshake to be delivered + // to the listening endpoint's accept queue. + <-notifyCh + + // Close the listening endpoint. + c.EP.Close() + + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 308590162..1274d9f60 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -43,8 +43,6 @@ packetimpact_go_test( packetimpact_go_test( name = "tcp_noaccept_close_rst", srcs = ["tcp_noaccept_close_rst_test.go"], - # TODO(b/153380909): Fix netstack then remove the line below. - netstack = False, deps = [ "//pkg/tcpip/header", "//test/packetimpact/testbench", diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 71bd7c14d..cd84e633a 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -365,6 +365,68 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } } +TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kClients = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + clients.push_back(std::move(client)); + } + } + // Close the listening socket. + listen_fd.reset(); + + for (auto& client : clients) { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = client.get(), + .events = POLLIN, + }; + // When the listening socket is closed, then we expect the remote to reset + // the connection. + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + char c; + // Subsequent read can fail with: + // ECONNRESET: If the client connection was established and was reset by the + // remote. ECONNREFUSED: If the client connection failed to be established. + ASSERT_THAT(read(client.get(), &c, sizeof(c)), + AnyOf(SyscallFailsWithErrno(ECONNRESET), + SyscallFailsWithErrno(ECONNREFUSED))); + } +} + TEST_P(SocketInetLoopbackTest, TCPbacklog) { auto const& param = GetParam(); |