From 7d80e85835fbe47b2395eedf287cf902ed78599a Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 29 Oct 2019 11:19:04 -0700 Subject: Allow waiting for Endpoint worker goroutines to finish. Updates #837 PiperOrigin-RevId: 277325162 --- pkg/tcpip/stack/registration.go | 14 ++++++++++++++ pkg/tcpip/stack/transport_demuxer.go | 20 ++++++++++++++++++++ pkg/tcpip/stack/transport_test.go | 5 +++-- pkg/tcpip/transport/icmp/endpoint.go | 3 +++ pkg/tcpip/transport/raw/endpoint.go | 3 +++ pkg/tcpip/transport/tcp/endpoint.go | 16 ++++++++++++++++ pkg/tcpip/transport/udp/endpoint.go | 3 +++ 7 files changed, 62 insertions(+), 2 deletions(-) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0869fb084..0360187b8 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,6 +67,20 @@ type TransportEndpoint interface { // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. This cleanup may happen asynchronously. Wait can + // be used to block on this asynchronous cleanup. + Close() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 97a1aec4b..9aff90a3d 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -240,6 +240,26 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, v ep.mu.RUnlock() // Don't use defer for performance reasons. } +// Close implements stack.TransportEndpoint.Close. +func (ep *multiPortEndpoint) Close() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Close() + } +} + +// Wait implements stack.TransportEndpoint.Wait. +func (ep *multiPortEndpoint) Wait() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Wait() + } +} + // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..db951c9ce 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -225,8 +225,9 @@ func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 043467519..d0dd383fd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -798,3 +798,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 308f10d24..951d317ed 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -641,3 +641,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8234a8b53..ce8307cee 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2399,6 +2399,22 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 91c8487f3..cda302bb7 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1234,6 +1234,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } -- cgit v1.2.3