diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 55 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 1 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ip_tcp_generic.cc | 33 |
5 files changed, 98 insertions, 2 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7b4a87a2d..272e8f570 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -91,6 +91,7 @@ go_test( tags = ["flaky"], deps = [ ":tcp", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8d52414b7..b5a8e15ee 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2047,8 +2047,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // work mutex is available. if e.workMu.TryLock() { e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.notifyProtocolGoroutine(notifyTickleWorker) + // We need to double check here to make + // sure worker has not transitioned the + // endpoint out of a connected state + // before trying to send a reset. + if e.EndpointState().connected() { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.notifyProtocolGoroutine(notifyTickleWorker) + } e.mu.Unlock() e.workMu.Unlock() } else { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a12336d47..2c1505067 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -6913,3 +6914,57 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SeqNum(uint32(iss+1)), checker.AckNum(uint32(irs+5)))) } + +func TestResetDuringClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + iss := seqnum.Value(789) + c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + // Send some data to make sure there is some unread + // data to trigger a reset on c.Close. + irs := c.IRS + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: irs.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(irs.Add(1))), + checker.AckNum(uint32(iss.Add(5))))) + + // Close in a separate goroutine so that we can trigger + // a race with the RST we send below. This should not + // panic due to the route being released depeding on + // whether Close() sends an active RST or the RST sent + // below is processed by the worker first. + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(5), + AckNum: c.IRS.Add(5), + RcvWnd: 30000, + Flags: header.TCPFlagRst, + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.EP.Close() + }() + + wg.Wait() +} diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 74bf068ec..7958fd0d7 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2173,6 +2173,7 @@ cc_library( ":socket_test_util", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/memory", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 57ce8e169..27779e47c 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -24,6 +24,7 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" @@ -875,5 +876,37 @@ TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { EXPECT_EQ(get, kAbove); } +TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { + DisableSave ds; // Too many syscalls. + constexpr int kThreadCount = 1000; + std::unique_ptr<ScopedThread> instances[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + instances[i] = absl::make_unique<ScopedThread>([&]() { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ScopedThread t([&]() { + // Close one end to trigger sending of a FIN. + struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; + // Wait up to 20 seconds for the data. + constexpr int kPollTimeoutMs = 20000; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); + }); + + // Send some data then close. + constexpr char kStr[] = "abc"; + ASSERT_THAT(write(sockets->first_fd(), kStr, 3), + SyscallSucceedsWithValue(3)); + absl::SleepFor(absl::Milliseconds(10)); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + t.Join(); + }); + } + for (int i = 0; i < kThreadCount; i++) { + instances[i]->Join(); + } +} + } // namespace testing } // namespace gvisor |