summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMithun Iyer <iyerm@google.com>2020-04-16 17:57:06 -0700
committergVisor bot <gvisor-bot@google.com>2020-04-16 17:58:08 -0700
commit3b05f576d73be644daa17203d9ed64481c45b4a8 (patch)
treed856cc675e646fbab32be15c2c3f32eaa48f27bc
parentb33c3bb4a73974bbae4274da5100a3cd3f5deef8 (diff)
Reset pending connections on listener shutdown.
When the listening socket is read shutdown, we need to reset all pending and incoming connections. Ensure that the endpoint is not cleaned up from the demuxer and subsequent bind to same port does not go through. PiperOrigin-RevId: 306958038
-rw-r--r--pkg/tcpip/transport/tcp/accept.go20
-rw-r--r--pkg/tcpip/transport/tcp/connect.go7
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go16
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc94
7 files changed, 138 insertions, 39 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index b61c2a8c3..5bb243e3b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -433,19 +432,16 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
- if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) {
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
+ if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
+ // If the endpoint is shutdown, reply with reset.
+ //
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- e.sendTCP(&s.route, tcpFields{
- id: s.id,
- ttl: e.ttl,
- tos: e.sendTOS,
- flags: header.TCPFlagRst,
- seq: s.ackNumber,
- ack: 0,
- rcvWnd: 0,
- }, buffer.VectorisedView{}, nil)
+ replyWithReset(s, e.sendTOS, e.ttl)
return
}
@@ -534,7 +530,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s)
+ replyWithReset(s, e.sendTOS, e.ttl)
return
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 994ac52a3..368865911 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1053,10 +1053,15 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
}
if ep == nil {
- replyWithReset(s)
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
s.decRef()
return
}
+
+ if e == ep {
+ panic("current endpoint not removed from demuxer, enqueing segments to itself")
+ }
+
if ep.(*endpoint).enqueueSegment(s) {
ep.(*endpoint).newSegmentWaker.Assert()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index bffc59e9f..5d0ea9e93 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2101,7 +2101,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
switch {
case e.EndpointState().connected():
// Close for read.
- if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
// Mark read side as closed.
e.rcvListMu.Lock()
e.rcvClosed = true
@@ -2110,7 +2110,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
- if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 {
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
// Wake up worker to terminate loop.
e.notifyProtocolGoroutine(notifyTickleWorker)
@@ -2119,7 +2119,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Close for write.
- if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
e.sndBufMu.Lock()
if e.sndClosed {
// Already closed.
@@ -2142,12 +2142,23 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
return nil
case e.EndpointState() == StateListen:
- // Tell protocolListenLoop to stop.
- if flags&tcpip.ShutdownRead != 0 {
- e.notifyProtocolGoroutine(notifyClose)
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
+ // Reset all connections from the accept queue and keep the
+ // worker running so that it can continue handling incoming
+ // segments by replying with RST.
+ //
+ // By not removing this endpoint from the demuxer mapping, we
+ // ensure that any other bind to the same port fails, as on Linux.
+ // TODO(gvisor.dev/issue/2468): We need to enable applications to
+ // start listening on this endpoint again similar to Linux.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ e.rcvListMu.Unlock()
+ e.closePendingAcceptableConnectionsLocked()
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}
return nil
-
default:
return tcpip.ErrNotConnected
}
@@ -2251,8 +2262,11 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
- if e.EndpointState() != StateListen {
+ if rcvClosed || e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 808410c92..704d01c64 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -130,7 +130,7 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
// If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment)
+ replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
}
// Release all resources.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index effbf203f..cfd9a4e8e 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -223,12 +223,12 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo
return true
}
- replyWithReset(s)
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
return true
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment) {
+func replyWithReset(s *segment, tos, ttl uint8) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
@@ -252,8 +252,8 @@ func replyWithReset(s *segment) {
}
sendTCP(&s.route, tcpFields{
id: s.id,
- ttl: s.route.DefaultTTL(),
- tos: stack.DefaultTOS,
+ ttl: ttl,
+ tos: tos,
flags: flags,
seq: seq,
ack: ack,
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 74fb6e064..ab1014c7f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1034,8 +1034,8 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.SeqNum(200)))
}
-// TestListenShutdown tests for the listening endpoint not processing
-// any receive when it is on read shutdown.
+// TestListenShutdown tests for the listening endpoint replying with RST
+// on read shutdown.
func TestListenShutdown(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -1046,7 +1046,7 @@ func TestListenShutdown(t *testing.T) {
t.Fatal("Bind failed:", err)
}
- if err := c.EP.Listen(10 /* backlog */); err != nil {
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
t.Fatal("Listen failed:", err)
}
@@ -1054,9 +1054,6 @@ func TestListenShutdown(t *testing.T) {
t.Fatal("Shutdown failed:", err)
}
- // Wait for the endpoint state to be propagated.
- time.Sleep(10 * time.Millisecond)
-
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -1065,7 +1062,12 @@ func TestListenShutdown(t *testing.T) {
AckNum: 200,
})
- c.CheckNoPacket("Packet received when listening socket was shutdown")
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
}
// TestListenCloseWhileConnect tests for the listening endpoint to
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index cd84e633a..d3000dbc6 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -319,6 +319,75 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) {
tcpSimpleConnectTest(listener, connector, false);
}
+TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kFDs = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ // Shutdown the write of the listener, expect to not have any effect.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds());
+
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
+ }
+
+ // Shutdown the read of the listener, expect to fail subsequent
+ // server accepts, binds and client connects.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Check that shutdown did not release the port.
+ FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Check that subsequent connection attempts receive a RST.
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallFailsWithErrno(ECONNREFUSED));
+ }
+}
+
TEST_P(SocketInetLoopbackTest, TCPListenClose) {
auto const& param = GetParam();
@@ -365,9 +434,8 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
}
}
-TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
- auto const& param = GetParam();
-
+void TestListenWhileConnect(const TestParam& param,
+ void (*stopListen)(FileDescriptor&)) {
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
@@ -404,8 +472,8 @@ TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
clients.push_back(std::move(client));
}
}
- // Close the listening socket.
- listen_fd.reset();
+
+ stopListen(listen_fd);
for (auto& client : clients) {
const int kTimeout = 10000;
@@ -420,13 +488,26 @@ TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
char c;
// Subsequent read can fail with:
// ECONNRESET: If the client connection was established and was reset by the
- // remote. ECONNREFUSED: If the client connection failed to be established.
+ // remote.
+ // ECONNREFUSED: If the client connection failed to be established.
ASSERT_THAT(read(client.get(), &c, sizeof(c)),
AnyOf(SyscallFailsWithErrno(ECONNRESET),
SyscallFailsWithErrno(ECONNREFUSED)));
}
}
+TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
+}
+
TEST_P(SocketInetLoopbackTest, TCPbacklog) {
auto const& param = GetParam();
@@ -1134,6 +1215,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
if (connects_received >= kConnectAttempts) {
// Another thread have shutdown our read side causing the
// accept to fail.
+ ASSERT_EQ(errno, EINVAL);
break;
}
ASSERT_NO_ERRNO(fd);