From a2c51efe3669f0380042b2375eae79e403d3680c Mon Sep 17 00:00:00 2001
From: Ian Gudger <igudger@google.com>
Date: Tue, 29 Oct 2019 16:13:43 -0700
Subject: Add endpoint tracking to the stack.

In the future this will replace DanglingEndpoints. DanglingEndpoints must be
kept for now due to issues with save/restore.

This is arguably a cleaner design and allows the stack to know which transport
endpoints might still be using its link endpoints.

Updates #837

PiperOrigin-RevId: 277386633
---
 pkg/tcpip/stack/stack.go                  | 59 ++++++++++++++++++++++++++--
 pkg/tcpip/stack/transport_demuxer.go      | 65 +++++++++++++++++++++++--------
 pkg/tcpip/stack/transport_test.go         |  3 +-
 pkg/tcpip/transport/tcp/endpoint.go       |  5 ++-
 pkg/tcpip/transport/tcp/endpoint_state.go |  7 +++-
 5 files changed, 114 insertions(+), 25 deletions(-)

(limited to 'pkg/tcpip')

diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 242d2150c..360c54b2d 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -361,9 +361,10 @@ type Stack struct {
 
 	linkAddrCache *linkAddrCache
 
-	mu         sync.RWMutex
-	nics       map[tcpip.NICID]*NIC
-	forwarding bool
+	mu               sync.RWMutex
+	nics             map[tcpip.NICID]*NIC
+	forwarding       bool
+	cleanupEndpoints map[TransportEndpoint]struct{}
 
 	// route is the route table passed in by the user via SetRouteTable(),
 	// it is used by FindRoute() to build a route for a specific
@@ -513,6 +514,7 @@ func New(opts Options) *Stack {
 		networkProtocols:     make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
 		linkAddrResolvers:    make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
 		nics:                 make(map[tcpip.NICID]*NIC),
+		cleanupEndpoints:     make(map[TransportEndpoint]struct{}),
 		linkAddrCache:        newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
 		PortManager:          ports.NewPortManager(),
 		clock:                clock,
@@ -1136,6 +1138,25 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
 	s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
 }
 
+// StartTransportEndpointCleanup removes the endpoint with the given id from
+// the stack transport dispatcher. It also transitions it to the cleanup stage.
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	s.cleanupEndpoints[ep] = struct{}{}
+
+	s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+}
+
+// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
+// stage.
+func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
+	s.mu.Lock()
+	delete(s.cleanupEndpoints, ep)
+	s.mu.Unlock()
+}
+
 // RegisterRawTransportEndpoint registers the given endpoint with the stack
 // transport dispatcher. Received packets that match the provided transport
 // protocol will be delivered to the given endpoint.
@@ -1157,6 +1178,38 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
 	s.mu.Unlock()
 }
 
+// RegisteredEndpoints returns all endpoints which are currently registered.
+func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	var es []TransportEndpoint
+	for _, e := range s.demux.protocol {
+		es = append(es, e.transportEndpoints()...)
+	}
+	return es
+}
+
+// CleanupEndpoints returns endpoints currently in the cleanup state.
+func (s *Stack) CleanupEndpoints() []TransportEndpoint {
+	s.mu.Lock()
+	es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
+	for e := range s.cleanupEndpoints {
+		es = append(es, e)
+	}
+	s.mu.Unlock()
+	return es
+}
+
+// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+// for restoring a stack after a save.
+func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
+	s.mu.Lock()
+	for _, e := range es {
+		s.cleanupEndpoints[e] = struct{}{}
+	}
+	s.mu.Unlock()
+}
+
 // Resume restarts the stack after a restore. This must be called after the
 // entire system has been restored.
 func (s *Stack) Resume() {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 9aff90a3d..f633632f0 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -41,6 +41,31 @@ type transportEndpoints struct {
 	rawEndpoints []RawTransportEndpoint
 }
 
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+	eps.mu.Lock()
+	defer eps.mu.Unlock()
+	epsByNic, ok := eps.endpoints[id]
+	if !ok {
+		return
+	}
+	if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+		return
+	}
+	delete(eps.endpoints, id)
+}
+
+func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
+	eps.mu.RLock()
+	defer eps.mu.RUnlock()
+	es := make([]TransportEndpoint, 0, len(eps.endpoints))
+	for _, e := range eps.endpoints {
+		es = append(es, e.transportEndpoints()...)
+	}
+	return es
+}
+
 type endpointsByNic struct {
 	mu        sync.RWMutex
 	endpoints map[tcpip.NICID]*multiPortEndpoint
@@ -48,6 +73,16 @@ type endpointsByNic struct {
 	seed uint32
 }
 
+func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
+	epsByNic.mu.RLock()
+	defer epsByNic.mu.RUnlock()
+	var eps []TransportEndpoint
+	for _, ep := range epsByNic.endpoints {
+		eps = append(eps, ep.transportEndpoints()...)
+	}
+	return eps
+}
+
 // HandlePacket is called by the stack when new packets arrive to this transport
 // endpoint.
 func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
