summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go7
-rw-r--r--pkg/tcpip/transport/tcp/connect.go30
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go37
4 files changed, 76 insertions, 28 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 7a9dea4ac..e07b436c4 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -330,6 +330,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
+
+ ep.drainClosingSegmentQueue()
+
return nil, err
}
ep.isConnectNotified = true
@@ -378,7 +381,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
for {
if e.acceptedChan == nil {
e.acceptMu.Unlock()
- n.Close()
+ n.notifyProtocolGoroutine(notifyReset)
return
}
select {
@@ -656,6 +659,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
}
e.mu.Unlock()
+ e.drainClosingSegmentQueue()
+
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 2ca3fb809..994ac52a3 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1062,6 +1062,20 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
}
}
+// Drain segment queue from the endpoint and try to re-match the segment to a
+// different endpoint. This is used when the current endpoint is transitioned to
+// StateClose and has been unregistered from the transport demuxer.
+func (e *endpoint) drainClosingSegmentQueue() {
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+}
+
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
@@ -1315,6 +1329,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.mu.Unlock()
+
+ e.drainClosingSegmentQueue()
+
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1565,19 +1582,6 @@ loop:
// Lock released below.
epilogue()
- // epilogue removes the endpoint from the transport-demuxer and
- // unlocks e.mu. Now that no new segments can get enqueued to this
- // endpoint, try to re-match the segment to a different endpoint
- // as the current endpoint is closed.
- for {
- s := e.segmentQueue.dequeue()
- if s == nil {
- break
- }
-
- e.tryDeliverSegmentFromClosedEndpoint(s)
- }
-
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a8d443f73..7ed78d57f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -980,25 +980,22 @@ func (e *endpoint) closeNoShutdownLocked() {
// Mark endpoint as closed.
e.closed = true
+
+ switch e.EndpointState() {
+ case StateClose, StateError:
+ return
+ }
+
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
- switch e.EndpointState() {
- // Sockets in StateSynRecv state(passive connections) are closed when
- // the handshake fails or if the listening socket is closed while
- // handshake was in progress. In such cases the handshake goroutine
- // is already gone by the time Close is called and we need to cleanup
- // here.
- case StateInitial, StateBound, StateSynRecv:
- e.cleanupLocked()
- e.setEndpointState(StateClose)
- case StateError, StateClose:
- // do nothing.
- default:
+ if e.workerRunning {
e.workerCleanup = true
tcpip.AddDanglingEndpoint(e)
// Worker will remove the dangling endpoint when the endpoint
// goroutine terminates.
e.notifyProtocolGoroutine(notifyClose)
+ } else {
+ e.transitionToStateCloseLocked()
}
}
@@ -1010,13 +1007,18 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Unlock()
return
}
-
close(e.acceptedChan)
+ ch := e.acceptedChan
e.acceptedChan = nil
e.acceptCond.Broadcast()
e.acceptMu.Unlock()
- // Wait for all pending endpoints to close.
+ // Reset all connections that are waiting to be accepted.
+ for n := range ch {
+ n.notifyProtocolGoroutine(notifyReset)
+ }
+ // Wait for reset of all endpoints that are still waiting to be delivered to
+ // the now closed acceptedChan.
e.pendingAccepted.Wait()
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 41caa9ed4..a9f121c17 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1068,6 +1068,43 @@ func TestListenShutdown(t *testing.T) {
c.CheckNoPacket("Packet received when listening socket was shutdown")
}
+// TestListenCloseWhileConnect tests for the listening endpoint to
+// drain the accept-queue when closed. This should reset all of the
+// pending connections that are waiting to be accepted.
+func TestListenCloseWhileConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventIn)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+ // Wait for the new endpoint created because of handshake to be delivered
+ // to the listening endpoint's accept queue.
+ <-notifyCh
+
+ // Close the listening endpoint.
+ c.EP.Close()
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
func TestTOSV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()