summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netstack/netstack.go23
-rw-r--r--pkg/tcpip/tcpip.go5
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go15
-rw-r--r--pkg/tcpip/transport/tcp/connect.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go19
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go21
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go19
-rw-r--r--pkg/tcpip/transport/tcp/rcv_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/snd.go48
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go194
12 files changed, 366 insertions, 37 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index fe5a46aa3..8a6522eac 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1127,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_USER_TIMEOUT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPUserTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Millisecond), nil
+
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
if err := ep.GetSockOpt(&v); err != nil {
@@ -1563,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_USER_TIMEOUT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))))
+
case linux.TCP_CONGESTION:
v := tcpip.CongestionControlOption(optVal)
if err := ep.SetSockOpt(v); err != nil {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index d5bb5b6ed..f62fd729f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -576,6 +576,11 @@ type KeepaliveIntervalOption time.Duration
// closed.
type KeepaliveCountOption int
+// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user
+// specified timeout for a given TCP connection.
+// See: RFC5482 for details.
+type TCPUserTimeoutOption time.Duration
+
// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
// the current congestion control algorithm.
type CongestionControlOption string
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 455a1c098..3b353d56c 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -28,6 +28,7 @@ go_library(
"forwarder.go",
"protocol.go",
"rcv.go",
+ "rcv_state.go",
"reno.go",
"sack.go",
"sack_scoreboard.go",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 74df3edfb..5422ae80c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -242,6 +242,13 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
+ // Now inherit any socket options that should be inherited from the
+ // listening endpoint.
+ // In case of Forwarder listenEP will be nil and hence this check.
+ if l.listenEP != nil {
+ l.listenEP.propagateInheritableOptions(n)
+ }
+
// Register new endpoint so that packets are routed to it.
if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
n.Close()
@@ -350,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
}
+// propagateInheritableOptions propagates any options set on the listening
+// endpoint to the newly created endpoint.
+func (e *endpoint) propagateInheritableOptions(n *endpoint) {
+ e.mu.Lock()
+ n.userTimeout = e.userTimeout
+ e.mu.Unlock()
+}
+
// handleSynSegment is called in its own goroutine once the listening endpoint
// receives a SYN segment. It is responsible for completing the handshake and
// queueing the new endpoint for acceptance.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 3d059c302..4c34fc9d2 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -862,7 +862,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
}
e.state = StateError
e.HardError = err
- if err != tcpip.ErrConnectionReset {
+ if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
// which can cause sndNxt to be outside the acceptable window on the
@@ -1087,12 +1087,24 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.mu.RLock()
+ userTimeout := e.userTimeout
+ e.mu.RUnlock()
+
e.keepalive.Lock()
if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
e.keepalive.Unlock()
return nil
}
+ // If a userTimeout is set then abort the connection if it is
+ // exceeded.
+ if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
+ e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return tcpip.ErrTimeout
+ }
+
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
e.stack.Stats().TCP.EstablishedTimedout.Increment()
@@ -1112,7 +1124,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
// whether it is enabled for this endpoint.
func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
e.keepalive.Lock()
- defer e.keepalive.Unlock()
if receivedData {
e.keepalive.unacked = 0
}
@@ -1120,6 +1131,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
// data to send.
if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
e.keepalive.timer.disable()
+ e.keepalive.Unlock()
return
}
if e.keepalive.unacked > 0 {
@@ -1127,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
} else {
e.keepalive.timer.enable(e.keepalive.idle)
}
+ e.keepalive.Unlock()
}
// disableKeepaliveTimer stops the keepalive timer.
@@ -1239,6 +1252,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
w: &e.snd.resendWaker,
f: func() *tcpip.Error {
if !e.snd.retransmitTimerExpired() {
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
return nil
@@ -1405,6 +1419,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
if s == nil {
break
}
+
e.tryDeliverSegmentFromClosedEndpoint(s)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4861ab513..dd8b47cbe 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -341,6 +341,7 @@ type endpoint struct {
// TCP should never broadcast but Linux nevertheless supports enabling/
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+
// Values used to reserve a port or register a transport endpoint
// (which ever happens first).
boundBindToDevice tcpip.NICID
@@ -474,6 +475,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // userTimeout if non-zero specifies a user specified timeout for
+ // a connection w/ pending data to send. A connection that has pending
+ // unacked data will be forcibily aborted if the timeout is reached
+ // without any data being acked.
+ userTimeout time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
@@ -1333,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
return nil
+ case tcpip.TCPUserTimeoutOption:
+ e.mu.Lock()
+ e.userTimeout = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -1591,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.keepalive.Unlock()
return nil
+ case *tcpip.TCPUserTimeoutOption:
+ e.mu.Lock()
+ *o = tcpip.TCPUserTimeoutOption(e.userTimeout)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
*o = 1
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 89b965c23..bc718064c 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -162,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo
func replyWithReset(s *segment) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
+ ack := seqnum.Value(0)
+ flags := byte(header.TCPFlagRst)
+ // As per RFC 793 page 35 (Reset Generation)
+ // 1. If the connection does not exist (CLOSED) then a reset is sent
+ // in response to any incoming segment except another reset. In
+ // particular, SYNs addressed to a non-existent connection are rejected
+ // by this means.
+
+ // If the incoming segment has an ACK field, the reset takes its
+ // sequence number from the ACK field of the segment, otherwise the
+ // reset has sequence number zero and the ACK field is set to the sum
+ // of the sequence number and segment length of the incoming segment.
+ // The connection remains in the CLOSED state.
if s.flagIsSet(header.TCPFlagAck) {
seq = s.ackNumber
+ } else {
+ flags |= header.TCPFlagAck
+ ack = s.sequenceNumber.Add(s.logicalLen())
}
-
- ack := s.sequenceNumber.Add(s.logicalLen())
-
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
// SetOption implements TransportProtocol.SetOption.
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 5ee499c36..0a5534959 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -50,16 +50,20 @@ type receiver struct {
pendingRcvdSegments segmentHeap
pendingBufUsed seqnum.Size
pendingBufSize seqnum.Size
+
+ // Time when the last ack was received.
+ lastRcvdAckTime time.Time `state:".(unixTime)"`
}
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
return &receiver{
- ep: ep,
- rcvNxt: irs + 1,
- rcvAcc: irs.Add(rcvWnd + 1),
- rcvWnd: rcvWnd,
- rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWnd: rcvWnd,
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: pendingBufSize,
+ lastRcvdAckTime: time.Now(),
}
}
@@ -360,6 +364,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
return true, nil
}
+ // Store the time of the last ack.
+ r.lastRcvdAckTime = time.Now()
+
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go
new file mode 100644
index 000000000..2bf21a2e7
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveLastRcvdAckTime is invoked by stateify.
+func (r *receiver) saveLastRcvdAckTime() unixTime {
+ return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()}
+}
+
+// loadLastRcvdAckTime is invoked by stateify.
+func (r *receiver) loadLastRcvdAckTime(unix unixTime) {
+ r.lastRcvdAckTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 8332a0179..8a947dc66 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -28,8 +28,11 @@ import (
)
const (
- // minRTO is the minimum allowed value for the retransmit timeout.
- minRTO = 200 * time.Millisecond
+ // MinRTO is the minimum allowed value for the retransmit timeout.
+ MinRTO = 200 * time.Millisecond
+
+ // MaxRTO is the maximum allowed value for the retransmit timeout.
+ MaxRTO = 120 * time.Second
// InitialCwnd is the initial congestion window.
InitialCwnd = 10
@@ -134,6 +137,10 @@ type sender struct {
// rttMeasureTime is the time when the rttMeasureSeqNum was sent.
rttMeasureTime time.Time `state:".(unixTime)"`
+ // firstRetransmittedSegXmitTime is the original transmit time of
+ // the first segment that was retransmitted due to RTO expiration.
+ firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
+
closed bool
writeNext *segment
writeList segmentList
@@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) {
s.rto = s.rtt.srtt + 4*s.rtt.rttvar
s.rtt.Unlock()
- if s.rto < minRTO {
- s.rto = minRTO
+ if s.rto < MinRTO {
+ s.rto = MinRTO
}
}
@@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool {
s.ep.stack.Stats().TCP.Timeouts.Increment()
s.ep.stats.SendErrors.Timeouts.Increment()
- // Give up if we've waited more than a minute since the last resend.
- if s.rto >= 60*time.Second {
+ // Give up if we've waited more than a minute since the last resend or
+ // if a user time out is set and we have exceeded the user specified
+ // timeout since the first retransmission.
+ s.ep.mu.RLock()
+ uto := s.ep.userTimeout
+ s.ep.mu.RUnlock()
+
+ if s.firstRetransmittedSegXmitTime.IsZero() {
+ // We store the original xmitTime of the segment that we are
+ // about to retransmit as the retransmission time. This is
+ // required as by the time the retransmitTimer has expired the
+ // segment has already been sent and unacked for the RTO at the
+ // time the segment was sent.
+ s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime
+ }
+
+ elapsed := time.Since(s.firstRetransmittedSegXmitTime)
+ remaining := MaxRTO
+ if uto != 0 {
+ // Cap to the user specified timeout if one is specified.
+ remaining = uto - elapsed
+ }
+
+ if remaining <= 0 || s.rto >= MaxRTO {
return false
}
@@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool {
// below.
s.rto *= 2
+ // Cap RTO to remaining time.
+ if s.rto > remaining {
+ s.rto = remaining
+ }
+
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
//
// Retransmit timeouts:
@@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// RFC 6298 Rule 5.3
if s.sndUna == s.sndNxt {
s.outstanding = 0
+ // Reset firstRetransmittedSegXmitTime to the zero value.
+ s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
}
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
index 12eff8afc..8b20c3455 100644
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) {
func (s *sender) afterLoad() {
s.resendTimer.init(&s.resendWaker)
}
+
+// saveFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime {
+ return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()}
+}
+
+// loadFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) {
+ s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index bc5cfcf0e..2a83f7bcc 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -323,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
}
func TestTCPResetsReceivedIncrement(t *testing.T) {
@@ -460,18 +460,17 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
break
}
}
-// TestClosingWithEnqueuedSegments tests handling of
-// still enqueued segments when the endpoint transitions
-// to StateClose. The in-flight segments would be re-enqueued
-// to a any listening endpoint.
+// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
+// when the endpoint transitions to StateClose. The in-flight segments would be
+// re-enqueued to a any listening endpoint.
func TestClosingWithEnqueuedSegments(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -576,8 +575,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(793),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
}
@@ -914,7 +913,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.TCPFlags(header.TCPFlagRst),
checker.SeqNum(200)))
}
@@ -942,7 +941,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.TCPFlags(header.TCPFlagRst),
checker.SeqNum(200)))
}
@@ -4291,8 +4290,9 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ const keepAliveInterval = 10 * time.Millisecond
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
@@ -4382,13 +4382,29 @@ func TestKeepalive(t *testing.T) {
)
}
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + 5*time.Millisecond)
+
// The connection should be terminated after 5 unacked keepalives.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -6157,8 +6173,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(uint32(ackHeaders.SeqNum)),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
@@ -6336,7 +6352,147 @@ func TestTCPCloseWithData(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(uint32(ackHeaders.SeqNum)),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+ userTimeout := 50 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Send some data and wait before ACKing it.
+ view := buffer.NewView(3)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Wait for a little over the minimum retransmit timeout of 200ms for
+ // the retransmitTimer to fire and close the connection.
+ time.Sleep(tcp.MinRTO + 10*time.Millisecond)
+
+ // No packet should be received as the connection should be silently
+ // closed due to timeout.
+ c.CheckNoPacket("unexpected packet received after userTimeout has expired")
+
+ next += uint32(len(view))
+
+ // The connection should be terminated after userTimeout has expired.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ }
+
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ }
+}
+
+func TestKeepaliveWithUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ const keepAliveInterval = 10 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10))
+ c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
+
+ // Set userTimeout to be the duration for 3 keepalive probes.
+ userTimeout := 30 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Check that the connection is still alive.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ // Now receive 2 keepalives, but don't ACK them. The connection should
+ // be reset when the 3rd one should be sent due to userTimeout being
+ // 30ms and each keepalive probe should be sent 10ms apart as set above after
+ // the connection has been idle for 10ms.
+ for i := 0; i < 2; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + 5*time.Millisecond)
+
+ // The connection should be terminated after 30ms.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(c.IRS + 1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ }
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ }
}