summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xpkg/tcpip/stack/stack_state_autogen.go6
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go54
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/connect.go310
-rwxr-xr-xpkg/tcpip/transport/tcp/dispatcher.go218
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go303
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go30
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go11
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go21
-rw-r--r--pkg/tcpip/transport/tcp/snd.go14
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_endpoint_list.go173
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_state_autogen.go28
12 files changed, 902 insertions, 275 deletions
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
index 2551126f2..c16b2ccf6 100755
--- a/pkg/tcpip/stack/stack_state_autogen.go
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -102,6 +102,9 @@ func (x *TransportEndpointInfo) load(m state.Map) {
func (x *multiPortEndpoint) beforeSave() {}
func (x *multiPortEndpoint) save(m state.Map) {
x.beforeSave()
+ m.Save("demux", &x.demux)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
m.Save("endpointsArr", &x.endpointsArr)
m.Save("endpointsMap", &x.endpointsMap)
m.Save("reuse", &x.reuse)
@@ -109,6 +112,9 @@ func (x *multiPortEndpoint) save(m state.Map) {
func (x *multiPortEndpoint) afterLoad() {}
func (x *multiPortEndpoint) load(m state.Map) {
+ m.Load("demux", &x.demux)
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
m.Load("endpointsArr", &x.endpointsArr)
m.Load("endpointsMap", &x.endpointsMap)
m.Load("reuse", &x.reuse)
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index f384a91de..d686e6eb8 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
return
}
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt)
+ transEP := selectEndpoint(id, mpep, epsByNic.seed)
+ if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
+ queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ epsByNic.mu.RUnlock()
+ return
+ }
+
+ transEP.HandlePacket(r, id, pkt)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNic.mu.Lock()
defer epsByNic.mu.Unlock()
@@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort
}
// This is a new binding.
- multiPortEp := &multiPortEndpoint{}
+ multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto}
multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
multiPortEp.reuse = reusePort
epsByNic.endpoints[bindToDevice] = multiPortEp
@@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// newTransportDemuxer.
type transportDemuxer struct {
// protocol is immutable.
- protocol map[protocolIDs]*transportEndpoints
+ protocol map[protocolIDs]*transportEndpoints
+ queuedProtocols map[protocolIDs]queuedTransportProtocol
+}
+
+// queuedTransportProtocol if supported by a protocol implementation will cause
+// the dispatcher to delivery packets to the QueuePacket method instead of
+// calling HandlePacket directly on the endpoint.
+type queuedTransportProtocol interface {
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
- d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+ d := &transportDemuxer{
+ protocol: make(map[protocolIDs]*transportEndpoints),
+ queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
+ }
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ protoIDs := protocolIDs{netProto, proto}
+ d.protocol[protoIDs] = &transportEndpoints{
endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
+ qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
+ if isQueued {
+ d.queuedProtocols[protoIDs] = qTransProto
+ }
}
}
@@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ demux *transportDemuxer
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
@@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
ep.mu.RLock()
+ queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
for i, endpoint := range ep.endpointsArr {
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
if i == len(ep.endpointsArr)-1 {
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ break
+ }
endpoint.HandlePacket(r, id, pkt)
break
}
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ continue
+ }
endpoint.HandlePacket(r, id, pkt.Clone())
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
@@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if epsByNic, ok := eps.endpoints[id]; ok {
// There was already a binding.
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// This is a new binding.
@@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
}
eps.endpoints[id] = epsByNic
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 1ea996936..1a2e3efa9 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
if l.listenEP != nil {
l.listenEP.mu.Lock()
- if l.listenEP.state != StateListen {
+ if l.listenEP.EndpointState() != StateListen {
l.listenEP.mu.Unlock()
return nil, tcpip.ErrConnectionAborted
}
@@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() {
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.pendingAccepted.Add(1)
defer e.pendingAccepted.Done()
acceptedChan := e.acceptedChan
e.mu.Unlock()
+
if state == StateListen {
acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
@@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// We do not use transitionToStateEstablishedLocked here as there is
// no handshake state available when doing a SYN cookie based accept.
n.stack.Stats().TCP.CurrentEstablished.Increment()
- n.state = StateEstablished
n.isConnectNotified = true
+ n.setEndpointState(StateEstablished)
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
e.mu.Lock()
- e.state = StateClose
+ e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 613ec1775..f3896715b 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.mss = opts.MSS
h.sndWndScale = opts.WS
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
}
@@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// SYN-RCVD state.
h.state = handshakeSynRcvd
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
ttl := h.ep.ttl
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
// permits SACK. This is not explicitly defined in the RFC but
@@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
WS: h.rcvWndScale,
TS: h.ep.sendTSOk,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
@@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error {
WS: h.rcvWndScale,
TS: true,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
@@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
}
if e.sackPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error {
}
func (e *endpoint) handleClose() *tcpip.Error {
+ if !e.EndpointState().connected() {
+ return nil
+ }
// Drain the send queue.
e.handleWrite()
@@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
- if e.state == StateEstablished || e.state == StateCloseWait {
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
- }
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
@@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
}
// completeWorkerLocked is called by the worker goroutine when it's about to
-// exit. It marks the worker as completed and performs cleanup work if requested
-// by Close().
+// exit.
func (e *endpoint) completeWorkerLocked() {
+ // Worker is terminating(either due to moving to
+ // CLOSED or ERROR state, ensure we release all
+ // registrations port reservations even if the socket
+ // itself is not yet closed by the application.
e.workerRunning = false
if e.workerCleanup {
e.cleanupLocked()
@@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
e.rcvAutoParams.prevCopied = int(h.rcvWnd)
e.rcvListMu.Unlock()
}
- h.ep.stack.Stats().TCP.CurrentEstablished.Increment()
- e.state = StateEstablished
+ e.setEndpointState(StateEstablished)
}
// transitionToStateCloseLocked ensures that the endpoint is
@@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.state == StateClose {
+ if e.EndpointState() == StateClose {
return
}
+ // Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
- e.state = StateClose
+ e.setEndpointState(StateClose)
e.stack.Stats().TCP.EstablishedClosed.Increment()
}
@@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
s.decRef()
return
}
- ep.(*endpoint).enqueueSegment(s)
+ if ep.(*endpoint).enqueueSegment(s) {
+ ep.(*endpoint).newSegmentWaker.Assert()
+ }
}
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
@@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
- s.decRef()
e.mu.Lock()
- switch e.state {
+ switch e.EndpointState() {
// In case of a RST in CLOSE-WAIT linux moves
// the socket to closed state with an error set
// to indicate EPIPE.
@@ -981,103 +984,57 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
e.mu.Unlock()
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+
+ // Notify protocol goroutine. This is required when
+ // handleSegment is invoked from the processor goroutine
+ // rather than the worker goroutine.
+ e.notifyProtocolGoroutine(notifyResetByPeer)
return false, tcpip.ErrConnectionReset
}
}
return true, nil
}
-// handleSegments pulls segments from the queue and processes them. It returns
-// no error if the protocol loop should continue, an error otherwise.
-func (e *endpoint) handleSegments() *tcpip.Error {
+// handleSegments processes all inbound segments.
+func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
+ if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ return nil
+ }
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
break
}
- // Invoke the tcp probe if installed.
- if e.probe != nil {
- e.probe(e.completeState())
+ cont, err := e.handleSegment(s)
+ if err != nil {
+ s.decRef()
+ e.mu.Lock()
+ e.setEndpointState(StateError)
+ e.HardError = err
+ e.mu.Unlock()
+ return err
}
-
- if s.flagIsSet(header.TCPFlagRst) {
- if ok, err := e.handleReset(s); !ok {
- return err
- }
- } else if s.flagIsSet(header.TCPFlagSyn) {
- // See: https://tools.ietf.org/html/rfc5961#section-4.1
- // 1) If the SYN bit is set, irrespective of the sequence number, TCP
- // MUST send an ACK (also referred to as challenge ACK) to the remote
- // peer:
- //
- // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
- //
- // After sending the acknowledgment, TCP MUST drop the unacceptable
- // segment and stop processing further.
- //
- // By sending an ACK, the remote peer is challenged to confirm the loss
- // of the previous connection and the request to start a new connection.
- // A legitimate peer, after restart, would not have a TCB in the
- // synchronized state. Thus, when the ACK arrives, the peer should send
- // a RST segment back with the sequence number derived from the ACK
- // field that caused the RST.
-
- // This RST will confirm that the remote peer has indeed closed the
- // previous connection. Upon receipt of a valid RST, the local TCP
- // endpoint MUST terminate its connection. The local TCP endpoint
- // should then rely on SYN retransmission from the remote end to
- // re-establish the connection.
-
- e.snd.sendAck()
- } else if s.flagIsSet(header.TCPFlagAck) {
- // Patch the window size in the segment according to the
- // send window scale.
- s.window <<= e.snd.sndWndScale
-
- // RFC 793, page 41 states that "once in the ESTABLISHED
- // state all segments must carry current acknowledgment
- // information."
- drop, err := e.rcv.handleRcvdSegment(s)
- if err != nil {
- s.decRef()
- return err
- }
- if drop {
- s.decRef()
- continue
- }
-
- // Now check if the received segment has caused us to transition
- // to a CLOSED state, if yes then terminate processing and do
- // not invoke the sender.
- e.mu.RLock()
- state := e.state
- e.mu.RUnlock()
- if state == StateClose {
- // When we get into StateClose while processing from the queue,
- // return immediately and let the protocolMainloop handle it.
- //
- // We can reach StateClose only while processing a previous segment
- // or a notification from the protocolMainLoop (caller goroutine).
- // This means that with this return, the segment dequeue below can
- // never occur on a closed endpoint.
- s.decRef()
- return nil
- }
- e.snd.handleRcvdSegment(s)
+ if !cont {
+ s.decRef()
+ return nil
}
- s.decRef()
}
- // If the queue is not empty, make sure we'll wake up in the next
- // iteration.
- if checkRequeue && !e.segmentQueue.empty() {
+ // When fastPath is true we don't want to wake up the worker
+ // goroutine. If the endpoint has more segments to process the
+ // dispatcher will call handleSegments again anyway.
+ if !fastPath && checkRequeue && !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
@@ -1086,11 +1043,88 @@ func (e *endpoint) handleSegments() *tcpip.Error {
e.snd.sendAck()
}
- e.resetKeepaliveTimer(true)
+ e.resetKeepaliveTimer(true /* receivedData */)
return nil
}
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if ok, err := e.handleReset(s); !ok {
+ return false, err
+ }
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ return false, err
+ }
+ if drop {
+ return true, nil
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return false, nil
+ }
+
+ e.snd.handleRcvdSegment(s)
+ }
+
+ return true, nil
+}
+
// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
@@ -1160,7 +1194,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1182,6 +1216,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
+ e.workMu.Unlock()
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1193,7 +1228,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
e.mu.Lock()
- h.ep.state = StateSynSent
+ h.ep.setEndpointState(StateSynSent)
e.mu.Unlock()
if err := h.execute(); err != nil {
@@ -1202,12 +1237,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.mu.Lock()
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
// Lock released below.
epilogue()
-
return err
}
}
@@ -1215,7 +1249,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- // Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
drained := e.drainDone != nil
e.mu.Unlock()
@@ -1224,8 +1257,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
<-e.undrain
}
- e.waiterQueue.Notify(waiter.EventOut)
-
// Set up the functions that will be called when the main protocol loop
// wakes up.
funcs := []struct {
@@ -1241,17 +1272,14 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
f: e.handleClose,
},
{
- w: &e.newSegmentWaker,
- f: e.handleSegments,
- },
- {
w: &closeWaker,
f: func() *tcpip.Error {
// This means the socket is being closed due
- // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
e.mu.Lock()
e.transitionToStateCloseLocked()
+ e.workerCleanup = true
e.mu.Unlock()
return nil
},
@@ -1267,6 +1295,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
},
},
{
+ w: &e.newSegmentWaker,
+ f: func() *tcpip.Error {
+ return e.handleSegments(false /* fastPath */)
+ },
+ },
+ {
w: &e.keepalive.waker,
f: e.keepaliveTimerExpired,
},
@@ -1293,14 +1327,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
if n&notifyReset != 0 {
- e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.mu.Unlock()
+ return tcpip.ErrConnectionAborted
+ }
+
+ if n&notifyResetByPeer != 0 {
+ return tcpip.ErrConnectionReset
}
if n&notifyClose != 0 && closeTimer == nil {
e.mu.Lock()
- if e.state == StateFinWait2 && e.closed {
+ if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
@@ -1320,11 +1356,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
- if err := e.handleSegments(); err != nil {
+ if err := e.handleSegments(false /* fastPath */); err != nil {
return err
}
}
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1349,14 +1385,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.AddWaker(funcs[i].w, i)
}
+ // Notify the caller that the waker initialization is complete and the
+ // endpoint is ready.
+ if wakerInitDone != nil {
+ close(wakerInitDone)
+ }
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.waiterQueue.Notify(waiter.EventOut)
+
// The following assertions and notifications are needed for restored
// endpoints. Fresh newly created endpoints have empty states and should
// not invoke any.
- e.segmentQueue.mu.Lock()
- if !e.segmentQueue.list.Empty() {
+ if !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
- e.segmentQueue.mu.Unlock()
e.rcvListMu.Lock()
if !e.rcvList.Empty() {
@@ -1372,27 +1415,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+ for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
+ // We need to double check here because the notification
+ // maybe stale by the time we got around to processing it.
+ // NOTE: since we now hold the workMu the processors cannot
+ // change the state of the endpoint so it' safe to proceed
+ // after this check.
+ if e.EndpointState() == StateTimeWait || e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ e.mu.Lock()
+ break
+ }
if err := funcs[v].f(); err != nil {
e.mu.Lock()
- // Ensure we release all endpoint registration and route
- // references as the connection is now in an error
- // state.
e.workerCleanup = true
e.resetConnectionLocked(err)
// Lock released below.
epilogue()
-
return nil
}
e.mu.Lock()
}
- state := e.state
+ state := e.EndpointState()
e.mu.Unlock()
var reuseTW func()
if state == StateTimeWait {
@@ -1405,13 +1453,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.Done()
// Wake up any waiters before we enter TIME_WAIT.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.mu.Lock()
+ e.workerCleanup = true
+ e.mu.Unlock()
reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
e.mu.Lock()
- if e.state != StateError {
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ if e.EndpointState() != StateError {
e.transitionToStateCloseLocked()
}
@@ -1468,7 +1518,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
- tcpEP.enqueueSegment(s)
+ if !tcpEP.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+ tcpEP.newSegmentWaker.Assert()
}
// We explicitly do not decRef
// the segment as it's still
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
new file mode 100755
index 000000000..a72f0c379
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -0,0 +1,218 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// epQueue is a queue of endpoints.
+type epQueue struct {
+ mu sync.Mutex
+ list endpointList
+}
+
+// enqueue adds e to the queue if the endpoint is not already on the queue.
+func (q *epQueue) enqueue(e *endpoint) {
+ q.mu.Lock()
+ if e.pendingProcessing {
+ q.mu.Unlock()
+ return
+ }
+ q.list.PushBack(e)
+ e.pendingProcessing = true
+ q.mu.Unlock()
+}
+
+// dequeue removes and returns the first element from the queue if available,
+// returns nil otherwise.
+func (q *epQueue) dequeue() *endpoint {
+ q.mu.Lock()
+ if e := q.list.Front(); e != nil {
+ q.list.Remove(e)
+ e.pendingProcessing = false
+ q.mu.Unlock()
+ return e
+ }
+ q.mu.Unlock()
+ return nil
+}
+
+// empty returns true if the queue is empty, false otherwise.
+func (q *epQueue) empty() bool {
+ q.mu.Lock()
+ v := q.list.Empty()
+ q.mu.Unlock()
+ return v
+}
+
+// processor is responsible for processing packets queued to a tcp endpoint.
+type processor struct {
+ epQ epQueue
+ newEndpointWaker sleep.Waker
+ id int
+}
+
+func newProcessor(id int) *processor {
+ p := &processor{
+ id: id,
+ }
+ go p.handleSegments()
+ return p
+}
+
+func (p *processor) queueEndpoint(ep *endpoint) {
+ // Queue an endpoint for processing by the processor goroutine.
+ p.epQ.enqueue(ep)
+ p.newEndpointWaker.Assert()
+}
+
+func (p *processor) handleSegments() {
+ const newEndpointWaker = 1
+ s := sleep.Sleeper{}
+ s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ defer s.Done()
+ for {
+ s.Fetch(true)
+ for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ if ep.segmentQueue.empty() {
+ continue
+ }
+
+ // If socket has transitioned out of connected state
+ // then just let the worker handle the packet.
+ //
+ // NOTE: We read this outside of e.mu lock which means
+ // that by the time we get to handleSegments the
+ // endpoint may not be in ESTABLISHED. But this should
+ // be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a
+ // CLOSED/ERROR state then handleSegments is a noop.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+
+ if !ep.workMu.TryLock() {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+ // If the endpoint is in a connected state then we do
+ // direct delivery to ensure low latency and avoid
+ // scheduler interactions.
+ if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ ep.workMu.Unlock()
+ continue
+ }
+
+ if !ep.segmentQueue.empty() {
+ p.epQ.enqueue(ep)
+ }
+
+ ep.workMu.Unlock()
+ }
+ }
+}
+
+// dispatcher manages a pool of TCP endpoint processors which are responsible
+// for the processing of inbound segments. This fixed pool of processor
+// goroutines do full tcp processing. The processor is selected based on the
+// hash of the endpoint id to ensure that delivery for the same endpoint happens
+// in-order.
+type dispatcher struct {
+ processors []*processor
+ seed uint32
+}
+
+func newDispatcher(nProcessors int) *dispatcher {
+ processors := []*processor{}
+ for i := 0; i < nProcessors; i++ {
+ processors = append(processors, newProcessor(i))
+ }
+ return &dispatcher{
+ processors: processors,
+ seed: generateRandUint32(),
+ }
+}
+
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ ep := stackEP.(*endpoint)
+ s := newSegment(r, id, pkt)
+ if !s.parse() {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.ChecksumErrors.Increment()
+ ep.stats.ReceiveErrors.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ ep.stats.SegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ ep.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ if !ep.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+
+ // For sockets not in established state let the worker goroutine
+ // handle the packets.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ return
+ }
+
+ d.selectProcessor(id).queueEndpoint(ep)
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8)}
+
+ h := jenkins.Sum32(d.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+
+ return d.processors[h.Sum32()%uint32(len(d.processors))]
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index cc8b533c8..1799c6e10 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -120,6 +120,7 @@ const (
notifyMTUChanged
notifyDrain
notifyReset
+ notifyResetByPeer
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -127,6 +128,7 @@ const (
// ensures the loop terminates if the final state of the endpoint is
// say TIME_WAIT.
notifyTickleWorker
+ notifyError
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {}
type endpoint struct {
EndpointInfo
+ // endpointEntry is used to queue endpoints for processing to the
+ // a given tcp processor goroutine.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ endpointEntry `state:"nosave"`
+
+ // pendingProcessing is true if this endpoint is queued for processing
+ // to a TCP processor.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ pendingProcessing bool `state:"nosave"`
+
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
// it, but other goroutines (e.g., send) may call TryLock() to eagerly
@@ -324,6 +338,7 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
+ // state must be read/set using the EndpointState()/setEndpointState() methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -359,7 +374,7 @@ type endpoint struct {
workerRunning bool
// workerCleanup specifies if the worker goroutine must perform cleanup
- // before exitting. This can only be set to true when workerRunning is
+ // before exiting. This can only be set to true when workerRunning is
// also true, and they're both protected by the mutex.
workerCleanup bool
@@ -371,6 +386,8 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
+ //
+ // recentTS must be read/written atomically.
recentTS uint32
// tsOffset is a randomized offset added to the value of the
@@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() {
e.workMu.Unlock()
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ switch state {
+ case StateEstablished:
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ case StateError:
+ fallthrough
+ case StateClose:
+ if oldstate == StateCloseWait || oldstate == StateEstablished {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ }
+ fallthrough
+ default:
+ if oldstate == StateEstablished {
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
+ }
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
+// setRecentTimestamp atomically sets the recentTS field to the
+// provided value.
+func (e *endpoint) setRecentTimestamp(recentTS uint32) {
+ atomic.StoreUint32(&e.recentTS, recentTS)
+}
+
+// recentTimestamp atomically reads and returns the value of the recentTS field.
+func (e *endpoint) recentTimestamp() uint32 {
+ return atomic.LoadUint32(&e.recentTS)
+}
+
// keepalive is a synchronization wrapper used to appease stateify. See the
// comment in endpoint, where it is used.
//
@@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.mu.RLock()
defer e.mu.RUnlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
@@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
}
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -733,14 +791,20 @@ func (e *endpoint) Close() {
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdown()
+}
+// closeNoShutdown closes the endpoint without doing a full shutdown. This is
+// used when a connection needs to be aborted with a RST and we want to skip
+// a full 4 way TCP shutdown.
+func (e *endpoint) closeNoShutdown() {
e.mu.Lock()
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
// in Listen() when trying to register.
- if e.state == StateListen && e.isPortReserved {
+ if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
@@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
defer close(done)
for n := range e.acceptedChan {
n.notifyProtocolGoroutine(notifyReset)
+ // close all connections that have completed but
+ // not accepted by the application.
n.Close()
}
}()
@@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
// after Close() is called and the worker goroutine (if any) is done with its
// work.
func (e *endpoint) cleanupLocked() {
+
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
e.closePendingAcceptableConnectionsLocked()
}
+
e.workerCleanup = false
if e.isRegistered {
@@ -920,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
- if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
+ if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
e.mu.RUnlock()
@@ -944,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -980,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
- if !e.state.connected() {
- switch e.state {
+ if !e.EndpointState().connected() {
+ switch e.EndpointState() {
case StateError:
return 0, e.HardError
default:
@@ -1039,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, perr
}
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if opts.Atomic {
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+ }
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
+ if e.workMu.TryLock() {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
}
- }
-
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
-
- if e.workMu.TryLock() {
// Do the work inline.
e.handleWrite()
e.workMu.Unlock()
} else {
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+
+ }
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
@@ -1091,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
- if s := e.state; !s.connected() && s != StateClose {
+ if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
return 0, tcpip.ControlMessages{}, e.HardError
}
@@ -1103,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
@@ -1187,7 +1299,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1402,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Acquire the work mutex as we may need to
// reinitialize the congestion control state.
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.cc = v
e.mu.Unlock()
switch state {
case StateEstablished:
e.workMu.Lock()
e.mu.Lock()
- if e.state == state {
+ if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
e.mu.Unlock()
@@ -1472,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
defer e.mu.RUnlock()
// The endpoint cannot be in listen state.
- if e.state == StateListen {
+ if e.EndpointState() == StateListen {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1731,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return err
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// The endpoint is already connected. If caller hasn't been
// notified yet, return success.
if !e.isConnectNotified {
@@ -1743,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
nicID := addr.NIC
- switch e.state {
+ switch e.EndpointState() {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
@@ -1850,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.isRegistered = true
- e.state = StateConnecting
+ e.setEndpointState(StateConnecting)
e.route = r.Clone()
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1871,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
- e.state = StateEstablished
- e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.setEndpointState(StateEstablished)
}
if run {
e.workerRunning = true
e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
return tcpip.ErrConnectStarted
@@ -1896,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.shutdownFlags |= flags
finQueued := false
switch {
- case e.state.connected():
+ case e.EndpointState().connected():
// Close for read.
if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
// Mark read side as closed.
@@ -1908,8 +2019,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.notifyProtocolGoroutine(notifyReset)
+ // Move the socket to error state immediately.
+ // This is done redundantly because in case of
+ // save/restore on a Shutdown/Close() the socket
+ // state needs to indicate the error otherwise
+ // save file will show the socket in established
+ // state even though snd/rcv are closed.
e.mu.Unlock()
+ // Try to send an active reset immediately if the
+ // work mutex is available.
+ if e.workMu.TryLock() {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ } else {
+ e.notifyProtocolGoroutine(notifyReset)
+ }
return nil
}
}
@@ -1931,11 +2057,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
-
e.sndBufMu.Unlock()
}
- case e.state == StateListen:
+ case e.EndpointState() == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
@@ -1976,7 +2101,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// When the endpoint shuts down, it sets workerCleanup to true, and from
// that point onward, acceptedChan is the responsibility of the cleanup()
// method (and should not be touched anywhere else, including here).
- if e.state == StateListen && !e.workerCleanup {
+ if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
if len(e.acceptedChan) > backlog {
@@ -1994,7 +2119,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
return nil
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
// The listen is called on an unbound socket, the socket is
// automatically bound to a random free port with the local
// address set to INADDR_ANY.
@@ -2004,7 +2129,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Endpoint must be bound before it can transition to listen mode.
- if e.state != StateBound {
+ if e.EndpointState() != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
@@ -2015,24 +2140,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
e.isRegistered = true
- e.state = StateListen
+ e.setEndpointState(StateListen)
+
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
e.workerRunning = true
-
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
-
return nil
}
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.mu.Lock()
e.waiterQueue = waiterQueue
e.workerRunning = true
- go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+ e.mu.Unlock()
+ wakerInitDone := make(chan struct{})
+ go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save.
+ <-wakerInitDone
}
// Accept returns a new endpoint if a peer has established a connection
@@ -2042,7 +2170,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
- if e.state != StateListen {
+ if e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
@@ -2069,7 +2197,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrAlreadyBound
}
@@ -2131,7 +2259,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
return nil
}
@@ -2153,7 +2281,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if !e.state.connected() {
+ if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -2164,45 +2292,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
- s := newSegment(r, id, pkt)
- if !s.parse() {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
- e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
- s.decRef()
- return
- }
-
- if !s.csumValid {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.ChecksumErrors.Increment()
- e.stats.ReceiveErrors.ChecksumErrors.Increment()
- s.decRef()
- return
- }
-
- e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
- e.stats.SegmentsReceived.Increment()
- if (s.flags & header.TCPFlagRst) != 0 {
- e.stack.Stats().TCP.ResetsReceived.Increment()
- }
-
- e.enqueueSegment(s)
+ // TCP HandlePacket is not required anymore as inbound packets first
+ // land at the Dispatcher which then can either delivery using the
+ // worker go routine or directly do the invoke the tcp processing inline
+ // based on the state of the endpoint.
}
-func (e *endpoint) enqueueSegment(s *segment) {
+func (e *endpoint) enqueueSegment(s *segment) bool {
// Send packet to worker goroutine.
- if e.segmentQueue.enqueue(s) {
- e.newSegmentWaker.Assert()
- } else {
+ if !e.segmentQueue.enqueue(s) {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
- s.decRef()
+ return false
}
+ return true
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
@@ -2319,8 +2424,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int {
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
- if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
- e.recentTS = tsVal
+ if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.setRecentTimestamp(tsVal)
}
}
@@ -2330,7 +2435,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
e.sendTSOk = true
- e.recentTS = synOpts.TSVal
+ e.setRecentTimestamp(synOpts.TSVal)
}
}
@@ -2419,7 +2524,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Endpoint TCP Option state.
s.SendTSOk = e.sendTSOk
- s.RecentTS = e.recentTS
+ s.RecentTS = e.recentTimestamp()
s.TSOffset = e.tsOffset
s.SACKPermitted = e.sackPermitted
s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
@@ -2526,9 +2631,7 @@ func (e *endpoint) initGSO() {
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 4b8d867bc..4a46f0ec5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -16,6 +16,7 @@ package tcp
import (
"fmt"
+ "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound:
// TODO(b/138137272): this enumeration duplicates
// EndpointState.connected. remove it.
@@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() {
fallthrough
case StateListen, StateConnecting:
e.drainSegmentLocked()
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
break
}
- fallthrough
case StateError, StateClose:
- for (e.state == StateError || e.state == StateClose) && e.workerRunning {
+ for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
}
if e.workerRunning {
- panic("endpoint still has worker running in closed or error state")
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
}
default:
- panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
}
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
- if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) {
+ if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
}
}
@@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
// saveState is invoked by stateify.
func (e *endpoint) saveState() EndpointState {
- return e.state
+ return e.EndpointState()
}
// Endpoint loading must be done in the following ordering by their state, to
@@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup
func (e *endpoint) loadState(state EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
- if state.connected() {
+ // For restore purposes we treat TimeWait like a connected endpoint.
+ if state.connected() || state == StateTimeWait {
connectedLoading.Add(1)
}
switch state {
@@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) {
case StateConnecting, StateSynSent, StateSynRecv:
connectingLoading.Add(1)
}
- e.state = state
+ // Directly update the state here rather than using e.setEndpointState
+ // as the endpoint is still being loaded and the stack reference to increment
+ // metrics is not yet initialized.
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
}
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- // Freeze segment queue before registering to prevent any segments
- // from being delivered while it is being restored.
e.origEndpointState = e.state
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
@@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
state := e.origEndpointState
-
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Wait()
connectingLoading.Wait()
bind()
- e.state = StateClose
+ e.setEndpointState(StateClose)
tcpip.AsyncLoading.Done()
}()
}
@@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
+
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 9a8f64aa6..958c06fa7 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -21,6 +21,7 @@
package tcp
import (
+ "runtime"
"strings"
"time"
@@ -104,6 +105,7 @@ type protocol struct {
moderateReceiveBuffer bool
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
+ dispatcher *dispatcher
}
// Number returns the tcp protocol number.
@@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
+// QueuePacket queues packets targeted at an endpoint after hashing the packet
+// to a specific processing queue. Each queue is serviced by its own processor
+// goroutine which is responsible for dequeuing and doing full TCP dispatch of
+// the packet.
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ p.dispatcher.queuePacket(r, ep, id, pkt)
+}
+
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
//
@@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 05c8488f8..958f03ac1 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -169,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateEstablished:
- r.ep.state = StateCloseWait
+ r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
if s.flagIsSet(header.TCPFlagAck) {
// FIN-ACK, transition to TIME-WAIT.
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
} else {
// Simultaneous close, expecting a final ACK.
- r.ep.state = StateClosing
+ r.ep.setEndpointState(StateClosing)
}
case StateFinWait2:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
}
r.ep.mu.Unlock()
@@ -205,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateFinWait1:
- r.ep.state = StateFinWait2
+ r.ep.setEndpointState(StateFinWait2)
// Notify protocol goroutine that we have received an
// ACK to our FIN so that it can start the FIN_WAIT2
// timer to abort connection if the other side does
// not close within 2MSL.
r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
case StateLastAck:
r.ep.transitionToStateCloseLocked()
}
@@ -267,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
switch state {
case StateCloseWait, StateClosing, StateLastAck:
if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
- s.decRef()
// Just drop the segment as we have
// already received a FIN and this
// segment is after the sequence number
@@ -284,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
if state == StateFinWait1 {
@@ -314,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// the last actual data octet in a segment in
// which it occurs.
if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
}
@@ -336,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
r.ep.mu.RLock()
- state := r.ep.state
+ state := r.ep.EndpointState()
closed := r.ep.closed
r.ep.mu.RUnlock()
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index fdff7ed81..b74b61e7d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -705,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
seg.flags = header.TCPFlagAck | header.TCPFlagFin
segEnd = seg.sequenceNumber.Add(1)
- // Transition to FIN-WAIT1 state since we're initiating an active close.
- s.ep.mu.Lock()
- switch s.ep.state {
+ // Update the state to reflect that we have now
+ // queued a FIN.
+ switch s.ep.EndpointState() {
case StateCloseWait:
- // We've already received a FIN and are now sending our own. The
- // sender is now awaiting a final ACK for this FIN.
- s.ep.state = StateLastAck
+ s.ep.setEndpointState(StateLastAck)
default:
- s.ep.state = StateFinWait1
+ s.ep.setEndpointState(StateFinWait1)
}
- s.ep.mu.Unlock()
+
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
new file mode 100755
index 000000000..42e190060
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
@@ -0,0 +1,173 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type endpointElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type endpointList struct {
+ head *endpoint
+ tail *endpoint
+}
+
+// Reset resets list l to the empty state.
+func (l *endpointList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *endpointList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *endpointList) Front() *endpoint {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *endpointList) Back() *endpoint {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *endpointList) PushFront(e *endpoint) {
+ endpointElementMapper{}.linkerFor(e).SetNext(l.head)
+ endpointElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ endpointElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *endpointList) PushBack(e *endpoint) {
+ endpointElementMapper{}.linkerFor(e).SetNext(nil)
+ endpointElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *endpointList) PushBackList(m *endpointList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *endpointList) InsertAfter(b, e *endpoint) {
+ a := endpointElementMapper{}.linkerFor(b).Next()
+ endpointElementMapper{}.linkerFor(e).SetNext(a)
+ endpointElementMapper{}.linkerFor(e).SetPrev(b)
+ endpointElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ endpointElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *endpointList) InsertBefore(a, e *endpoint) {
+ b := endpointElementMapper{}.linkerFor(a).Prev()
+ endpointElementMapper{}.linkerFor(e).SetNext(a)
+ endpointElementMapper{}.linkerFor(e).SetPrev(b)
+ endpointElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ endpointElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *endpointList) Remove(e *endpoint) {
+ prev := endpointElementMapper{}.linkerFor(e).Prev()
+ next := endpointElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ endpointElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ endpointElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type endpointEntry struct {
+ next *endpoint
+ prev *endpoint
+}
+
+// Next returns the entry that follows e in the list.
+func (e *endpointEntry) Next() *endpoint {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *endpointEntry) Prev() *endpoint {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *endpointEntry) SetNext(elem *endpoint) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *endpointEntry) SetPrev(elem *endpoint) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index d0d188b3d..700e7958a 100755
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -454,6 +454,32 @@ func (x *unixTime) load(m state.Map) {
m.Load("nano", &x.nano)
}
+func (x *endpointList) beforeSave() {}
+func (x *endpointList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *endpointList) afterLoad() {}
+func (x *endpointList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *endpointEntry) beforeSave() {}
+func (x *endpointEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *endpointEntry) afterLoad() {}
+func (x *endpointEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
func (x *segmentList) beforeSave() {}
func (x *segmentList) save(m state.Map) {
x.beforeSave()
@@ -496,6 +522,8 @@ func init() {
state.Register("tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load})
state.Register("tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load})
state.Register("tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load})
+ state.Register("tcp.endpointList", (*endpointList)(nil), state.Fns{Save: (*endpointList).save, Load: (*endpointList).load})
+ state.Register("tcp.endpointEntry", (*endpointEntry)(nil), state.Fns{Save: (*endpointEntry).save, Load: (*endpointEntry).load})
state.Register("tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load})
state.Register("tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load})
}