summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMithun Iyer <iyerm@google.com>2019-11-15 11:44:02 -0800
committergVisor bot <gvisor-bot@google.com>2019-11-15 12:11:36 -0800
commit3e534f2974f469a889534221b83c3bbbd1b0318c (patch)
treed1c94f32370f7843166608fd48ece498a46feea3
parent76039f895995c3fe0deef5958f843868685ecc38 (diff)
Handle in-flight TCP segments when moving to CLOSE.
As we move to CLOSE state from LAST-ACK or TIME-WAIT, ensure that we re-match all in-flight segments to any listening endpoint. Also fix LISTEN state handling of any ACK segments as per RFC793. Fixes #1153 PiperOrigin-RevId: 280703556
-rw-r--r--pkg/tcpip/transport/tcp/accept.go14
-rw-r--r--pkg/tcpip/transport/tcp/connect.go60
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go196
4 files changed, 261 insertions, 11 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index f24b51b91..023045ec1 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -419,8 +419,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// TODO(b/143300739): Use the userMSS of the listening socket
// for accepted sockets.
- switch s.flags {
- case header.TCPFlagSyn:
+ switch {
+ case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
// Only handle the syn if the following conditions hold
@@ -464,7 +464,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
- case header.TCPFlagAck:
+ case (s.flags & header.TCPFlagAck) != 0:
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -478,6 +478,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
}
if !synCookiesInUse() {
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
// Send a reset as this is an ACK for which there is no
// half open connections and we are not using cookies
// yet.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 49f2b9685..364067731 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -865,6 +865,33 @@ func (e *endpoint) completeWorkerLocked() {
}
}
+// transitionToStateCloseLocked ensures that the endpoint is
+// cleaned up from the transport demuxer, "before" moving to
+// StateClose. This will ensure that no packet will be
+// delivered to this endpoint from the demuxer when the endpoint
+// is transitioned to StateClose.
+func (e *endpoint) transitionToStateCloseLocked() {
+ if e.state == StateClose {
+ return
+ }
+ e.cleanupLocked()
+ e.state = StateClose
+}
+
+// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
+// segment to any other endpoint other than the current one. This is called
+// only when the endpoint is in StateClose and we want to deliver the segment
+// to any other listening endpoint. We reply with RST if we cannot find one.
+func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ if ep == nil {
+ replyWithReset(s)
+ s.decRef()
+ return
+ }
+ ep.(*endpoint).enqueueSegment(s)
+}
+
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
@@ -894,12 +921,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// general "connection reset" signal. Enter the CLOSED state,
// delete the TCB, and return.
case StateCloseWait:
- e.state = StateClose
+ e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
- // We need to set this explicitly here because otherwise
- // the port registrations will not be released till the
- // endpoint is actively closed by the application.
- e.workerCleanup = true
e.mu.Unlock()
return false, nil
default:
@@ -915,6 +938,20 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
func (e *endpoint) handleSegments() *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ return nil
+ }
+
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
@@ -1160,7 +1197,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// to the TCP_FIN_WAIT2 timeout was hit. Just
// mark the socket as closed.
e.mu.Lock()
- e.state = StateClose
+ e.transitionToStateCloseLocked()
e.mu.Unlock()
return nil
},
@@ -1321,12 +1358,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
if e.state != StateError {
e.stack.Stats().TCP.EstablishedResets.Increment()
e.stack.Stats().TCP.CurrentEstablished.Decrement()
- e.state = StateClose
+ e.transitionToStateCloseLocked()
}
// Lock released below.
epilogue()
+ // epilogue removes the endpoint from the transport-demuxer and
+ // unlocks e.mu. Now that no new segments can get enqueued to this
+ // endpoint, try to re-match the segment to a different endpoint
+ // as the current endpoint is closed.
+ for !e.segmentQueue.empty() {
+ s := e.segmentQueue.dequeue()
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 068b90fb6..857dc445f 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -218,7 +218,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateClosing:
r.ep.state = StateTimeWait
case StateLastAck:
- r.ep.state = StateClose
+ r.ep.transitionToStateCloseLocked()
}
r.ep.mu.Unlock()
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index b443fe9dc..64f765c70 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -454,6 +454,112 @@ func TestConnectResetAfterClose(t *testing.T) {
}
}
+// TestClosingWithEnqueuedSegments tests handling of
+// still enqueued segments when the endpoint transitions
+// to StateClose. The in-flight segments would be re-enqueued
+// to a any listening endpoint.
+func TestClosingWithEnqueuedSegments(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Send a FIN for ESTABLISHED --> CLOSED-WAIT
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Get the ACK for the FIN we sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
+ ep.Close()
+
+ // Get the FIN
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Pause the endpoint`s protocolMainLoop.
+ ep.(interface{ StopWork() }).StopWork()
+
+ // Enqueue last ACK followed by an ACK matching the endpoint
+ //
+ // Send Last ACK for LAST_ACK --> CLOSED
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 791,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Send a packet with ACK set, this would generate RST when
+ // not using SYN cookies as in this test.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 792,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Unpause endpoint`s protocolMainLoop.
+ ep.(interface{ ResumeWork() }).ResumeWork()
+
+ // Wait for the protocolMainLoop to resume and update state.
+ time.Sleep(1 * time.Millisecond)
+
+ // Expect the endpoint to be closed.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Check if the endpoint was moved to CLOSED and netstack a reset in
+ // response to the ACK packet that we sent after last-ACK.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(793),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ),
+ )
+}
+
func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -686,6 +792,96 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
checker.SeqNum(200)))
}
+func TestSendRstOnListenerRxAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true /* v6Only */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(200)))
+}
+
+// TestListenShutdown tests for the listening endpoint not processing
+// any receive when it is on read shutdown.
+func TestListenShutdown(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatal("Shutdown failed:", err)
+ }
+
+ // Wait for the endpoint state to be propagated.
+ time.Sleep(10 * time.Millisecond)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ c.CheckNoPacket("Packet received when listening socket was shutdown")
+}
+
func TestTOSV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()