From bc5e18c9d1004ee324a794446416a6b108999b9c Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 5 Sep 2018 11:47:21 -0700 Subject: Implement TCP keepalives PiperOrigin-RevId: 211670620 Change-Id: Ia8a3d8ae53a7fece1dee08ee9c74964bd7f71bb7 --- pkg/tcpip/transport/tcp/connect.go | 58 +++++++++++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 79 +++++++++++++++++++++++++ pkg/tcpip/transport/tcp/snd.go | 4 ++ pkg/tcpip/transport/tcp/tcp_test.go | 112 ++++++++++++++++++++++++++++++++++++ 4 files changed, 253 insertions(+) (limited to 'pkg/tcpip/transport/tcp') diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 14282d399..558dbc50a 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -827,9 +827,56 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } + e.resetKeepaliveTimer(true) + + return nil +} + +// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP +// keepalive packets periodically when the connection is idle. If we don't hear +// from the other side after a number of tries, we terminate the connection. +func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.keepalive.Lock() + if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + e.keepalive.Unlock() + return nil + } + + if e.keepalive.unacked >= e.keepalive.count { + e.keepalive.Unlock() + return tcpip.ErrConnectionReset + } + + // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with + // seg.seq = snd.nxt-1. + e.keepalive.unacked++ + e.keepalive.Unlock() + e.snd.sendSegment(nil, flagAck, e.snd.sndNxt-1) + e.resetKeepaliveTimer(false) return nil } +// resetKeepaliveTimer restarts or stops the keepalive timer, depending on +// whether it is enabled for this endpoint. +func (e *endpoint) resetKeepaliveTimer(receivedData bool) { + e.keepalive.Lock() + defer e.keepalive.Unlock() + if receivedData { + e.keepalive.unacked = 0 + } + // Start the keepalive timer IFF it's enabled and there is no pending + // data to send. + if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + e.keepalive.timer.disable() + return + } + if e.keepalive.unacked > 0 { + e.keepalive.timer.enable(e.keepalive.interval) + } else { + e.keepalive.timer.enable(e.keepalive.idle) + } +} + // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. @@ -892,6 +939,9 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.rcvListMu.Unlock() } + e.keepalive.timer.init(&e.keepalive.waker) + defer e.keepalive.timer.cleanup() + // Tell waiters that the endpoint is connected and writable. e.mu.Lock() e.state = stateConnected @@ -937,6 +987,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil }, }, + { + w: &e.keepalive.waker, + f: e.keepaliveTimerExpired, + }, { w: &e.notificationWaker, f: func() *tcpip.Error { @@ -982,6 +1036,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { <-e.undrain } + if n¬ifyKeepaliveChanged != 0 { + e.resetKeepaliveTimer(true) + } + return nil }, }, diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7c73f0d13..60e9daf74 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -51,6 +51,7 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyKeepaliveChanged ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -211,6 +212,12 @@ type endpoint struct { // goroutine what it was notified; this is only accessed atomically. notifyFlags uint32 `state:"nosave"` + // keepalive manages TCP keepalive state. When the connection is idle + // (no data sent or received) for keepaliveIdle, we start sending + // keepalives every keepalive.interval. If we send keepalive.count + // without hearing a response, the connection is closed. + keepalive keepalive + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -236,6 +243,21 @@ type endpoint struct { connectingAddress tcpip.Address } +// keepalive is a synchronization wrapper used to appease stateify. See the +// comment in endpoint, where it is used. +// +// +stateify savable +type keepalive struct { + sync.Mutex `state:"nosave"` + enabled bool + idle time.Duration + interval time.Duration + count int + unacked int + timer timer `state:"nosave"` + waker sleep.Waker `state:"nosave"` +} + func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ stack: stack, @@ -246,6 +268,12 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite sndMTU: int(math.MaxInt32), noDelay: false, reuseAddr: true, + keepalive: keepalive{ + // Linux defaults. + idle: 2 * time.Hour, + interval: 75 * time.Second, + count: 9, + }, } var ss SendBufferSizeOption @@ -696,6 +724,31 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.v6only = v != 0 + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v != 0 + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + e.keepalive.idle = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + e.keepalive.interval = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = int(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + } return nil @@ -799,6 +852,32 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + + case *tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + *o = 0 + if v { + *o = 1 + } + + case *tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) + e.keepalive.Unlock() + + case *tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) + e.keepalive.Unlock() + + case *tcpip.KeepaliveCountOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveCountOption(e.keepalive.count) + e.keepalive.Unlock() + } return tcpip.ErrUnknownProtocolOption diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 096ea9cd4..e4fa89912 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -435,6 +435,10 @@ func (s *sender) sendData() { if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { s.resendTimer.enable(s.rto) } + // If we have no more pending data, start the keepalive timer. + if s.sndUna == s.sndNxt { + s.ep.resetKeepaliveTimer(false) + } } func (s *sender) enterFastRecovery() { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 71d70a597..bf26ea24e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3275,3 +3275,115 @@ func enableCUBIC(t *testing.T, c *context.Context) { t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) } } + +func TestKeepalive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // 5 unacked keepalives are sent. ACK each one, and check that the + // connection stays alive after 5. + for i := 0; i < 10; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Acknowledge the keepalive. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS, + RcvWnd: 30000, + }) + } + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Send some data and wait before ACKing it. Keepalives should be disabled + // during this period. + view := buffer.NewView(3) + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for the packet to be retransmitted. Verify that no keepalives + // were sent. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + ), + ) + c.CheckNoPacket("Keepalive packet received while unACKed data is pending") + + next += uint32(len(view)) + + // Send ACK. Keepalives should start sending again. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Now receive 5 keepalives, but don't ACK them. The connection + // should be reset after 5. + for i := 0; i < 5; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next-1)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // The connection should be terminated after 5 unacked keepalives. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + } +} -- cgit v1.2.3