summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go32
-rw-r--r--pkg/tcpip/stack/registration.go23
-rw-r--r--pkg/tcpip/stack/stack.go53
-rw-r--r--pkg/tcpip/stack/stack_test.go6
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go15
6 files changed, 123 insertions, 26 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 862954ab2..46d3a6646 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -298,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ var errs []*tcpip.Error
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ if len(errs) > 0 {
+ return errs[0]
+ }
+
+ return nil
+}
+
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
@@ -1302,6 +1329,11 @@ func (n *NIC) Stack() *Stack {
return n.stack
}
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// isAddrTentative returns true if addr is tentative on n.
//
// Note that if addr is not associated with n, then this function will return
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index d83adf0ec..f9fd8f18f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -74,10 +74,11 @@ type TransportEndpoint interface {
// HandleControlPacket takes ownership of pkt.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
- // Close puts the endpoint in a closed state and frees all resources
- // associated with it. This cleanup may happen asynchronously. Wait can
- // be used to block on this asynchronous cleanup.
- Close()
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
// Wait waits for any worker goroutines owned by the endpoint to stop.
//
@@ -160,6 +161,13 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -293,6 +301,13 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f0ed76fbe..ebb6c5e3b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -916,6 +916,18 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
+}
+
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
@@ -956,6 +968,33 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
return nic.enabled()
}
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ if r.NIC != id {
+ // Keep this route.
+ if i > n {
+ s.routeTable[n] = r
+ }
+ n++
+ }
+ }
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
+}
+
// NICAddressRanges returns a map of NICIDs to their associated subnets.
func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
@@ -1407,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
// Endpoints created or modified during this call may not get closed.
func (s *Stack) Close() {
for _, e := range s.RegisteredEndpoints() {
- e.Close()
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
}
}
@@ -1425,6 +1470,12 @@ func (s *Stack) Wait() {
for _, e := range s.CleanupEndpoints() {
e.Wait()
}
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
s.mu.RLock()
defer s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 18016e7db..edf6bec52 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -235,6 +235,12 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index d686e6eb8..778c0a4d6 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// Close implements stack.TransportEndpoint.Close.
-func (ep *multiPortEndpoint) Close() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Close()
- }
-}
-
-// Wait implements stack.TransportEndpoint.Wait.
-func (ep *multiPortEndpoint) Wait() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Wait()
- }
-}
-
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 869c69a6d..5d1da2f8b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
+}
+
func (f *fakeTransportEndpoint) Close() {
f.route.Release()
}
@@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}