summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go17
-rw-r--r--pkg/tcpip/transport/tcp/connect.go322
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go101
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go26
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go43
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go167
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_state_autogen.go4
7 files changed, 618 insertions, 62 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index cb0e13ebc..0e8e0a2b4 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -269,8 +269,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
- ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
+ isn := generateSecureISN(s.id, l.stack.Seed())
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
if err != nil {
return nil, err
}
@@ -289,7 +289,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
- h.resetToSynRcvd(cookie, irs, opts)
+ h.resetToSynRcvd(isn, irs, opts)
if err := h.execute(); err != nil {
ep.Close()
if l.listenEP != nil {
@@ -361,6 +361,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer decSynRcvdCount()
defer e.decSynRcvdCount()
defer s.decRef()
+
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -368,6 +369,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
return
}
ctx.removePendingEndpoint(n)
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
e.deliverAccepted(n)
}
@@ -543,6 +549,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// number of goroutines as we do check before
// entering here that there was at least some
// space available in the backlog.
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index ca982c451..a114c06c1 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"sync"
"time"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -139,7 +141,32 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+ h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+}
+
+// generateSecureISN generates a secure Initial Sequence number based on the
+// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
+func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+ isnHasher := jenkins.Sum32(seed)
+ isnHasher.Write([]byte(id.LocalAddress))
+ isnHasher.Write([]byte(id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
+ isnHasher.Write(portBuf)
+ binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
+ isnHasher.Write(portBuf)
+ // The time period here is 64ns. This is similar to what linux uses
+ // generate a sequence number that overlaps less than one
+ // time per MSL (2 minutes).
+ //
+ // A 64ns clock ticks 10^9/64 = 15625000) times in a second.
+ // To wrap the whole 32 bit space would require
+ // 2^32/1562500 ~ 274 seconds.
+ //
+ // Which sort of guarantees that we won't reuse the ISN for a new
+ // connection for the same tuple for at least 274s.
+ isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ return seqnum.Value(isn)
}
// effectiveRcvWndScale returns the effective receive window scale to be used.
@@ -809,7 +836,19 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
e.state = StateError
e.HardError = err
if err != tcpip.ErrConnectionReset {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ // The exact sequence number to be used for the RST is the same as the
+ // one used by Linux. We need to handle the case of window being shrunk
+ // which can cause sndNxt to be outside the acceptable window on the
+ // receiver.
+ //
+ // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
+ // information.
+ sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ resetSeqNum := sndWndEnd
+ if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
+ resetSeqNum = e.snd.sndNxt
+ }
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
}
}
@@ -823,6 +862,51 @@ func (e *endpoint) completeWorkerLocked() {
}
}
+func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ e.mu.Lock()
+ switch e.state {
+ // In case of a RST in CLOSE-WAIT linux moves
+ // the socket to closed state with an error set
+ // to indicate EPIPE.
+ //
+ // Technically this seems to be at odds w/ RFC.
+ // As per https://tools.ietf.org/html/rfc793#section-2.7
+ // page 69 the behavior for a segment arriving
+ // w/ RST bit set in CLOSE-WAIT is inlined below.
+ //
+ // ESTABLISHED
+ // FIN-WAIT-1
+ // FIN-WAIT-2
+ // CLOSE-WAIT
+
+ // If the RST bit is set then, any outstanding RECEIVEs and
+ // SEND should receive "reset" responses. All segment queues
+ // should be flushed. Users should also receive an unsolicited
+ // general "connection reset" signal. Enter the CLOSED state,
+ // delete the TCB, and return.
+ case StateCloseWait:
+ e.state = StateClose
+ e.HardError = tcpip.ErrAborted
+ // We need to set this explicitly here because otherwise
+ // the port registrations will not be released till the
+ // endpoint is actively closed by the application.
+ e.workerCleanup = true
+ e.mu.Unlock()
+ return false, nil
+ default:
+ e.mu.Unlock()
+ return false, tcpip.ErrConnectionReset
+ }
+ }
+ return true, nil
+}
+
// handleSegments pulls segments from the queue and processes them. It returns
// no error if the protocol loop should continue, an error otherwise.
func (e *endpoint) handleSegments() *tcpip.Error {
@@ -840,14 +924,34 @@ func (e *endpoint) handleSegments() *tcpip.Error {
}
if s.flagIsSet(header.TCPFlagRst) {
- if e.rcv.acceptable(s.sequenceNumber, 0) {
- // RFC 793, page 37 states that "in all states
- // except SYN-SENT, all reset (RST) segments are
- // validated by checking their SEQ-fields." So
- // we only process it if it's acceptable.
- s.decRef()
- return tcpip.ErrConnectionReset
+ if ok, err := e.handleReset(s); !ok {
+ return err
}
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
} else if s.flagIsSet(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
@@ -856,7 +960,15 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// RFC 793, page 41 states that "once in the ESTABLISHED
// state all segments must carry current acknowledgment
// information."
- e.rcv.handleRcvdSegment(s)
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
+ }
+ if drop {
+ s.decRef()
+ continue
+ }
e.snd.handleRcvdSegment(s)
}
s.decRef()
@@ -955,7 +1067,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
-
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1001,6 +1112,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// RTT itself.
e.rcvAutoParams.prevCopied = initialRcvWnd
e.rcvListMu.Unlock()
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.mu.Lock()
+ e.state = StateEstablished
+ e.mu.Unlock()
}
e.keepalive.timer.init(&e.keepalive.waker)
@@ -1008,10 +1123,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
- if e.state != StateEstablished {
- e.stack.Stats().TCP.CurrentEstablished.Increment()
- e.state = StateEstablished
- }
drained := e.drainDone != nil
e.mu.Unlock()
if drained {
@@ -1042,7 +1153,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
{
w: &closeWaker,
f: func() *tcpip.Error {
- return tcpip.ErrConnectionAborted
+ // This means the socket is being closed due
+ // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // mark the socket as closed.
+ e.mu.Lock()
+ e.state = StateClose
+ e.mu.Unlock()
+ return nil
},
},
{
@@ -1085,17 +1202,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.mu.Unlock()
}
+
if n&notifyClose != 0 && closeTimer == nil {
- // Reset the connection 3 seconds after
- // the endpoint has been closed.
- //
- // The timer could fire in background
- // when the endpoint is drained. That's
- // OK as the loop here will not honor
- // the firing until the undrain arrives.
- closeTimer = time.AfterFunc(3*time.Second, func() {
- closeWaker.Assert()
- })
+ e.mu.Lock()
+ if e.state == StateFinWait2 && e.closed {
+ // The socket has been closed and we are in FIN_WAIT2
+ // so start the FIN_WAIT2 timer.
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
+ closeWaker.Assert()
+ })
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+ e.mu.Unlock()
}
if n&notifyKeepaliveChanged != 0 {
@@ -1117,6 +1235,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
}
+ if n&notifyTickleWorker != 0 {
+ // Just a tickle notification. No need to do
+ // anything.
+ return nil
+ }
+
return nil
},
},
@@ -1143,15 +1267,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.rcvListMu.Unlock()
- e.mu.RLock()
+ e.mu.Lock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
- e.mu.RUnlock()
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+
+ for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+ e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
@@ -1167,6 +1292,23 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return nil
}
+ e.mu.Lock()
+ }
+
+ state := e.state
+ e.mu.Unlock()
+ var reuseTW func()
+ if state == StateTimeWait {
+ // Disable close timer as we now entering real TIME_WAIT.
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+ // Mark the current sleeper done so as to free all associated
+ // wakers.
+ s.Done()
+ // Wake up any waiters before we enter TIME_WAIT.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
@@ -1176,8 +1318,130 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.stack.Stats().TCP.CurrentEstablished.Decrement()
e.state = StateClose
}
+
// Lock released below.
epilogue()
+ // A new SYN was received during TIME_WAIT and we need to abort
+ // the timewait and redirect the segment to the listener queue
+ if reuseTW != nil {
+ reuseTW()
+ }
+
return nil
}
+
+// handleTimeWaitSegments processes segments received during TIME_WAIT
+// state.
+func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+ extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
+ if newSyn {
+ info := e.EndpointInfo.TransportEndpointInfo
+ newID := info.ID
+ newID.RemoteAddress = ""
+ newID.RemotePort = 0
+ netProtos := []tcpip.NetworkProtocolNumber{info.NetProto}
+ // If the local address is an IPv4 address then also
+ // look for IPv6 dual stack endpoints that might be
+ // listening on the local address.
+ if newID.LocalAddress.To4() != "" {
+ netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
+ }
+ for _, netProto := range netProtos {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ tcpEP := listenEP.(*endpoint)
+ if EndpointState(tcpEP.State()) == StateListen {
+ reuseTW = func() {
+ tcpEP.enqueueSegment(s)
+ }
+ // We explicitly do not decRef
+ // the segment as it's still
+ // valid and being reflected to
+ // a listening endpoint.
+ return false, reuseTW
+ }
+ }
+ }
+ }
+ if extTW {
+ extendTimeWait = true
+ }
+ s.decRef()
+ }
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ return extendTimeWait, nil
+}
+
+// doTimeWait is responsible for handling the TCP behaviour once a socket
+// enters the TIME_WAIT state. Optionally it can return a closure that
+// should be executed after releasing the endpoint registrations. This is
+// done in cases where a new SYN is received during TIME_WAIT that carries
+// a sequence number larger than one see on the connection.
+func (e *endpoint) doTimeWait() (twReuse func()) {
+ // Trigger a 2 * MSL time wait state. During this period
+ // we will drop all incoming segments.
+ // NOTE: On Linux this is not configurable and is fixed at 60 seconds.
+ timeWaitDuration := DefaultTCPTimeWaitTimeout
+
+ // Get the stack wide configuration.
+ var tcpTW tcpip.TCPTimeWaitTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil {
+ timeWaitDuration = time.Duration(tcpTW)
+ }
+
+ const newSegment = 1
+ const notification = 2
+ const timeWaitDone = 3
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.newSegmentWaker, newSegment)
+ s.AddWaker(&e.notificationWaker, notification)
+
+ var timeWaitWaker sleep.Waker
+ s.AddWaker(&timeWaitWaker, timeWaitDone)
+ timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ defer timeWaitTimer.Stop()
+
+ for {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ switch v {
+ case newSegment:
+ extendTimeWait, reuseTW := e.handleTimeWaitSegments()
+ if reuseTW != nil {
+ return reuseTW
+ }
+ if extendTimeWait {
+ timeWaitTimer.Reset(timeWaitDuration)
+ }
+ case notification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ // Ignore extending TIME_WAIT during a
+ // save. For sockets in TIME_WAIT we just
+ // terminate the TIME_WAIT early.
+ e.handleTimeWaitSegments()
+ }
+ close(e.drainDone)
+ <-e.undrain
+ return nil
+ }
+ case timeWaitDone:
+ return nil
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 79fec6b77..04c92c04c 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -121,6 +121,11 @@ const (
notifyReset
notifyKeepaliveChanged
notifyMSSChanged
+ // notifyTickleWorker is used to tickle the protocol main loop during a
+ // restore after we update the endpoint state to the correct one. This
+ // ensures the loop terminates if the final state of the endpoint is
+ // say TIME_WAIT.
+ notifyTickleWorker
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -320,6 +325,11 @@ type endpoint struct {
state EndpointState `state:".(EndpointState)"`
+ // origEndpointState is only used during a restore phase to save the
+ // endpoint state at restore time as the socket is moved to it's correct
+ // state.
+ origEndpointState EndpointState `state:"nosave"`
+
isPortReserved bool `state:"manual"`
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
@@ -503,6 +513,16 @@ type endpoint struct {
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats Stats `state:"nosave"`
+
+ // tcpLingerTimeout is the maximum amount of a time a socket
+ // a socket stays in TIME_WAIT state before being marked
+ // closed.
+ tcpLingerTimeout time.Duration
+
+ // closed indicates that the user has called closed on the
+ // endpoint and at this point the endpoint is only around
+ // to complete the TCP shutdown.
+ closed bool
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -599,6 +619,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.SetSockOptInt(tcpip.DelayOption, 1)
}
+ var tcpLT tcpip.TCPLingerTimeoutOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil {
+ e.tcpLingerTimeout = time.Duration(tcpLT)
+ }
+
if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -686,6 +711,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
+ e.mu.Lock()
+ closed := e.closed
+ e.mu.Unlock()
+ if closed {
+ return
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -706,6 +738,8 @@ func (e *endpoint) Close() {
e.isPortReserved = false
}
+ // Mark endpoint as closed.
+ e.closed = true
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
@@ -731,9 +765,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
go func() {
defer close(done)
for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
+ n.notifyProtocolGoroutine(notifyReset)
n.Close()
}
}()
@@ -1349,6 +1381,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ if v < 0 {
+ // Same as effectively disabling TCPLinger timeout.
+ v = 0
+ }
+ var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
+ // We were unable to retrieve a stack config, just use
+ // the DefaultTCPLingerTimeout.
+ if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
+ stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
+ }
+ }
+ // Cap it to the stack wide TCPLinger timeout.
+ if v > stkTCPLingerTimeout {
+ v = stkTCPLingerTimeout
+ }
+ e.tcpLingerTimeout = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1562,6 +1616,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.RUnlock()
return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1696,7 +1756,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// src IP to ensure that for a given tuple (srcIP, destIP,
// destPort) the offset used as a starting point is the same to
// ensure that we can cycle through the port space effectively.
- h := jenkins.Sum32(e.stack.PortSeed())
+ h := jenkins.Sum32(e.stack.Seed())
h.Write([]byte(e.ID.LocalAddress))
h.Write([]byte(e.ID.RemoteAddress))
portBuf := make([]byte, 2)
@@ -1782,9 +1842,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
- defer e.mu.Unlock()
e.shutdownFlags |= flags
-
+ finQueued := false
switch {
case e.state.connected():
// Close for read.
@@ -1799,6 +1858,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
e.notifyProtocolGoroutine(notifyReset)
+ e.mu.Unlock()
return nil
}
}
@@ -1817,14 +1877,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
-
+ finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
e.sndBufMu.Unlock()
-
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
}
case e.state == StateListen:
@@ -1832,11 +1889,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
}
-
default:
+ e.mu.Unlock()
return tcpip.ErrNotConnected
}
-
+ e.mu.Unlock()
+ if finQueued {
+ if e.workMu.TryLock() {
+ e.handleClose()
+ e.workMu.Unlock()
+ } else {
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+ }
return nil
}
@@ -1928,12 +1994,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrWouldBlock
}
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
-
- return n, wq, nil
+ return n, n.waiterQueue, nil
}
// Bind binds the endpoint to a specific local port and optionally address.
@@ -2058,6 +2119,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.stack.Stats().TCP.ResetsReceived.Increment()
}
+ e.enqueueSegment(s)
+}
+
+func (e *endpoint) enqueueSegment(s *segment) {
// Send packet to worker goroutine.
if e.segmentQueue.enqueue(s) {
e.newSegmentWaker.Assert()
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 19f003b6b..7aa4c3f0e 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() {
}
fallthrough
case StateError, StateClose:
- for e.state == StateError && e.workerRunning {
+ for (e.state == StateError || e.state == StateClose) && e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
@@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
+ // Freeze segment queue before registering to prevent any segments
+ // from being delivered while it is being restored.
+ e.origEndpointState = e.state
+ // Restore the endpoint to InitialState as it will be moved to
+ // its origEndpointState during Resume.
+ e.state = StateInitial
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
+ state := e.origEndpointState
- state := e.state
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -189,7 +195,6 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
bind := func() {
- e.state = StateInitial
if len(e.BindAddr) == 0 {
e.BindAddr = e.ID.LocalAddress
}
@@ -219,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
+ e.mu.Lock()
+ e.state = e.origEndpointState
+ closed := e.closed
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ if state == StateFinWait2 && closed {
+ // If the endpoint has been closed then make sure we notify so
+ // that the FIN_WAIT2 timer is started after a restore.
+ e.notifyProtocolGoroutine(notifyClose)
+ }
connectedLoading.Done()
case StateListen:
tcpip.AsyncLoading.Add(1)
@@ -265,8 +280,11 @@ func (e *endpoint) Resume(s *stack.Stack) {
tcpip.AsyncLoading.Done()
}()
}
- fallthrough
+ e.state = StateClose
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
case StateError:
+ e.state = StateError
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index c8e4a0d7e..89b965c23 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,6 +23,7 @@ package tcp
import (
"strings"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -54,6 +55,14 @@ const (
// MaxUnprocessedSegments is the maximum number of unprocessed segments
// that can be queued for a given endpoint.
MaxUnprocessedSegments = 300
+
+ // DefaultTCPLingerTimeout is the amount of time that sockets linger in
+ // FIN_WAIT_2 state before being marked closed.
+ DefaultTCPLingerTimeout = 60 * time.Second
+
+ // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
+ // in TIME_WAIT state before being marked closed.
+ DefaultTCPTimeWaitTimeout = 60 * time.Second
)
// SACKEnabled option can be used to enable SACK support in the TCP
@@ -93,6 +102,8 @@ type protocol struct {
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
+ tcpLingerTimeout time.Duration
+ tcpTimeWaitTimeout time.Duration
}
// Number returns the tcp protocol number.
@@ -212,6 +223,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpLingerTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpTimeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -262,6 +291,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -274,5 +315,7 @@ func NewProtocol() stack.TransportProtocol {
recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
+ tcpLingerTimeout: DefaultTCPLingerTimeout,
+ tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index e90f9a7d9..068b90fb6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -18,6 +18,7 @@ import (
"container/heap"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -209,6 +210,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
switch r.ep.state {
case StateFinWait1:
r.ep.state = StateFinWait2
+ // Notify protocol goroutine that we have received an
+ // ACK to our FIN so that it can start the FIN_WAIT2
+ // timer to abort connection if the other side does
+ // not close within 2MSL.
+ r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
r.ep.state = StateTimeWait
case StateLastAck:
@@ -253,23 +259,105 @@ func (r *receiver) updateRTT() {
r.ep.rcvListMu.Unlock()
}
-// handleRcvdSegment handles TCP segments directed at the connection managed by
-// r as they arrive. It is called by the protocol main loop.
-func (r *receiver) handleRcvdSegment(s *segment) {
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+ r.ep.rcvListMu.Lock()
+ rcvClosed := r.ep.rcvClosed || r.closed
+ r.ep.rcvListMu.Unlock()
+
+ // If we are in one of the shutdown states then we need to do
+ // additional checks before we try and process the segment.
+ switch state {
+ case StateCloseWait, StateClosing, StateLastAck:
+ if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ s.decRef()
+ // Just drop the segment as we have
+ // already received a FIN and this
+ // segment is after the sequence number
+ // for the FIN.
+ return true, nil
+ }
+ fallthrough
+ case StateFinWait1:
+ fallthrough
+ case StateFinWait2:
+ // If we are closed for reads (either due to an
+ // incoming FIN or the user calling shutdown(..,
+ // SHUT_RD) then any data past the rcvNxt should
+ // trigger a RST.
+ endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ if state == StateFinWait1 {
+ break
+ }
+
+ // If it's a retransmission of an old data segment
+ // or a pure ACK then allow it.
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ s.logicalLen() == 0 {
+ break
+ }
+
+ // In FIN-WAIT2 if the socket is fully
+ // closed(not owned by application on our end
+ // then the only acceptable segment is a
+ // FIN. Since FIN can technically also carry
+ // data we verify that the segment carrying a
+ // FIN ends at exactly e.rcvNxt+1.
+ //
+ // From RFC793 page 25.
+ //
+ // For sequence number purposes, the SYN is
+ // considered to occur before the first actual
+ // data octet of the segment in which it occurs,
+ // while the FIN is considered to occur after
+ // the last actual data octet in a segment in
+ // which it occurs.
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ }
+
// We don't care about receive processing anymore if the receive side
// is closed.
- if r.closed {
- return
+ //
+ // NOTE: We still want to permit a FIN as it's possible only our
+ // end has closed and the peer is yet to send a FIN. Hence we
+ // compare only the payload.
+ segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ return true, nil
+ }
+ return false, nil
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+ r.ep.mu.RLock()
+ state := r.ep.state
+ closed := r.ep.closed
+ r.ep.mu.RUnlock()
+
+ if state != StateEstablished {
+ drop, err := r.handleRcvdSegmentClosing(s, state, closed)
+ if drop || err != nil {
+ return drop, err
+ }
}
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// If the sequence number range is outside the acceptable range, just
- // send an ACK. This is according to RFC 793, page 37.
+ // send an ACK and stop further processing of the segment.
+ // This is according to RFC 793, page 68.
if !r.acceptable(segSeq, segLen) {
r.ep.snd.sendAck()
- return
+ return true, nil
}
// Defer segment processing if it can't be consumed now.
@@ -288,7 +376,7 @@ func (r *receiver) handleRcvdSegment(s *segment) {
// have to retransmit.
r.ep.snd.sendAck()
}
- return
+ return false, nil
}
// Since we consumed a segment update the receiver's RTT estimate
@@ -315,4 +403,67 @@ func (r *receiver) handleRcvdSegment(s *segment) {
r.pendingBufUsed -= s.logicalLen()
s.decRef()
}
+ return false, nil
+}
+
+// handleTimeWaitSegment handles inbound segments received when the endpoint
+// has entered the TIME_WAIT state.
+func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) {
+ segSeq := s.sequenceNumber
+ segLen := seqnum.Size(s.data.Size())
+
+ // Just silently drop any RST packets in TIME_WAIT. We do not support
+ // TIME_WAIT assasination as a result we confirm w/ fix 1 as described
+ // in https://tools.ietf.org/html/rfc1337#section-3.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return false, false
+ }
+
+ // If it's a SYN and the sequence number is higher than any seen before
+ // for this connection then try and redirect it to a listening endpoint
+ // if available.
+ //
+ // RFC 1122:
+ // "When a connection is [...] on TIME-WAIT state [...]
+ // [a TCP] MAY accept a new SYN from the remote TCP to
+ // reopen the connection directly, if it:
+
+ // (1) assigns its initial sequence number for the new
+ // connection to be larger than the largest sequence
+ // number it used on the previous connection incarnation,
+ // and
+
+ // (2) returns to TIME-WAIT state if the SYN turns out
+ // to be an old duplicate".
+ if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+
+ return false, true
+ }
+
+ // Drop the segment if it does not contain an ACK.
+ if !s.flagIsSet(header.TCPFlagAck) {
+ return false, false
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if r.ep.sendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ }
+
+ if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ // If it's a FIN-ACK then resetTimeWait and send an ACK, as it
+ // indicates our final ACK could have been lost.
+ r.ep.snd.sendAck()
+ return true, false
+ }
+
+ // If the sequence number range is outside the acceptable range or
+ // carries data then just send an ACK. This is according to RFC 793,
+ // page 37.
+ //
+ // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
+ if segSeq != r.rcvNxt || segLen != 0 {
+ r.ep.snd.sendAck()
+ }
+ return false, false
}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 00347a215..a3c8d2353 100755
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -144,6 +144,8 @@ func (x *endpoint) save(m state.Map) {
m.Save("amss", &x.amss)
m.Save("sendTOS", &x.sendTOS)
m.Save("gso", &x.gso)
+ m.Save("tcpLingerTimeout", &x.tcpLingerTimeout)
+ m.Save("closed", &x.closed)
}
func (x *endpoint) load(m state.Map) {
@@ -194,6 +196,8 @@ func (x *endpoint) load(m state.Map) {
m.Load("amss", &x.amss)
m.Load("sendTOS", &x.sendTOS)
m.Load("gso", &x.gso)
+ m.Load("tcpLingerTimeout", &x.tcpLingerTimeout)
+ m.Load("closed", &x.closed)
m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) })
m.LoadValue("state", new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) })
m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })