summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go16
-rw-r--r--pkg/tcpip/transport/tcp/snd.go27
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go77
4 files changed, 104 insertions, 20 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 9aaabe0b1..b0cf0eaf6 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -938,6 +938,10 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.snd.updateMaxPayloadSize(mtu, count)
}
+ if n&notifyReset != 0 {
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ }
+
if n&notifyClose != 0 && closeTimer == nil {
// Reset the connection 3 seconds after the
// endpoint has been closed.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b21c2b4ab..191dc1acc 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -40,6 +40,7 @@ const (
notifyClose
notifyMTUChanged
notifyDrain
+ notifyReset
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -919,7 +920,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
switch e.state {
case stateConnected:
// Close for write.
- if (flags & tcpip.ShutdownWrite) != 0 {
+ if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ // We're fully closed, if we have unread data we need to abort
+ // the connection with a RST.
+ e.rcvListMu.Lock()
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ if rcvBufUsed > 0 {
+ e.notifyProtocolGoroutine(notifyReset)
+ return nil
+ }
+ }
+
e.sndBufMu.Lock()
if e.sndClosed {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 95bea4d88..085973c02 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -333,28 +333,17 @@ func (s *sender) sendData() {
var segEnd seqnum.Value
if seg.data.Size() == 0 {
- seg.flags = flagAck
-
- s.ep.rcvListMu.Lock()
- rcvBufUsed := s.ep.rcvBufUsed
- s.ep.rcvListMu.Unlock()
-
- s.ep.mu.Lock()
- // We're sending a FIN by default
- fl := flagFin
- segEnd = seg.sequenceNumber
- if (s.ep.shutdownFlags&tcpip.ShutdownRead) != 0 && rcvBufUsed > 0 {
- // If there is unread data we must send a RST.
- // For more information see RFC 2525 section 2.17.
- fl = flagRst
- } else {
- segEnd = seg.sequenceNumber.Add(1)
+ if s.writeList.Back() != seg {
+ panic("FIN segments must be the final segment in the write list.")
}
-
- s.ep.mu.Unlock()
- seg.flags |= uint8(fl)
+ seg.flags = flagAck | flagFin
+ segEnd = seg.sequenceNumber.Add(1)
} else {
// We're sending a non-FIN segment.
+ if seg.flags&flagFin != 0 {
+ panic("Netstack queues FIN segments without data.")
+ }
+
if !seg.sequenceNumber.LessThan(end) {
break
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fa2ef52f9..e564af8c0 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -416,6 +416,83 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
})
}
+func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that ACK is received, this happens regardless of the read.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Cause a FIN to be generated.
+ c.EP.Shutdown(tcpip.ShutdownWrite)
+
+ // Make sure we get the FIN but DON't ACK IT.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ checker.SeqNum(uint32(c.IRS)+1),
+ ))
+
+ // Cause a RST to be generated by closing the read end now since we have
+ // unread data.
+ c.EP.Shutdown(tcpip.ShutdownRead)
+
+ // Make sure we get the RST
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // We shouldn't consume a sequence number on RST.
+ checker.SeqNum(uint32(c.IRS)+1),
+ ))
+
+ // The ACK to the FIN should now be rejected since the connection has been
+ // closed by a RST.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + len(data)),
+ AckNum: c.IRS.Add(seqnum.Size(2)),
+ RcvWnd: 30000,
+ })
+}
+
func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()