diff options
author | Mithun Iyer <iyerm@google.com> | 2019-11-15 11:44:02 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-11-15 12:11:36 -0800 |
commit | 3e534f2974f469a889534221b83c3bbbd1b0318c (patch) | |
tree | d1c94f32370f7843166608fd48ece498a46feea3 | |
parent | 76039f895995c3fe0deef5958f843868685ecc38 (diff) |
Handle in-flight TCP segments when moving to CLOSE.
As we move to CLOSE state from LAST-ACK or TIME-WAIT,
ensure that we re-match all in-flight segments to any
listening endpoint.
Also fix LISTEN state handling of any ACK segments as per RFC793.
Fixes #1153
PiperOrigin-RevId: 280703556
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 60 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 196 |
4 files changed, 261 insertions, 11 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f24b51b91..023045ec1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -419,8 +419,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // TODO(b/143300739): Use the userMSS of the listening socket // for accepted sockets. - switch s.flags { - case header.TCPFlagSyn: + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { // Only handle the syn if the following conditions hold @@ -464,7 +464,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -478,6 +478,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } if !synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 49f2b9685..364067731 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -865,6 +865,33 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + if e.state == StateClose { + return + } + e.cleanupLocked() + e.state = StateClose +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil { + replyWithReset(s) + s.decRef() + return + } + ep.(*endpoint).enqueueSegment(s) +} + func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states @@ -894,12 +921,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // general "connection reset" signal. Enter the CLOSED state, // delete the TCB, and return. case StateCloseWait: - e.state = StateClose + e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted - // We need to set this explicitly here because otherwise - // the port registrations will not be released till the - // endpoint is actively closed by the application. - e.workerCleanup = true e.mu.Unlock() return false, nil default: @@ -915,6 +938,20 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + return nil + } + s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false @@ -1160,7 +1197,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // to the TCP_FIN_WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() - e.state = StateClose + e.transitionToStateCloseLocked() e.mu.Unlock() return nil }, @@ -1321,12 +1358,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if e.state != StateError { e.stack.Stats().TCP.EstablishedResets.Increment() e.stack.Stats().TCP.CurrentEstablished.Decrement() - e.state = StateClose + e.transitionToStateCloseLocked() } // Lock released below. epilogue() + // epilogue removes the endpoint from the transport-demuxer and + // unlocks e.mu. Now that no new segments can get enqueued to this + // endpoint, try to re-match the segment to a different endpoint + // as the current endpoint is closed. + for !e.segmentQueue.empty() { + s := e.segmentQueue.dequeue() + e.tryDeliverSegmentFromClosedEndpoint(s) + } + // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue if reuseTW != nil { diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 068b90fb6..857dc445f 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -218,7 +218,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateClosing: r.ep.state = StateTimeWait case StateLastAck: - r.ep.state = StateClose + r.ep.transitionToStateCloseLocked() } r.ep.mu.Unlock() } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b443fe9dc..64f765c70 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -454,6 +454,112 @@ func TestConnectResetAfterClose(t *testing.T) { } } +// TestClosingWithEnqueuedSegments tests handling of +// still enqueued segments when the endpoint transitions +// to StateClose. The in-flight segments would be re-enqueued +// to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(1 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + ), + ) +} + func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -686,6 +792,96 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.SeqNum(200))) } +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint not processing +// any receive when it is on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + // Wait for the endpoint state to be propagated. + time.Sleep(10 * time.Millisecond) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + c.CheckNoPacket("Packet received when listening socket was shutdown") +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() |