diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 77 |
4 files changed, 104 insertions, 20 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 9aaabe0b1..b0cf0eaf6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -938,6 +938,10 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.snd.updateMaxPayloadSize(mtu, count) } + if n¬ifyReset != 0 { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + } + if n¬ifyClose != 0 && closeTimer == nil { // Reset the connection 3 seconds after the // endpoint has been closed. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b21c2b4ab..191dc1acc 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -40,6 +40,7 @@ const ( notifyClose notifyMTUChanged notifyDrain + notifyReset ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -919,7 +920,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { switch e.state { case stateConnected: // Close for write. - if (flags & tcpip.ShutdownWrite) != 0 { + if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { + if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + // We're fully closed, if we have unread data we need to abort + // the connection with a RST. + e.rcvListMu.Lock() + rcvBufUsed := e.rcvBufUsed + e.rcvListMu.Unlock() + + if rcvBufUsed > 0 { + e.notifyProtocolGoroutine(notifyReset) + return nil + } + } + e.sndBufMu.Lock() if e.sndClosed { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 95bea4d88..085973c02 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -333,28 +333,17 @@ func (s *sender) sendData() { var segEnd seqnum.Value if seg.data.Size() == 0 { - seg.flags = flagAck - - s.ep.rcvListMu.Lock() - rcvBufUsed := s.ep.rcvBufUsed - s.ep.rcvListMu.Unlock() - - s.ep.mu.Lock() - // We're sending a FIN by default - fl := flagFin - segEnd = seg.sequenceNumber - if (s.ep.shutdownFlags&tcpip.ShutdownRead) != 0 && rcvBufUsed > 0 { - // If there is unread data we must send a RST. - // For more information see RFC 2525 section 2.17. - fl = flagRst - } else { - segEnd = seg.sequenceNumber.Add(1) + if s.writeList.Back() != seg { + panic("FIN segments must be the final segment in the write list.") } - - s.ep.mu.Unlock() - seg.flags |= uint8(fl) + seg.flags = flagAck | flagFin + segEnd = seg.sequenceNumber.Add(1) } else { // We're sending a non-FIN segment. + if seg.flags&flagFin != 0 { + panic("Netstack queues FIN segments without data.") + } + if !seg.sequenceNumber.LessThan(end) { break } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index fa2ef52f9..e564af8c0 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -416,6 +416,83 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { }) } +func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that ACK is received, this happens regardless of the read. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Cause a FIN to be generated. + c.EP.Shutdown(tcpip.ShutdownWrite) + + // Make sure we get the FIN but DON't ACK IT. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + checker.SeqNum(uint32(c.IRS)+1), + )) + + // Cause a RST to be generated by closing the read end now since we have + // unread data. + c.EP.Shutdown(tcpip.ShutdownRead) + + // Make sure we get the RST + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // We shouldn't consume a sequence number on RST. + checker.SeqNum(uint32(c.IRS)+1), + )) + + // The ACK to the FIN should now be rejected since the connection has been + // closed by a RST. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + len(data)), + AckNum: c.IRS.Add(seqnum.Size(2)), + RcvWnd: 30000, + }) +} + func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() |