summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/stack/stack.go59
-rwxr-xr-xpkg/tcpip/stack/stack_state_autogen.go16
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go65
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go7
5 files changed, 129 insertions, 23 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 242d2150c..360c54b2d 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -361,9 +361,10 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+ forwarding bool
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -513,6 +514,7 @@ func New(opts Options) *Stack {
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
clock: clock,
@@ -1136,6 +1138,25 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
+// StartTransportEndpointCleanup removes the endpoint with the given id from
+// the stack transport dispatcher. It also transitions it to the cleanup stage.
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.cleanupEndpoints[ep] = struct{}{}
+
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+}
+
+// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
+// stage.
+func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
+ s.mu.Lock()
+ delete(s.cleanupEndpoints, ep)
+ s.mu.Unlock()
+}
+
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
@@ -1157,6 +1178,38 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
s.mu.Unlock()
}
+// RegisteredEndpoints returns all endpoints which are currently registered.
+func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var es []TransportEndpoint
+ for _, e := range s.demux.protocol {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// CleanupEndpoints returns endpoints currently in the cleanup state.
+func (s *Stack) CleanupEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
+ for e := range s.cleanupEndpoints {
+ es = append(es, e)
+ }
+ s.mu.Unlock()
+ return es
+}
+
+// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+// for restoring a stack after a save.
+func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
+ s.mu.Lock()
+ for _, e := range es {
+ s.cleanupEndpoints[e] = struct{}{}
+ }
+ s.mu.Unlock()
+}
+
// Resume restarts the stack after a restore. This must be called after the
// entire system has been restored.
func (s *Stack) Resume() {
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
index 2459c3d3a..2551126f2 100755
--- a/pkg/tcpip/stack/stack_state_autogen.go
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -99,6 +99,21 @@ func (x *TransportEndpointInfo) load(m state.Map) {
m.Load("RegisterNICID", &x.RegisterNICID)
}
+func (x *multiPortEndpoint) beforeSave() {}
+func (x *multiPortEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("endpointsArr", &x.endpointsArr)
+ m.Save("endpointsMap", &x.endpointsMap)
+ m.Save("reuse", &x.reuse)
+}
+
+func (x *multiPortEndpoint) afterLoad() {}
+func (x *multiPortEndpoint) load(m state.Map) {
+ m.Load("endpointsArr", &x.endpointsArr)
+ m.Load("endpointsMap", &x.endpointsMap)
+ m.Load("reuse", &x.reuse)
+}
+
func init() {
state.Register("stack.linkAddrEntryList", (*linkAddrEntryList)(nil), state.Fns{Save: (*linkAddrEntryList).save, Load: (*linkAddrEntryList).load})
state.Register("stack.linkAddrEntryEntry", (*linkAddrEntryEntry)(nil), state.Fns{Save: (*linkAddrEntryEntry).save, Load: (*linkAddrEntryEntry).load})
@@ -106,4 +121,5 @@ func init() {
state.Register("stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load})
state.Register("stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load})
state.Register("stack.TransportEndpointInfo", (*TransportEndpointInfo)(nil), state.Fns{Save: (*TransportEndpointInfo).save, Load: (*TransportEndpointInfo).load})
+ state.Register("stack.multiPortEndpoint", (*multiPortEndpoint)(nil), state.Fns{Save: (*multiPortEndpoint).save, Load: (*multiPortEndpoint).load})
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 9aff90a3d..f633632f0 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -41,6 +41,31 @@ type transportEndpoints struct {
rawEndpoints []RawTransportEndpoint
}
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ epsByNic, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
+ }
+ delete(eps.endpoints, id)
+}
+
+func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+ es := make([]TransportEndpoint, 0, len(eps.endpoints))
+ for _, e := range eps.endpoints {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
type endpointsByNic struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
@@ -48,6 +73,16 @@ type endpointsByNic struct {
seed uint32
}
+func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+ var eps []TransportEndpoint
+ for _, ep := range epsByNic.endpoints {
+ eps = append(eps, ep.transportEndpoints()...)
+ }
+ return eps
+}
+
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
@@ -127,21 +162,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
return len(epsByNic.endpoints) == 0
}
-// unregisterEndpoint unregisters the endpoint with the given id such that it
-// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- eps.mu.Lock()
- defer eps.mu.Unlock()
- epsByNic, ok := eps.endpoints[id]
- if !ok {
- return
- }
- if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
- return
- }
- delete(eps.endpoints, id)
-}
-
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
@@ -183,14 +203,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
+//
+// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
+// this to ensure that the underlying endpoints get saved/restored, but not not
+// use the restored copy.
+//
+// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
reuse bool
}
+func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ return eps
+}
+
// reciprocalScale scales a value into range [0, n).
//
// This is similar to val % n, but faster.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ce8307cee..8a3ca0f1b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -686,7 +686,7 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
@@ -747,7 +747,7 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
@@ -757,6 +757,7 @@ func (e *endpoint) cleanupLocked() {
}
e.route.Release()
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index eae17237e..19f003b6b 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -193,8 +193,10 @@ func (e *endpoint) Resume(s *stack.Stack) {
if len(e.BindAddr) == 0 {
e.BindAddr = e.ID.LocalAddress
}
- if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil {
- panic("endpoint binding failed: " + err.String())
+ addr := e.BindAddr
+ port := e.ID.LocalPort
+ if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
+ panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
}
}
@@ -265,6 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
fallthrough
case StateError:
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
}