summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/registration.go14
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go5
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go3
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go3
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go16
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go3
7 files changed, 62 insertions, 2 deletions
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0869fb084..0360187b8 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -67,6 +67,20 @@ type TransportEndpoint interface {
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it. This cleanup may happen asynchronously. Wait can
+ // be used to block on this asynchronous cleanup.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // An endpoint can be requested to stop its worker goroutines by calling
+ // its Close method.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// RawTransportEndpoint is the interface that needs to be implemented by raw
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 97a1aec4b..9aff90a3d 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -240,6 +240,26 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, v
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
+// Close implements stack.TransportEndpoint.Close.
+func (ep *multiPortEndpoint) Close() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Close()
+ }
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (ep *multiPortEndpoint) Wait() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Wait()
+ }
+}
+
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 86c62be25..db951c9ce 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -225,8 +225,9 @@ func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
return iptables.IPTables{}, nil
}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
-}
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+
+func (f *fakeTransportEndpoint) Wait() {}
type fakeTransportGoodOption bool
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 043467519..d0dd383fd 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -798,3 +798,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 308f10d24..951d317ed 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -641,3 +641,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 8234a8b53..ce8307cee 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2399,6 +2399,22 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+// Wait implements stack.TransportEndpoint.Wait.
+func (e *endpoint) Wait() {
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
+ defer e.waiterQueue.EventUnregister(&waitEntry)
+ for {
+ e.mu.Lock()
+ running := e.workerRunning
+ e.mu.Unlock()
+ if !running {
+ break
+ }
+ <-notifyCh
+ }
+}
+
func mssForRoute(r *stack.Route) uint16 {
// TODO(b/143359391): Respect TCP Min and Max size.
return uint16(r.MTU() - header.TCPMinimumSize)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 91c8487f3..cda302bb7 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1234,6 +1234,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+// Wait implements tcpip.Endpoint.Wait.
+func (*endpoint) Wait() {}
+
func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}