summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go34
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go90
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go38
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go37
4 files changed, 103 insertions, 96 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index a485064a1..1c54dc180 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -385,8 +385,8 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state (acceptedChan is nil),
-// the new endpoint is closed instead.
+// listener has transitioned out of the listen state (accepted is the zero
+// value), the new endpoint is reset instead.
func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Lock()
e.pendingAccepted.Add(1)
@@ -395,23 +395,23 @@ func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.acceptMu.Lock()
for {
- if e.acceptedChan == nil {
- e.acceptMu.Unlock()
+ if e.accepted == (accepted{}) {
n.notifyProtocolGoroutine(notifyReset)
- return
+ break
}
- select {
- case e.acceptedChan <- n:
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
- }
- e.acceptMu.Unlock()
- e.waiterQueue.Notify(waiter.ReadableEvents)
- return
- default:
+ if e.accepted.endpoints.Len() == e.accepted.cap {
e.acceptCond.Wait()
+ continue
}
+
+ e.accepted.endpoints.PushBack(n)
+ if !withSynCookie {
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ }
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ break
}
+ e.acceptMu.Unlock()
}
// propagateInheritableOptionsLocked propagates any options set on the listening
@@ -499,7 +499,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) synRcvdBacklogFull() bool {
e.acceptMu.Lock()
- acceptedChanCap := cap(e.acceptedChan)
+ backlog := e.accepted.cap
e.acceptMu.Unlock()
// The allocated accepted channel size would always be one greater than the
// listen backlog. But, the SYNRCVD connections count is always checked
@@ -509,12 +509,12 @@ func (e *endpoint) synRcvdBacklogFull() bool {
// We maintain an equality check here as the synRcvdCount is incremented
// and compared only from a single listener context and the capacity of
// the accepted channel can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedChanCap-1
+ return int(atomic.LoadInt32(&e.synRcvdCount)) == backlog-1
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan) == cap(e.acceptedChan)
+ full := e.accepted.endpoints.Len() == e.accepted.cap
e.acceptMu.Unlock()
return full
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 5001d222e..9fbaf6f4b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "container/list"
"encoding/binary"
"fmt"
"io"
@@ -322,6 +323,15 @@ type EndpointInfo struct {
// marker interface.
func (*EndpointInfo) IsEndpointInfo() {}
+// +stateify savable
+type accepted struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
+ cap int
+}
+
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -337,7 +347,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects acceptedChan.
+// e.acceptMu -> protects accepted.
// e.rcvListMu -> Protects the rcvList and associated fields.
// e.sndBufMu -> Protects the sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -607,33 +617,26 @@ type endpoint struct {
// listener.
deferAccept time.Duration
- // pendingAccepted is a synchronization primitive used to track number
- // of connections that are queued up to be delivered to the accepted
- // channel. We use this to ensure that all goroutines blocked on writing
- // to the acceptedChan below terminate before we close acceptedChan.
+ // pendingAccepted tracks connections queued to be accepted. It is used to
+ // ensure such queued connections are terminated before the accepted queue is
+ // marked closed (by setting its capacity to zero).
pendingAccepted sync.WaitGroup `state:"nosave"`
- // acceptMu protects acceptedChan.
+ // acceptMu protects accepted.
acceptMu sync.Mutex `state:"nosave"`
// acceptCond is a condition variable that can be used to block on when
- // acceptedChan is full and an endpoint is ready to be delivered.
- //
- // This condition variable is required because just blocking on sending
- // to acceptedChan does not work in cases where endpoint.Listen is
- // called twice with different backlog values. In such cases the channel
- // is closed and a new one created. Any pending goroutines blocking on
- // the write to the channel will panic.
+ // accepted is full and an endpoint is ready to be delivered.
//
// We use this condition variable to block/unblock goroutines which
// tried to deliver an endpoint but couldn't because accept backlog was
// full ( See: endpoint.deliverAccepted ).
acceptCond *sync.Cond `state:"nosave"`
- // acceptedChan is used by a listening endpoint protocol goroutine to
+ // accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- acceptedChan chan *endpoint `state:".([]*endpoint)"`
+ accepted accepted
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -962,7 +965,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check if there's anything in the accepted channel.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if len(e.acceptedChan) > 0 {
+ if e.accepted.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -1145,22 +1148,22 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- if e.acceptedChan == nil {
- e.acceptMu.Unlock()
+ acceptedCopy := e.accepted
+ e.accepted = accepted{}
+ e.acceptMu.Unlock()
+
+ if acceptedCopy == (accepted{}) {
return
}
- close(e.acceptedChan)
- ch := e.acceptedChan
- e.acceptedChan = nil
+
e.acceptCond.Broadcast()
- e.acceptMu.Unlock()
// Reset all connections that are waiting to be accepted.
- for n := range ch {
- n.notifyProtocolGoroutine(notifyReset)
+ for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
// Wait for reset of all endpoints that are still waiting to be delivered to
- // the now closed acceptedChan.
+ // the now closed accepted.
e.pendingAccepted.Wait()
}
@@ -2495,28 +2498,20 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.acceptedChan == nil {
+ if e.accepted == (accepted{}) {
// listen is called after shutdown.
- e.acceptedChan = make(chan *endpoint, backlog)
+ e.accepted.cap = backlog
e.shutdownFlags = 0
e.rcvListMu.Lock()
e.rcvClosed = false
e.rcvListMu.Unlock()
} else {
- // Adjust the size of the channel iff we can fix
+ // Adjust the size of the backlog iff we can fit
// existing pending connections into the new one.
- if len(e.acceptedChan) > backlog {
+ if e.accepted.endpoints.Len() > backlog {
return &tcpip.ErrInvalidEndpointState{}
}
- if cap(e.acceptedChan) == backlog {
- return nil
- }
- origChan := e.acceptedChan
- e.acceptedChan = make(chan *endpoint, backlog)
- close(origChan)
- for ep := range origChan {
- e.acceptedChan <- ep
- }
+ e.accepted.cap = backlog
}
// Notify any blocked goroutines that they can attempt to
@@ -2549,12 +2544,12 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
e.isRegistered = true
e.setEndpointState(StateListen)
- // The channel may be non-nil when we're restoring the endpoint, and it
+ // The queue may be non-zero when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.acceptedChan == nil {
- e.acceptedChan = make(chan *endpoint, backlog)
+ if e.accepted == (accepted{}) {
+ e.accepted.cap = backlog
}
e.acceptMu.Unlock()
@@ -2591,15 +2586,16 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
}
// Get the new accepted endpoint.
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
var n *endpoint
- select {
- case n = <-e.acceptedChan:
- e.acceptCond.Signal()
- default:
+ e.acceptMu.Lock()
+ if element := e.accepted.endpoints.Front(); element != nil {
+ n = e.accepted.endpoints.Remove(element).(*endpoint)
+ }
+ e.acceptMu.Unlock()
+ if n == nil {
return nil, nil, &tcpip.ErrWouldBlock{}
}
+ e.acceptCond.Signal()
if peerAddr != nil {
*peerAddr = n.getRemoteAddress()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index a53d76917..f51b3ad90 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -99,37 +99,19 @@ func (e *endpoint) beforeSave() {
}
}
-// saveAcceptedChan is invoked by stateify.
-func (e *endpoint) saveAcceptedChan() []*endpoint {
- if e.acceptedChan == nil {
- return nil
- }
- acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
- for i := 0; i < len(acceptedEndpoints); i++ {
- select {
- case ep := <-e.acceptedChan:
- acceptedEndpoints[i] = ep
- default:
- panic("endpoint acceptedChan buffer got consumed by background context")
- }
- }
- for i := 0; i < len(acceptedEndpoints); i++ {
- select {
- case e.acceptedChan <- acceptedEndpoints[i]:
- default:
- panic("endpoint acceptedChan buffer got populated by background context")
- }
+// saveEndpoints is invoked by stateify.
+func (a *accepted) saveEndpoints() []*endpoint {
+ acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
+ for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
+ acceptedEndpoints[i] = e.Value.(*endpoint)
}
return acceptedEndpoints
}
-// loadAcceptedChan is invoked by stateify.
-func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
- if cap(acceptedEndpoints) > 0 {
- e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
- for _, ep := range acceptedEndpoints {
- e.acceptedChan <- ep
- }
+// loadEndpoints is invoked by stateify.
+func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+ for _, ep := range acceptedEndpoints {
+ a.endpoints.PushBack(ep)
}
}
@@ -263,7 +245,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
- backlog := cap(e.acceptedChan)
+ backlog := e.accepted.cap
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index a92bec6c5..632287cd3 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -158,6 +158,35 @@ func (e *EndpointInfo) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.TransportEndpointInfo)
}
+func (a *accepted) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.accepted"
+}
+
+func (a *accepted) StateFields() []string {
+ return []string{
+ "endpoints",
+ "cap",
+ }
+}
+
+func (a *accepted) beforeSave() {}
+
+// +checklocksignore
+func (a *accepted) StateSave(stateSinkObject state.Sink) {
+ a.beforeSave()
+ var endpointsValue []*endpoint = a.saveEndpoints()
+ stateSinkObject.SaveValue(0, endpointsValue)
+ stateSinkObject.Save(1, &a.cap)
+}
+
+func (a *accepted) afterLoad() {}
+
+// +checklocksignore
+func (a *accepted) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(1, &a.cap)
+ stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) })
+}
+
func (e *endpoint) StateTypeName() string {
return "pkg/tcpip/transport/tcp.endpoint"
}
@@ -213,7 +242,7 @@ func (e *endpoint) StateFields() []string {
"keepalive",
"userTimeout",
"deferAccept",
- "acceptedChan",
+ "accepted",
"rcv",
"snd",
"connectingAddress",
@@ -236,8 +265,6 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.SaveValue(13, stateValue)
var recentTSTimeValue unixTime = e.saveRecentTSTime()
stateSinkObject.SaveValue(26, recentTSTimeValue)
- var acceptedChanValue []*endpoint = e.saveAcceptedChan()
- stateSinkObject.SaveValue(49, acceptedChanValue)
var lastOutOfWindowAckTimeValue unixTime = e.saveLastOutOfWindowAckTime()
stateSinkObject.SaveValue(61, lastOutOfWindowAckTimeValue)
stateSinkObject.Save(0, &e.EndpointInfo)
@@ -287,6 +314,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(46, &e.keepalive)
stateSinkObject.Save(47, &e.userTimeout)
stateSinkObject.Save(48, &e.deferAccept)
+ stateSinkObject.Save(49, &e.accepted)
stateSinkObject.Save(50, &e.rcv)
stateSinkObject.Save(51, &e.snd)
stateSinkObject.Save(52, &e.connectingAddress)
@@ -349,6 +377,7 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(46, &e.keepalive)
stateSourceObject.Load(47, &e.userTimeout)
stateSourceObject.Load(48, &e.deferAccept)
+ stateSourceObject.Load(49, &e.accepted)
stateSourceObject.LoadWait(50, &e.rcv)
stateSourceObject.LoadWait(51, &e.snd)
stateSourceObject.Load(52, &e.connectingAddress)
@@ -362,7 +391,6 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(60, &e.ops)
stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) })
stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) })
- stateSourceObject.LoadValue(49, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) })
stateSourceObject.LoadValue(61, new(unixTime), func(y interface{}) { e.loadLastOutOfWindowAckTime(y.(unixTime)) })
stateSourceObject.AfterLoad(e.afterLoad)
}
@@ -1093,6 +1121,7 @@ func init() {
state.Register((*SACKInfo)(nil))
state.Register((*rcvBufAutoTuneParams)(nil))
state.Register((*EndpointInfo)(nil))
+ state.Register((*accepted)(nil))
state.Register((*endpoint)(nil))
state.Register((*keepalive)(nil))
state.Register((*rackControl)(nil))