diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 90 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 37 |
4 files changed, 103 insertions, 96 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index a485064a1..1c54dc180 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -385,8 +385,8 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state (acceptedChan is nil), -// the new endpoint is closed instead. +// listener has transitioned out of the listen state (accepted is the zero +// value), the new endpoint is reset instead. func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) @@ -395,23 +395,23 @@ func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.acceptMu.Lock() for { - if e.acceptedChan == nil { - e.acceptMu.Unlock() + if e.accepted == (accepted{}) { n.notifyProtocolGoroutine(notifyReset) - return + break } - select { - case e.acceptedChan <- n: - if !withSynCookie { - atomic.AddInt32(&e.synRcvdCount, -1) - } - e.acceptMu.Unlock() - e.waiterQueue.Notify(waiter.ReadableEvents) - return - default: + if e.accepted.endpoints.Len() == e.accepted.cap { e.acceptCond.Wait() + continue } + + e.accepted.endpoints.PushBack(n) + if !withSynCookie { + atomic.AddInt32(&e.synRcvdCount, -1) + } + e.waiterQueue.Notify(waiter.ReadableEvents) + break } + e.acceptMu.Unlock() } // propagateInheritableOptionsLocked propagates any options set on the listening @@ -499,7 +499,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header func (e *endpoint) synRcvdBacklogFull() bool { e.acceptMu.Lock() - acceptedChanCap := cap(e.acceptedChan) + backlog := e.accepted.cap e.acceptMu.Unlock() // The allocated accepted channel size would always be one greater than the // listen backlog. But, the SYNRCVD connections count is always checked @@ -509,12 +509,12 @@ func (e *endpoint) synRcvdBacklogFull() bool { // We maintain an equality check here as the synRcvdCount is incremented // and compared only from a single listener context and the capacity of // the accepted channel can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedChanCap-1 + return int(atomic.LoadInt32(&e.synRcvdCount)) == backlog-1 } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan) == cap(e.acceptedChan) + full := e.accepted.endpoints.Len() == e.accepted.cap e.acceptMu.Unlock() return full } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 5001d222e..9fbaf6f4b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "container/list" "encoding/binary" "fmt" "io" @@ -322,6 +323,15 @@ type EndpointInfo struct { // marker interface. func (*EndpointInfo) IsEndpointInfo() {} +// +stateify savable +type accepted struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` + cap int +} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -337,7 +347,7 @@ func (*EndpointInfo) IsEndpointInfo() {} // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects acceptedChan. +// e.acceptMu -> protects accepted. // e.rcvListMu -> Protects the rcvList and associated fields. // e.sndBufMu -> Protects the sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -607,33 +617,26 @@ type endpoint struct { // listener. deferAccept time.Duration - // pendingAccepted is a synchronization primitive used to track number - // of connections that are queued up to be delivered to the accepted - // channel. We use this to ensure that all goroutines blocked on writing - // to the acceptedChan below terminate before we close acceptedChan. + // pendingAccepted tracks connections queued to be accepted. It is used to + // ensure such queued connections are terminated before the accepted queue is + // marked closed (by setting its capacity to zero). pendingAccepted sync.WaitGroup `state:"nosave"` - // acceptMu protects acceptedChan. + // acceptMu protects accepted. acceptMu sync.Mutex `state:"nosave"` // acceptCond is a condition variable that can be used to block on when - // acceptedChan is full and an endpoint is ready to be delivered. - // - // This condition variable is required because just blocking on sending - // to acceptedChan does not work in cases where endpoint.Listen is - // called twice with different backlog values. In such cases the channel - // is closed and a new one created. Any pending goroutines blocking on - // the write to the channel will panic. + // accepted is full and an endpoint is ready to be delivered. // // We use this condition variable to block/unblock goroutines which // tried to deliver an endpoint but couldn't because accept backlog was // full ( See: endpoint.deliverAccepted ). acceptCond *sync.Cond `state:"nosave"` - // acceptedChan is used by a listening endpoint protocol goroutine to + // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. - acceptedChan chan *endpoint `state:".([]*endpoint)"` + accepted accepted // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -962,7 +965,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Check if there's anything in the accepted channel. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if len(e.acceptedChan) > 0 { + if e.accepted.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -1145,22 +1148,22 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptMu.Unlock() + acceptedCopy := e.accepted + e.accepted = accepted{} + e.acceptMu.Unlock() + + if acceptedCopy == (accepted{}) { return } - close(e.acceptedChan) - ch := e.acceptedChan - e.acceptedChan = nil + e.acceptCond.Broadcast() - e.acceptMu.Unlock() // Reset all connections that are waiting to be accepted. - for n := range ch { - n.notifyProtocolGoroutine(notifyReset) + for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } // Wait for reset of all endpoints that are still waiting to be delivered to - // the now closed acceptedChan. + // the now closed accepted. e.pendingAccepted.Wait() } @@ -2495,28 +2498,20 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.acceptedChan == nil { + if e.accepted == (accepted{}) { // listen is called after shutdown. - e.acceptedChan = make(chan *endpoint, backlog) + e.accepted.cap = backlog e.shutdownFlags = 0 e.rcvListMu.Lock() e.rcvClosed = false e.rcvListMu.Unlock() } else { - // Adjust the size of the channel iff we can fix + // Adjust the size of the backlog iff we can fit // existing pending connections into the new one. - if len(e.acceptedChan) > backlog { + if e.accepted.endpoints.Len() > backlog { return &tcpip.ErrInvalidEndpointState{} } - if cap(e.acceptedChan) == backlog { - return nil - } - origChan := e.acceptedChan - e.acceptedChan = make(chan *endpoint, backlog) - close(origChan) - for ep := range origChan { - e.acceptedChan <- ep - } + e.accepted.cap = backlog } // Notify any blocked goroutines that they can attempt to @@ -2549,12 +2544,12 @@ func (e *endpoint) listen(backlog int) tcpip.Error { e.isRegistered = true e.setEndpointState(StateListen) - // The channel may be non-nil when we're restoring the endpoint, and it + // The queue may be non-zero when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptedChan = make(chan *endpoint, backlog) + if e.accepted == (accepted{}) { + e.accepted.cap = backlog } e.acceptMu.Unlock() @@ -2591,15 +2586,16 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. } // Get the new accepted endpoint. - e.acceptMu.Lock() - defer e.acceptMu.Unlock() var n *endpoint - select { - case n = <-e.acceptedChan: - e.acceptCond.Signal() - default: + e.acceptMu.Lock() + if element := e.accepted.endpoints.Front(); element != nil { + n = e.accepted.endpoints.Remove(element).(*endpoint) + } + e.acceptMu.Unlock() + if n == nil { return nil, nil, &tcpip.ErrWouldBlock{} } + e.acceptCond.Signal() if peerAddr != nil { *peerAddr = n.getRemoteAddress() } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index a53d76917..f51b3ad90 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -99,37 +99,19 @@ func (e *endpoint) beforeSave() { } } -// saveAcceptedChan is invoked by stateify. -func (e *endpoint) saveAcceptedChan() []*endpoint { - if e.acceptedChan == nil { - return nil - } - acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan)) - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case ep := <-e.acceptedChan: - acceptedEndpoints[i] = ep - default: - panic("endpoint acceptedChan buffer got consumed by background context") - } - } - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case e.acceptedChan <- acceptedEndpoints[i]: - default: - panic("endpoint acceptedChan buffer got populated by background context") - } +// saveEndpoints is invoked by stateify. +func (a *accepted) saveEndpoints() []*endpoint { + acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) + for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { + acceptedEndpoints[i] = e.Value.(*endpoint) } return acceptedEndpoints } -// loadAcceptedChan is invoked by stateify. -func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { - if cap(acceptedEndpoints) > 0 { - e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints)) - for _, ep := range acceptedEndpoints { - e.acceptedChan <- ep - } +// loadEndpoints is invoked by stateify. +func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { + for _, ep := range acceptedEndpoints { + a.endpoints.PushBack(ep) } } @@ -263,7 +245,7 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() - backlog := cap(e.acceptedChan) + backlog := e.accepted.cap if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index a92bec6c5..632287cd3 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -158,6 +158,35 @@ func (e *EndpointInfo) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &e.TransportEndpointInfo) } +func (a *accepted) StateTypeName() string { + return "pkg/tcpip/transport/tcp.accepted" +} + +func (a *accepted) StateFields() []string { + return []string{ + "endpoints", + "cap", + } +} + +func (a *accepted) beforeSave() {} + +// +checklocksignore +func (a *accepted) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + var endpointsValue []*endpoint = a.saveEndpoints() + stateSinkObject.SaveValue(0, endpointsValue) + stateSinkObject.Save(1, &a.cap) +} + +func (a *accepted) afterLoad() {} + +// +checklocksignore +func (a *accepted) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(1, &a.cap) + stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) }) +} + func (e *endpoint) StateTypeName() string { return "pkg/tcpip/transport/tcp.endpoint" } @@ -213,7 +242,7 @@ func (e *endpoint) StateFields() []string { "keepalive", "userTimeout", "deferAccept", - "acceptedChan", + "accepted", "rcv", "snd", "connectingAddress", @@ -236,8 +265,6 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.SaveValue(13, stateValue) var recentTSTimeValue unixTime = e.saveRecentTSTime() stateSinkObject.SaveValue(26, recentTSTimeValue) - var acceptedChanValue []*endpoint = e.saveAcceptedChan() - stateSinkObject.SaveValue(49, acceptedChanValue) var lastOutOfWindowAckTimeValue unixTime = e.saveLastOutOfWindowAckTime() stateSinkObject.SaveValue(61, lastOutOfWindowAckTimeValue) stateSinkObject.Save(0, &e.EndpointInfo) @@ -287,6 +314,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(46, &e.keepalive) stateSinkObject.Save(47, &e.userTimeout) stateSinkObject.Save(48, &e.deferAccept) + stateSinkObject.Save(49, &e.accepted) stateSinkObject.Save(50, &e.rcv) stateSinkObject.Save(51, &e.snd) stateSinkObject.Save(52, &e.connectingAddress) @@ -349,6 +377,7 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(46, &e.keepalive) stateSourceObject.Load(47, &e.userTimeout) stateSourceObject.Load(48, &e.deferAccept) + stateSourceObject.Load(49, &e.accepted) stateSourceObject.LoadWait(50, &e.rcv) stateSourceObject.LoadWait(51, &e.snd) stateSourceObject.Load(52, &e.connectingAddress) @@ -362,7 +391,6 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(60, &e.ops) stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) }) - stateSourceObject.LoadValue(49, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) stateSourceObject.LoadValue(61, new(unixTime), func(y interface{}) { e.loadLastOutOfWindowAckTime(y.(unixTime)) }) stateSourceObject.AfterLoad(e.afterLoad) } @@ -1093,6 +1121,7 @@ func init() { state.Register((*SACKInfo)(nil)) state.Register((*rcvBufAutoTuneParams)(nil)) state.Register((*EndpointInfo)(nil)) + state.Register((*accepted)(nil)) state.Register((*endpoint)(nil)) state.Register((*keepalive)(nil)) state.Register((*rackControl)(nil)) |