summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go89
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go86
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go49
-rw-r--r--test/syscalls/BUILD4
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc61
5 files changed, 261 insertions, 28 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 52fd1bfa3..e9c5099ea 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -96,6 +96,17 @@ type listenContext struct {
hasher hash.Hash
v6only bool
netProto tcpip.NetworkProtocolNumber
+ // pendingMu protects pendingEndpoints. This should only be accessed
+ // by the listening endpoint's worker goroutine.
+ //
+ // Lock Ordering: listenEP.workerMu -> pendingMu
+ pendingMu sync.Mutex
+ // pending is used to wait for all pendingEndpoints to finish when
+ // a socket is closed.
+ pending sync.WaitGroup
+ // pendingEndpoints is a map of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -133,14 +144,15 @@ func decSynRcvdCount() {
}
// newListenContext creates a new listen context.
-func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stack,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6only: v6only,
- netProto: netProto,
- listenEP: listenEP,
+ stack: stk,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ listenEP: listenEP,
+ pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
rand.Read(l.nonce[0][:])
@@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
+ // listenEP is nil when listenContext is used by tcp.Forwarder.
+ if l.listenEP != nil {
+ l.listenEP.mu.Lock()
+ if l.listenEP.state != StateListen {
+ l.listenEP.mu.Unlock()
+ return nil, tcpip.ErrConnectionAborted
+ }
+ l.addPendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
@@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if err := h.execute(); err != nil {
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
return nil, err
}
ep.mu.Lock()
@@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return ep, nil
}
+func (l *listenContext) addPendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ l.pendingEndpoints[n.id] = n
+ l.pending.Add(1)
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) removePendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ delete(l.pendingEndpoints, n.id)
+ l.pending.Done()
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) closeAllPendingEndpoints() {
+ l.pendingMu.Lock()
+ for _, n := range l.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ l.pendingMu.Unlock()
+ l.pending.Wait()
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state, the new endpoint is closed
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
- e.mu.RLock()
+ e.mu.Lock()
state := e.state
- e.mu.RUnlock()
+ e.pendingAccepted.Add(1)
+ defer e.pendingAccepted.Done()
+ acceptedChan := e.acceptedChan
+ e.mu.Unlock()
if state == StateListen {
- e.acceptedChan <- n
+ acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
n.Close()
@@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
return
}
-
+ ctx.removePendingEndpoint(n)
e.deliverAccepted(n)
}
@@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
@@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
e.state = StateClose
+ // close any endpoints in SYN-RCVD state.
+ ctx.closeAllPendingEndpoints()
+
// Do cleanup if needed.
e.completeWorkerLocked()
@@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
}()
- e.mu.Lock()
- v6only := e.v6only
- e.mu.Unlock()
-
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
-
s := sleep.Sleeper{}
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
@@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.handleListenSegment(ctx, s)
s.decRef()
}
- synRcvdCount.pending.Wait()
close(e.drainDone)
<-e.undrain
}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index d9f79e8c5..c54610a87 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -570,3 +570,89 @@ func TestV4AcceptOnV4(t *testing.T) {
// Test acceptance.
testV4Accept(t, c)
}
+
+func testV4ListenClose(t *testing.T, c *context.Context) {
+ // Set the SynRcvd threshold to zero to force a syn cookie based accept
+ // to happen.
+ saved := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = saved
+ }()
+ tcp.SynRcvdCountThreshold = 0
+ const n = uint16(32)
+
+ // Start listening.
+ if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ irs := seqnum.Value(789)
+ for i := uint16(0); i < n; i++ {
+ // Send a SYN request.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + i,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Each of these ACK's will cause a syn-cookie based connection to be
+ // accepted and delivered to the listening endpoint.
+ for i := uint16(0); i < n; i++ {
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ nep.Close()
+ c.EP.Close()
+}
+
+func TestV4ListenCloseOnV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4ListenClose(t, c)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 353e2efaf..0e16877e7 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -362,6 +362,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // pendingAccepted is a synchronization primitive used to track number
+ // of connections that are queued up to be delivered to the accepted
+ // channel. We use this to ensure that all goroutines blocked on writing
+ // to the acceptedChan below terminate before we close acceptedChan.
+ pendingAccepted sync.WaitGroup `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -375,7 +381,11 @@ type endpoint struct {
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
- // The goroutine undrain notification channel.
+ // The goroutine undrain notification channel. This is currently used as
+ // a way to block the worker goroutines. Today nothing closes/writes
+ // this channel and this causes any goroutines waiting on this to just
+ // block. This is used during save/restore to prevent worker goroutines
+ // from mutating state as it's being saved.
undrain chan struct{} `state:"nosave"`
// probe if not nil is invoked on every received segment. It is passed
@@ -575,6 +585,34 @@ func (e *endpoint) Close() {
e.mu.Unlock()
}
+// closePendingAcceptableConnections closes all connections that have completed
+// handshake but not yet been delivered to the application.
+func (e *endpoint) closePendingAcceptableConnectionsLocked() {
+ done := make(chan struct{})
+ // Spin a goroutine up as ranging on e.acceptedChan will just block when
+ // there are no more connections in the channel. Using a non-blocking
+ // select does not work as it can potentially select the default case
+ // even when there are pending writes but that are not yet written to
+ // the channel.
+ go func() {
+ defer close(done)
+ for n := range e.acceptedChan {
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
+ n.Close()
+ }
+ }()
+ // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
+ // endpoints which have completed handshake but are not yet written to
+ // the e.acceptedChan. We wait here till the goroutine above can drain
+ // all such connections from e.acceptedChan.
+ e.pendingAccepted.Wait()
+ close(e.acceptedChan)
+ <-done
+ e.acceptedChan = nil
+}
+
// cleanupLocked frees all resources associated with the endpoint. It is called
// after Close() is called and the worker goroutine (if any) is done with its
// work.
@@ -582,14 +620,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
- close(e.acceptedChan)
- for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
- n.Close()
- }
- e.acceptedChan = nil
+ e.closePendingAcceptableConnectionsLocked()
}
e.workerCleanup = false
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index a3e43cad2..aa1e33fb4 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -440,7 +440,9 @@ syscall_test(
)
syscall_test(
- size = "medium",
+ size = "large",
+ parallel = False,
+ shard_count = 10,
test = "//test/syscalls/linux:socket_inet_loopback_test",
)
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index df31d25b5..322ee07ad 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -145,6 +145,67 @@ TEST_P(SocketInetLoopbackTest, TCP) {
ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds());
}
+TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), 1001), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ DisableSave ds; // Too many system calls.
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ constexpr int kFDs = 2048;
+ constexpr int kThreadCount = 4;
+ constexpr int kFDsPerThread = kFDs / kThreadCount;
+ FileDescriptor clients[kFDs];
+ std::unique_ptr<ScopedThread> threads[kThreadCount];
+ for (int i = 0; i < kFDs; i++) {
+ clients[i] = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ }
+ for (int i = 0; i < kThreadCount; i++) {
+ threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr,
+ &clients, i]() {
+ for (int j = 0; j < kFDsPerThread; j++) {
+ int k = i * kFDsPerThread + j;
+ int ret =
+ connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ }
+ });
+ }
+ for (int i = 0; i < kThreadCount; i++) {
+ threads[i]->Join();
+ }
+ for (int i = 0; i < 32; i++) {
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ }
+ // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked
+ // before function end.
+ // ds.reset()
+}
+
TEST_P(SocketInetLoopbackTest, TCPbacklog) {
auto const& param = GetParam();