diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/stack/stack.go | 59 | ||||
-rwxr-xr-x | pkg/tcpip/stack/stack_state_autogen.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 7 |
5 files changed, 129 insertions, 23 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 242d2150c..360c54b2d 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -361,9 +361,10 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -513,6 +514,7 @@ func New(opts Options) *Stack { networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, @@ -1136,6 +1138,25 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() + + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1157,6 +1178,38 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go index 2459c3d3a..2551126f2 100755 --- a/pkg/tcpip/stack/stack_state_autogen.go +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -99,6 +99,21 @@ func (x *TransportEndpointInfo) load(m state.Map) { m.Load("RegisterNICID", &x.RegisterNICID) } +func (x *multiPortEndpoint) beforeSave() {} +func (x *multiPortEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("endpointsArr", &x.endpointsArr) + m.Save("endpointsMap", &x.endpointsMap) + m.Save("reuse", &x.reuse) +} + +func (x *multiPortEndpoint) afterLoad() {} +func (x *multiPortEndpoint) load(m state.Map) { + m.Load("endpointsArr", &x.endpointsArr) + m.Load("endpointsMap", &x.endpointsMap) + m.Load("reuse", &x.reuse) +} + func init() { state.Register("stack.linkAddrEntryList", (*linkAddrEntryList)(nil), state.Fns{Save: (*linkAddrEntryList).save, Load: (*linkAddrEntryList).load}) state.Register("stack.linkAddrEntryEntry", (*linkAddrEntryEntry)(nil), state.Fns{Save: (*linkAddrEntryEntry).save, Load: (*linkAddrEntryEntry).load}) @@ -106,4 +121,5 @@ func init() { state.Register("stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load}) state.Register("stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load}) state.Register("stack.TransportEndpointInfo", (*TransportEndpointInfo)(nil), state.Fns{Save: (*TransportEndpointInfo).save, Load: (*TransportEndpointInfo).load}) + state.Register("stack.multiPortEndpoint", (*multiPortEndpoint)(nil), state.Fns{Save: (*multiPortEndpoint).save, Load: (*multiPortEndpoint).load}) } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9aff90a3d..f633632f0 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -41,6 +41,31 @@ type transportEndpoints struct { rawEndpoints []RawTransportEndpoint } +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + eps.mu.Lock() + defer eps.mu.Unlock() + epsByNic, ok := eps.endpoints[id] + if !ok { + return + } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + type endpointsByNic struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint @@ -48,6 +73,16 @@ type endpointsByNic struct { seed uint32 } +func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNic.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { @@ -127,21 +162,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T return len(epsByNic.endpoints) == 0 } -// unregisterEndpoint unregisters the endpoint with the given id such that it -// won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - eps.mu.Lock() - defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] - if !ok { - return - } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { - return - } - delete(eps.endpoints, id) -} - // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then @@ -183,14 +203,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. reuse bool } +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + return eps +} + // reciprocalScale scales a value into range [0, n). // // This is similar to val % n, but faster. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ce8307cee..8a3ca0f1b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -686,7 +686,7 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } @@ -747,7 +747,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } @@ -757,6 +757,7 @@ func (e *endpoint) cleanupLocked() { } e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index eae17237e..19f003b6b 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -193,8 +193,10 @@ func (e *endpoint) Resume(s *stack.Stack) { if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -265,6 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } fallthrough case StateError: + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } |