summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/endpoint_state.go
diff options
context:
space:
mode:
authorZhaozhong Ni <nzz@google.com>2018-05-11 16:27:50 -0700
committerShentubot <shentubot@google.com>2018-05-11 16:28:39 -0700
commit987f7841a6ad8b77fe6a41cb70323517a5d2ccd1 (patch)
tree8b8b86d6723ee31a8b1ff3df8e5e43d235cd440b /pkg/tcpip/transport/tcp/endpoint_state.go
parent85fd5d40ff78f7b7fd473e5215daba84a28977f3 (diff)
netstack: TCP connecting state endpoint save / restore support.
PiperOrigin-RevId: 196325647 Change-Id: I850eb4a29b9c679da4db10eb164bbdf967690663
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint_state.go')
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go177
1 files changed, 146 insertions, 31 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index dbb70ff21..ebab7006d 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -6,6 +6,7 @@ package tcp
import (
"fmt"
+ "sync"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
@@ -22,27 +23,43 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported endpoint state: " + e.Err.Error()
}
+func (e *endpoint) drainSegmentLocked() {
+ // Drain only up to once.
+ if e.drainDone != nil {
+ return
+ }
+
+ e.drainDone = make(chan struct{})
+ e.undrain = make(chan struct{})
+ e.mu.Unlock()
+
+ e.notificationWaker.Assert()
+ <-e.drainDone
+
+ e.mu.Lock()
+}
+
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
e.segmentQueue.setLimit(0)
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
switch e.state {
case stateInitial:
case stateBound:
case stateListen:
if !e.segmentQueue.empty() {
- e.mu.RUnlock()
- e.drainDone = make(chan struct{}, 1)
- e.notificationWaker.Assert()
- <-e.drainDone
- e.mu.RLock()
+ e.drainSegmentLocked()
}
case stateConnecting:
- panic(ErrSaveRejection{fmt.Errorf("endpoint in connecting state upon save: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ e.drainSegmentLocked()
+ if e.state != stateConnected {
+ break
+ }
+ fallthrough
case stateConnected:
// FIXME
panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
@@ -56,28 +73,12 @@ func (e *endpoint) beforeSave() {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
e.stack = stack.StackFromEnv
+ e.segmentQueue.setLimit(2 * e.rcvBufSize)
+ e.workMu.Init()
- if e.state == stateListen {
- e.state = stateBound
- backlog := cap(e.acceptedChan)
- e.acceptedChan = nil
- defer func() {
- if err := e.Listen(backlog); err != nil {
- panic("endpoint listening failed: " + err.String())
- }
- }()
- }
-
- if e.state == stateBound {
- e.state = stateInitial
- defer func() {
- if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
- panic("endpoint binding failed: " + err.String())
- }
- }()
- }
-
- if e.state == stateInitial {
+ state := e.state
+ switch state {
+ case stateInitial, stateBound, stateListen, stateConnecting:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
@@ -89,8 +90,29 @@ func (e *endpoint) afterLoad() {
}
}
- e.segmentQueue.setLimit(2 * e.rcvBufSize)
- e.workMu.Init()
+ switch state {
+ case stateBound, stateListen, stateConnecting:
+ e.state = stateInitial
+ if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case stateListen:
+ backlog := cap(e.acceptedChan)
+ e.acceptedChan = nil
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case stateConnecting:
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ }
}
// saveAcceptedChan is invoked by stateify.
@@ -126,3 +148,96 @@ type endpointChan struct {
buffer []*endpoint
cap int
}
+
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = loadError(s)
+}
+
+// saveHardError is invoked by stateify.
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
+ return ""
+ }
+
+ return e.hardError.String()
+}
+
+// loadHardError is invoked by stateify.
+func (e *endpoint) loadHardError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.hardError = loadError(s)
+}
+
+var messageToError map[string]*tcpip.Error
+
+var populate sync.Once
+
+func loadError(s string) *tcpip.Error {
+ populate.Do(func() {
+ var errors = []*tcpip.Error{
+ tcpip.ErrUnknownProtocol,
+ tcpip.ErrUnknownNICID,
+ tcpip.ErrUnknownProtocolOption,
+ tcpip.ErrDuplicateNICID,
+ tcpip.ErrDuplicateAddress,
+ tcpip.ErrNoRoute,
+ tcpip.ErrBadLinkEndpoint,
+ tcpip.ErrAlreadyBound,
+ tcpip.ErrInvalidEndpointState,
+ tcpip.ErrAlreadyConnecting,
+ tcpip.ErrAlreadyConnected,
+ tcpip.ErrNoPortAvailable,
+ tcpip.ErrPortInUse,
+ tcpip.ErrBadLocalAddress,
+ tcpip.ErrClosedForSend,
+ tcpip.ErrClosedForReceive,
+ tcpip.ErrWouldBlock,
+ tcpip.ErrConnectionRefused,
+ tcpip.ErrTimeout,
+ tcpip.ErrAborted,
+ tcpip.ErrConnectStarted,
+ tcpip.ErrDestinationRequired,
+ tcpip.ErrNotSupported,
+ tcpip.ErrQueueSizeNotSupported,
+ tcpip.ErrNotConnected,
+ tcpip.ErrConnectionReset,
+ tcpip.ErrConnectionAborted,
+ tcpip.ErrNoSuchFile,
+ tcpip.ErrInvalidOptionValue,
+ tcpip.ErrNoLinkAddress,
+ tcpip.ErrBadAddress,
+ }
+
+ messageToError = make(map[string]*tcpip.Error)
+ for _, e := range errors {
+ if messageToError[e.String()] != nil {
+ panic("tcpip errors with duplicated message: " + e.String())
+ }
+ messageToError[e.String()] = e
+ }
+ })
+
+ e, ok := messageToError[s]
+ if !ok {
+ panic("unknown error message: " + s)
+ }
+
+ return e
+}