diff options
author | Zhaozhong Ni <nzz@google.com> | 2018-05-11 16:27:50 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-05-11 16:28:39 -0700 |
commit | 987f7841a6ad8b77fe6a41cb70323517a5d2ccd1 (patch) | |
tree | 8b8b86d6723ee31a8b1ff3df8e5e43d235cd440b /pkg/tcpip/transport/tcp/endpoint_state.go | |
parent | 85fd5d40ff78f7b7fd473e5215daba84a28977f3 (diff) |
netstack: TCP connecting state endpoint save / restore support.
PiperOrigin-RevId: 196325647
Change-Id: I850eb4a29b9c679da4db10eb164bbdf967690663
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint_state.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 177 |
1 files changed, 146 insertions, 31 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index dbb70ff21..ebab7006d 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -6,6 +6,7 @@ package tcp import ( "fmt" + "sync" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" @@ -22,27 +23,43 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported endpoint state: " + e.Err.Error() } +func (e *endpoint) drainSegmentLocked() { + // Drain only up to once. + if e.drainDone != nil { + return + } + + e.drainDone = make(chan struct{}) + e.undrain = make(chan struct{}) + e.mu.Unlock() + + e.notificationWaker.Assert() + <-e.drainDone + + e.mu.Lock() +} + // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. e.segmentQueue.setLimit(0) - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() switch e.state { case stateInitial: case stateBound: case stateListen: if !e.segmentQueue.empty() { - e.mu.RUnlock() - e.drainDone = make(chan struct{}, 1) - e.notificationWaker.Assert() - <-e.drainDone - e.mu.RLock() + e.drainSegmentLocked() } case stateConnecting: - panic(ErrSaveRejection{fmt.Errorf("endpoint in connecting state upon save: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + e.drainSegmentLocked() + if e.state != stateConnected { + break + } + fallthrough case stateConnected: // FIXME panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) @@ -56,28 +73,12 @@ func (e *endpoint) beforeSave() { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { e.stack = stack.StackFromEnv + e.segmentQueue.setLimit(2 * e.rcvBufSize) + e.workMu.Init() - if e.state == stateListen { - e.state = stateBound - backlog := cap(e.acceptedChan) - e.acceptedChan = nil - defer func() { - if err := e.Listen(backlog); err != nil { - panic("endpoint listening failed: " + err.String()) - } - }() - } - - if e.state == stateBound { - e.state = stateInitial - defer func() { - if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil { - panic("endpoint binding failed: " + err.String()) - } - }() - } - - if e.state == stateInitial { + state := e.state + switch state { + case stateInitial, stateBound, stateListen, stateConnecting: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { @@ -89,8 +90,29 @@ func (e *endpoint) afterLoad() { } } - e.segmentQueue.setLimit(2 * e.rcvBufSize) - e.workMu.Init() + switch state { + case stateBound, stateListen, stateConnecting: + e.state = stateInitial + if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil { + panic("endpoint binding failed: " + err.String()) + } + } + + switch state { + case stateListen: + backlog := cap(e.acceptedChan) + e.acceptedChan = nil + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + } + + switch state { + case stateConnecting: + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + } } // saveAcceptedChan is invoked by stateify. @@ -126,3 +148,96 @@ type endpointChan struct { buffer []*endpoint cap int } + +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = loadError(s) +} + +// saveHardError is invoked by stateify. +func (e *endpoint) saveHardError() string { + if e.hardError == nil { + return "" + } + + return e.hardError.String() +} + +// loadHardError is invoked by stateify. +func (e *endpoint) loadHardError(s string) { + if s == "" { + return + } + + e.hardError = loadError(s) +} + +var messageToError map[string]*tcpip.Error + +var populate sync.Once + +func loadError(s string) *tcpip.Error { + populate.Do(func() { + var errors = []*tcpip.Error{ + tcpip.ErrUnknownProtocol, + tcpip.ErrUnknownNICID, + tcpip.ErrUnknownProtocolOption, + tcpip.ErrDuplicateNICID, + tcpip.ErrDuplicateAddress, + tcpip.ErrNoRoute, + tcpip.ErrBadLinkEndpoint, + tcpip.ErrAlreadyBound, + tcpip.ErrInvalidEndpointState, + tcpip.ErrAlreadyConnecting, + tcpip.ErrAlreadyConnected, + tcpip.ErrNoPortAvailable, + tcpip.ErrPortInUse, + tcpip.ErrBadLocalAddress, + tcpip.ErrClosedForSend, + tcpip.ErrClosedForReceive, + tcpip.ErrWouldBlock, + tcpip.ErrConnectionRefused, + tcpip.ErrTimeout, + tcpip.ErrAborted, + tcpip.ErrConnectStarted, + tcpip.ErrDestinationRequired, + tcpip.ErrNotSupported, + tcpip.ErrQueueSizeNotSupported, + tcpip.ErrNotConnected, + tcpip.ErrConnectionReset, + tcpip.ErrConnectionAborted, + tcpip.ErrNoSuchFile, + tcpip.ErrInvalidOptionValue, + tcpip.ErrNoLinkAddress, + tcpip.ErrBadAddress, + } + + messageToError = make(map[string]*tcpip.Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} |