summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2018-09-05 11:47:21 -0700
committerShentubot <shentubot@google.com>2018-09-05 11:48:23 -0700
commitbc5e18c9d1004ee324a794446416a6b108999b9c (patch)
treec17a369779942637d7a13f97c84901776f2e0877 /pkg/tcpip/transport
parent2b8dae0bc5594f7088dd028268efaedbb5a72507 (diff)
Implement TCP keepalives
PiperOrigin-RevId: 211670620 Change-Id: Ia8a3d8ae53a7fece1dee08ee9c74964bd7f71bb7
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/tcp/connect.go58
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go79
-rw-r--r--pkg/tcpip/transport/tcp/snd.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go112
4 files changed, 253 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 14282d399..558dbc50a 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -827,9 +827,56 @@ func (e *endpoint) handleSegments() *tcpip.Error {
e.snd.sendAck()
}
+ e.resetKeepaliveTimer(true)
+
+ return nil
+}
+
+// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
+// keepalive packets periodically when the connection is idle. If we don't hear
+// from the other side after a number of tries, we terminate the connection.
+func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.keepalive.Lock()
+ if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ e.keepalive.Unlock()
+ return nil
+ }
+
+ if e.keepalive.unacked >= e.keepalive.count {
+ e.keepalive.Unlock()
+ return tcpip.ErrConnectionReset
+ }
+
+ // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
+ // seg.seq = snd.nxt-1.
+ e.keepalive.unacked++
+ e.keepalive.Unlock()
+ e.snd.sendSegment(nil, flagAck, e.snd.sndNxt-1)
+ e.resetKeepaliveTimer(false)
return nil
}
+// resetKeepaliveTimer restarts or stops the keepalive timer, depending on
+// whether it is enabled for this endpoint.
+func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
+ e.keepalive.Lock()
+ defer e.keepalive.Unlock()
+ if receivedData {
+ e.keepalive.unacked = 0
+ }
+ // Start the keepalive timer IFF it's enabled and there is no pending
+ // data to send.
+ if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ e.keepalive.timer.disable()
+ return
+ }
+ if e.keepalive.unacked > 0 {
+ e.keepalive.timer.enable(e.keepalive.interval)
+ } else {
+ e.keepalive.timer.enable(e.keepalive.idle)
+ }
+}
+
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
@@ -892,6 +939,9 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.rcvListMu.Unlock()
}
+ e.keepalive.timer.init(&e.keepalive.waker)
+ defer e.keepalive.timer.cleanup()
+
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
e.state = stateConnected
@@ -938,6 +988,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
},
},
{
+ w: &e.keepalive.waker,
+ f: e.keepaliveTimerExpired,
+ },
+ {
w: &e.notificationWaker,
f: func() *tcpip.Error {
n := e.fetchNotifications()
@@ -982,6 +1036,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
<-e.undrain
}
+ if n&notifyKeepaliveChanged != 0 {
+ e.resetKeepaliveTimer(true)
+ }
+
return nil
},
},
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7c73f0d13..60e9daf74 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -51,6 +51,7 @@ const (
notifyMTUChanged
notifyDrain
notifyReset
+ notifyKeepaliveChanged
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -211,6 +212,12 @@ type endpoint struct {
// goroutine what it was notified; this is only accessed atomically.
notifyFlags uint32 `state:"nosave"`
+ // keepalive manages TCP keepalive state. When the connection is idle
+ // (no data sent or received) for keepaliveIdle, we start sending
+ // keepalives every keepalive.interval. If we send keepalive.count
+ // without hearing a response, the connection is closed.
+ keepalive keepalive
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -236,6 +243,21 @@ type endpoint struct {
connectingAddress tcpip.Address
}
+// keepalive is a synchronization wrapper used to appease stateify. See the
+// comment in endpoint, where it is used.
+//
+// +stateify savable
+type keepalive struct {
+ sync.Mutex `state:"nosave"`
+ enabled bool
+ idle time.Duration
+ interval time.Duration
+ count int
+ unacked int
+ timer timer `state:"nosave"`
+ waker sleep.Waker `state:"nosave"`
+}
+
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
stack: stack,
@@ -246,6 +268,12 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
sndMTU: int(math.MaxInt32),
noDelay: false,
reuseAddr: true,
+ keepalive: keepalive{
+ // Linux defaults.
+ idle: 2 * time.Hour,
+ interval: 75 * time.Second,
+ count: 9,
+ },
}
var ss SendBufferSizeOption
@@ -696,6 +724,31 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
e.v6only = v != 0
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v != 0
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ e.keepalive.idle = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ e.keepalive.interval = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = int(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
}
return nil
@@ -799,6 +852,32 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+
+ case *tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
+ e.keepalive.Unlock()
+
+ case *tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
+ e.keepalive.Unlock()
+
+ case *tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveCountOption(e.keepalive.count)
+ e.keepalive.Unlock()
+
}
return tcpip.ErrUnknownProtocolOption
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 096ea9cd4..e4fa89912 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -435,6 +435,10 @@ func (s *sender) sendData() {
if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
s.resendTimer.enable(s.rto)
}
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
}
func (s *sender) enterFastRecovery() {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 71d70a597..bf26ea24e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3275,3 +3275,115 @@ func enableCUBIC(t *testing.T, c *context.Context) {
t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
}
}
+
+func TestKeepalive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
+ c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
+
+ // 5 unacked keepalives are sent. ACK each one, and check that the
+ // connection stays alive after 5.
+ for i := 0; i < 10; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Acknowledge the keepalive.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Check that the connection is still alive.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send some data and wait before ACKing it. Keepalives should be disabled
+ // during this period.
+ view := buffer.NewView(3)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Wait for the packet to be retransmitted. Verify that no keepalives
+ // were sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
+ ),
+ )
+ c.CheckNoPacket("Keepalive packet received while unACKed data is pending")
+
+ next += uint32(len(view))
+
+ // Send ACK. Keepalives should start sending again.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Now receive 5 keepalives, but don't ACK them. The connection
+ // should be reset after 5.
+ for i := 0; i < 5; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next-1)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // The connection should be terminated after 5 unacked keepalives.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ }
+}