@@ -127,21 +162,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
 	return len(epsByNic.endpoints) == 0
 }
 
-// unregisterEndpoint unregisters the endpoint with the given id such that it
-// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
-	eps.mu.Lock()
-	defer eps.mu.Unlock()
-	epsByNic, ok := eps.endpoints[id]
-	if !ok {
-		return
-	}
-	if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
-		return
-	}
-	delete(eps.endpoints, id)
-}
-
 // transportDemuxer demultiplexes packets targeted at a transport endpoint
 // (i.e., after they've been parsed by the network layer). It does two levels
 // of demultiplexing: first based on the network and transport protocols, then
@@ -183,14 +203,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
 // multiPortEndpoint is a container for TransportEndpoints which are bound to
 // the same pair of address and port. endpointsArr always has at least one
 // element.
+//
+// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
+// this to ensure that the underlying endpoints get saved/restored, but not not
+// use the restored copy.
+//
+// +stateify savable
 type multiPortEndpoint struct {
-	mu           sync.RWMutex
+	mu           sync.RWMutex `state:"nosave"`
 	endpointsArr []TransportEndpoint
 	endpointsMap map[TransportEndpoint]int
 	// reuse indicates if more than one endpoint is allowed.
 	reuse bool
 }
 
+func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
+	ep.mu.RLock()
+	eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+	ep.mu.RUnlock()
+	return eps
+}
+
 // reciprocalScale scales a value into range [0, n).
 //
 // This is similar to val % n, but faster.
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index db951c9ce..ae6fda3a9 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -218,8 +218,7 @@ func (f *fakeTransportEndpoint) State() uint32 {
 	return 0
 }
 
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
-}
+func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
 
 func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
 	return iptables.IPTables{}, nil
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ce8307cee..8a3ca0f1b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -686,7 +686,7 @@ func (e *endpoint) Close() {
 	// in Listen() when trying to register.
 	if e.state == StateListen && e.isPortReserved {
 		if e.isRegistered {
-			e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+			e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
 			e.isRegistered = false
 		}
 
@@ -747,7 +747,7 @@ func (e *endpoint) cleanupLocked() {
 	e.workerCleanup = false
 
 	if e.isRegistered {
-		e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+		e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
 		e.isRegistered = false
 	}
 
@@ -757,6 +757,7 @@ func (e *endpoint) cleanupLocked() {
 	}
 
 	e.route.Release()
+	e.stack.CompleteTransportEndpointCleanup(e)
 	tcpip.DeleteDanglingEndpoint(e)
 }
 
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index eae17237e..19f003b6b 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -193,8 +193,10 @@ func (e *endpoint) Resume(s *stack.Stack) {
 		if len(e.BindAddr) == 0 {
 			e.BindAddr = e.ID.LocalAddress
 		}
-		if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil {
-			panic("endpoint binding failed: " + err.String())
+		addr := e.BindAddr
+		port := e.ID.LocalPort
+		if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
+			panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
 		}
 	}
 
@@ -265,6 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
 		}
 		fallthrough
 	case StateError:
+		e.stack.CompleteTransportEndpointCleanup(e)
 		tcpip.DeleteDanglingEndpoint(e)
 	}
 }
-- 
cgit v1.2.3