summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/inet/inet.go3
-rw-r--r--pkg/sentry/inet/test_stack.go4
-rw-r--r--pkg/sentry/kernel/kernel.go9
-rw-r--r--pkg/sentry/socket/epsocket/stack.go5
-rw-r--r--pkg/sentry/socket/hostinet/stack.go3
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go3
-rw-r--r--pkg/sentry/state/BUILD1
-rw-r--r--pkg/sentry/state/state.go5
-rw-r--r--pkg/tcpip/stack/stack.go35
-rw-r--r--pkg/tcpip/stack/transport_test.go3
-rw-r--r--pkg/tcpip/tcpip.go10
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go28
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go26
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go25
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go24
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go101
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go100
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go47
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go46
19 files changed, 278 insertions, 200 deletions
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 60d6dfb93..80f227dbe 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -55,6 +55,9 @@ type Stack interface {
// RouteTable returns the network stack's route table.
RouteTable() []Route
+
+ // Resume restarts the network stack after restore.
+ Resume()
}
// Interface contains information about a network interface.
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 57d5510f0..b9eed7c3a 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -92,3 +92,7 @@ func (s *TestStack) Statistics(stat interface{}, arg string) error {
func (s *TestStack) RouteTable() []Route {
return s.RouteList
}
+
+// Resume implements Stack.Resume.
+func (s *TestStack) Resume() {
+}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 56a329f83..8c1f79ab5 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -496,7 +496,7 @@ func (ts *TaskSet) unregisterEpollWaiters() {
}
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
+func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
k.networkStack = net
@@ -540,6 +540,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
log.Infof("Overall load took [%s]", time.Since(loadStart))
+ k.Timekeeper().SetClocks(clocks)
+ if net != nil {
+ net.Resume()
+ }
+
// Ensure that all pending asynchronous work is complete:
// - namedpipe opening
// - inode file opening
@@ -549,7 +554,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
tcpip.AsyncLoading.Wait()
- log.Infof("Overall load took [%s]", time.Since(loadStart))
+ log.Infof("Overall load took [%s] after async work", time.Since(loadStart))
// Applications may size per-cpu structures based on k.applicationCores, so
// it can't change across save/restore. When we are virtualizing CPU
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
index 0cf235b31..8f1572bf4 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -201,3 +201,8 @@ func (s *Stack) IPTables() (iptables.IPTables, error) {
func (s *Stack) FillDefaultIPTables() error {
return netfilter.FillDefaultIPTables(s.Stack)
}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {
+ s.Stack.Resume()
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 99b7a1e2b..1902fe155 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -329,3 +329,6 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
func (s *Stack) RouteTable() []inet.Route {
return append([]inet.Route(nil), s.routes...)
}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
index d18305589..5dcb6b455 100644
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ b/pkg/sentry/socket/rpcinet/stack.go
@@ -162,3 +162,6 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
func (s *Stack) RouteTable() []inet.Route {
return append([]inet.Route(nil), s.routes...)
}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
index f297ef3b7..88765f4d6 100644
--- a/pkg/sentry/state/BUILD
+++ b/pkg/sentry/state/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/log",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
+ "//pkg/sentry/time",
"//pkg/sentry/watchdog",
"//pkg/state/statefile",
"//pkg/syserror",
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index 026549756..9eb626b76 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/state/statefile"
"gvisor.dev/gvisor/pkg/syserror"
@@ -104,7 +105,7 @@ type LoadOpts struct {
}
// Load loads the given kernel, setting the provided platform and stack.
-func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error {
+func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error {
// Open the file.
r, m, err := statefile.NewReader(opts.Source, opts.Key)
if err != nil {
@@ -114,5 +115,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error {
previousMetadata = m
// Restore the Kernel object graph.
- return k.LoadFrom(r, n)
+ return k.LoadFrom(r, n, clocks)
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 78beb0dae..d45e547ee 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -334,6 +334,15 @@ type TCPEndpointState struct {
Sender TCPSenderState
}
+// ResumableEndpoint is an endpoint that needs to be resumed after restore.
+type ResumableEndpoint interface {
+ // Resume resumes an endpoint after restore. This can be used to restart
+ // background workers such as protocol goroutines. This must be called after
+ // all indirect dependencies of the endpoint has been restored, which
+ // generally implies at the end of the restore process.
+ Resume(*Stack)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -376,6 +385,10 @@ type Stack struct {
// tables are the iptables packet filtering and manipulation rules.
tables iptables.IPTables
+
+ // resumableEndpoints is a list of endpoints that need to be resumed if the
+ // stack is being restored.
+ resumableEndpoints []ResumableEndpoint
}
// Options contains optional Stack configuration.
@@ -1091,6 +1104,28 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip
}
}
+// RegisterRestoredEndpoint records e as an endpoint that has been restored on
+// this stack.
+func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
+ s.mu.Lock()
+ s.resumableEndpoints = append(s.resumableEndpoints, e)
+ s.mu.Unlock()
+}
+
+// Resume restarts the stack after a restore. This must be called after the
+// entire system has been restored.
+func (s *Stack) Resume() {
+ // ResumableEndpoint.Resume() may call other methods on s, so we can't hold
+ // s.mu while resuming the endpoints.
+ s.mu.Lock()
+ eps := s.resumableEndpoints
+ s.resumableEndpoints = nil
+ s.mu.Unlock()
+ for _, e := range eps {
+ e.Resume(s)
+ }
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 8652d7814..eee3144cd 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -205,6 +205,9 @@ func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
return iptables.IPTables{}, nil
}
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
+}
+
type fakeTransportGoodOption bool
type fakeTransportBadOption bool
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0df9f6d93..119712d2f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1082,11 +1082,13 @@ type ProtocolAddress struct {
AddressWithPrefix AddressWithPrefix
}
-// danglingEndpointsMu protects access to danglingEndpoints.
-var danglingEndpointsMu sync.Mutex
+var (
+ // danglingEndpointsMu protects access to danglingEndpoints.
+ danglingEndpointsMu sync.Mutex
-// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
-var danglingEndpoints = make(map[Endpoint]struct{})
+ // danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+ danglingEndpoints = make(map[Endpoint]struct{})
+)
// GetDanglingEndpoints returns all dangling endpoints.
func GetDanglingEndpoints() []Endpoint {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index a4527c041..9a4306011 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -136,6 +136,34 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
return e.stack.IPTables(), nil
}
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
+ if err != nil {
+ panic(*err)
+ }
+}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 99b8c4093..43551d642 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,7 +15,6 @@
package icmp
import (
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -63,28 +62,5 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
-
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- var err *tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
- if err != nil {
- panic(*err)
- }
-
- e.id.LocalAddress = e.route.LocalAddress
- } else if len(e.id.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
- }
- }
-
- e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
- if err != nil {
- panic(*err)
- }
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b4be855c1..eab3dcbd2 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -174,6 +174,31 @@ func (ep *endpoint) IPTables() (iptables.IPTables, error) {
return ep.stack.IPTables(), nil
}
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (ep *endpoint) Resume(s *stack.Stack) {
+ ep.stack = s
+
+ // If the endpoint is connected, re-connect.
+ if ep.connected {
+ var err *tcpip.Error
+ ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ if err != nil {
+ panic(*err)
+ }
+ }
+
+ // If the endpoint is bound, re-bind.
+ if ep.bound {
+ if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ panic(*err)
+ }
+}
+
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
if !ep.associated {
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index cb5534d90..44abddb2b 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,7 +15,6 @@
package raw
import (
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -63,26 +62,5 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- // StackFromEnv is a stack used specifically for save/restore.
- ep.stack = stack.StackFromEnv
-
- // If the endpoint is connected, re-connect via the save/restore stack.
- if ep.connected {
- var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
- if err != nil {
- panic(*err)
- }
- }
-
- // If the endpoint is bound, re-bind via the save/restore stack.
- if ep.bound {
- if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
- panic(tcpip.ErrBadLocalAddress)
- }
- }
-
- if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
- panic(*err)
- }
+ stack.StackFromEnv.RegisterRestoredEndpoint(ep)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 0e16877e7..e67169111 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -720,6 +720,107 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
return e.stack.IPTables(), nil
}
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+
+ state := e.state
+ switch state {
+ case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ }
+ if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+ }
+ }
+
+ bind := func() {
+ e.state = StateInitial
+ if len(e.bindAddress) == 0 {
+ e.bindAddress = e.id.LocalAddress
+ }
+ if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ bind()
+ if len(e.connectingAddress) == 0 {
+ e.connectingAddress = e.id.RemoteAddress
+ // This endpoint is accepted by netstack but not yet by
+ // the app. If the endpoint is IPv6 but the remote
+ // address is IPv4, we need to connect as IPv6 so that
+ // dual-stack mode can be properly activated.
+ if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ }
+ }
+ // Reset the scoreboard to reinitialize the sack information as
+ // we do not restore SACK information.
+ e.scoreboard.Reset()
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectedLoading.Done()
+ case StateListen:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ bind()
+ backlog := cap(e.acceptedChan)
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ listenLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case StateConnecting, StateSynSent, StateSynRecv:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ bind()
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectingLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case StateBound:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ tcpip.AsyncLoading.Done()
+ }()
+ case StateClose:
+ if e.isPortReserved {
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ e.state = StateClose
+ tcpip.AsyncLoading.Done()
+ }()
+ }
+ fallthrough
+ case StateError:
+ tcpip.DeleteDanglingEndpoint(e)
+ }
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b3f0f6c5d..ef88dc618 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -20,7 +20,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -165,104 +164,7 @@ func (e *endpoint) loadState(state EndpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
-
- state := e.state
- switch state {
- case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
- var ss SendBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
- }
- if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
- }
- }
- }
-
- bind := func() {
- e.state = StateInitial
- if len(e.bindAddress) == 0 {
- e.bindAddress = e.id.LocalAddress
- }
- if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
- panic("endpoint binding failed: " + err.String())
- }
- }
-
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
- bind()
- if len(e.connectingAddress) == 0 {
- // This endpoint is accepted by netstack but not yet by
- // the app. If the endpoint is IPv6 but the remote
- // address is IPv4, we need to connect as IPv6 so that
- // dual-stack mode can be properly activated.
- if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
- e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
- } else {
- e.connectingAddress = e.id.RemoteAddress
- }
- }
- // Reset the scoreboard to reinitialize the sack information as
- // we do not restore SACK information.
- e.scoreboard.Reset()
- if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
- panic("endpoint connecting failed: " + err.String())
- }
- connectedLoading.Done()
- case StateListen:
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- bind()
- backlog := cap(e.acceptedChan)
- if err := e.Listen(backlog); err != nil {
- panic("endpoint listening failed: " + err.String())
- }
- listenLoading.Done()
- tcpip.AsyncLoading.Done()
- }()
- case StateConnecting, StateSynSent, StateSynRecv:
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- bind()
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
- panic("endpoint connecting failed: " + err.String())
- }
- connectingLoading.Done()
- tcpip.AsyncLoading.Done()
- }()
- case StateBound:
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- tcpip.AsyncLoading.Done()
- }()
- case StateClose:
- if e.isPortReserved {
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- e.state = StateClose
- tcpip.AsyncLoading.Done()
- }()
- }
- fallthrough
- case StateError:
- tcpip.DeleteDanglingEndpoint(e)
- }
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 7210b3a9f..7c12a6092 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -178,6 +178,53 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
return e.stack.IPTables(), nil
}
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ if err != nil {
+ panic(*err)
+ }
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ id := e.id
+ e.id.LocalPort = 0
+ e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ panic(*err)
+ }
+}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 18e786397..86db36260 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,9 +15,7 @@
package udp
import (
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -64,47 +62,5 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
-
- for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
-
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- netProto := e.effectiveNetProtos[0]
- // Connect() and bindLocked() both assert
- //
- // netProto == header.IPv6ProtocolNumber
- //
- // before creating a multi-entry effectiveNetProtos.
- if len(e.effectiveNetProtos) > 1 {
- netProto = header.IPv6ProtocolNumber
- }
-
- var err *tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
- if err != nil {
- panic(*err)
- }
- } else if len(e.id.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
- }
- }
-
- // Our saved state had a port, but we don't actually have a
- // reservation. We need to remove the port from our state, but still
- // pass it to the reservation machinery.
- id := e.id
- e.id.LocalPort = 0
- e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
- if err != nil {
- panic(*err)
- }
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
}