summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorZhaozhong Ni <nzz@google.com>2018-05-11 16:27:50 -0700
committerShentubot <shentubot@google.com>2018-05-11 16:28:39 -0700
commit987f7841a6ad8b77fe6a41cb70323517a5d2ccd1 (patch)
tree8b8b86d6723ee31a8b1ff3df8e5e43d235cd440b /pkg
parent85fd5d40ff78f7b7fd473e5215daba84a28977f3 (diff)
netstack: TCP connecting state endpoint save / restore support.
PiperOrigin-RevId: 196325647 Change-Id: I850eb4a29b9c679da4db10eb164bbdf967690663
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/tcpip.go3
-rw-r--r--pkg/tcpip/transport/tcp/accept.go4
-rw-r--r--pkg/tcpip/transport/tcp/connect.go56
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go21
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go177
5 files changed, 210 insertions, 51 deletions
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f9df1d989..c27c0dd89 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,9 @@ import (
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
+//
+// Note: to support save / restore, it is important that all tcpip errors have
+// distinct error messages.
type Error struct {
string
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 9a5b13066..a71cb444f 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -379,8 +379,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.handleListenSegment(ctx, s)
s.decRef()
}
- e.drainDone <- struct{}{}
- return nil
+ close(e.drainDone)
+ <-e.undrain
}
case wakerForNewSegment:
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4d20f4d3f..698e2b440 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -296,6 +296,21 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return nil
}
+func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+ h.sndWnd = s.window
+ if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ switch h.state {
+ case handshakeSynRcvd:
+ return h.synRcvdState(s)
+ case handshakeSynSent:
+ return h.synSentState(s)
+ }
+ return nil
+}
+
// processSegments goes through the segment queue and processes up to
// maxSegmentsPerWake (if they're available).
func (h *handshake) processSegments() *tcpip.Error {
@@ -305,18 +320,7 @@ func (h *handshake) processSegments() *tcpip.Error {
return nil
}
- h.sndWnd = s.window
- if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
- h.sndWnd <<= uint8(h.sndWndScale)
- }
-
- var err *tcpip.Error
- switch h.state {
- case handshakeSynRcvd:
- err = h.synRcvdState(s)
- case handshakeSynSent:
- err = h.synSentState(s)
- }
+ err := h.handleSegment(s)
s.decRef()
if err != nil {
return err
@@ -364,6 +368,10 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
+ if n&notifyDrain != 0 {
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
}
// Wait for notification.
@@ -434,6 +442,20 @@ func (h *handshake) execute() *tcpip.Error {
if n&notifyClose != 0 {
return tcpip.ErrAborted
}
+ if n&notifyDrain != 0 {
+ for s := h.ep.segmentQueue.dequeue(); s != nil; s = h.ep.segmentQueue.dequeue() {
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+ if h.state == handshakeCompleted {
+ return nil
+ }
+ }
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -833,7 +855,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.mu.Lock()
e.state = stateError
e.hardError = err
+ drained := e.drainDone != nil
e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
return err
}
@@ -851,7 +878,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
e.state = stateConnected
+ drained := e.drainDone != nil
e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
e.waiterQueue.Notify(waiter.EventOut)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d84171b0c..f26b28632 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -73,8 +73,8 @@ type endpoint struct {
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
- lastErrorMu sync.Mutex `state:"nosave"`
- lastError *tcpip.Error
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
// The following fields are used to manage the receive queue. The
// protocol goroutine adds ready-for-delivery segments to rcvList,
@@ -92,7 +92,7 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
id stack.TransportEndpointID
state endpointState
- isPortReserved bool
+ isPortReserved bool `state:"manual"`
isRegistered bool
boundNICID tcpip.NICID
route stack.Route `state:"manual"`
@@ -105,12 +105,12 @@ type endpoint struct {
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
- effectiveNetProtos []tcpip.NetworkProtocolNumber
+ effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
// hardError is meaningful only when state is stateError, it stores the
// error to be returned when read/write syscalls are called and the
// endpoint is in this state.
- hardError *tcpip.Error
+ hardError *tcpip.Error `state:".(string)"`
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -203,9 +203,15 @@ type endpoint struct {
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
+ // The goroutine undrain notification channel.
+ undrain chan struct{} `state:"nosave"`
+
// probe if not nil is invoked on every received segment. It is passed
// a copy of the current state of the endpoint.
probe stack.TCPProbeFunc `state:"nosave"`
+
+ // The following are only used to assist the restore run to re-connect.
+ connectingAddress tcpip.Address
}
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
@@ -786,6 +792,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
+ connectingAddr := addr.Addr
+
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
@@ -891,9 +899,10 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.route = r.Clone()
e.boundNICID = nicid
e.effectiveNetProtos = netProtos
+ e.connectingAddress = connectingAddr
e.workerRunning = true
- go e.protocolMainLoop(false) // S/R-FIXME
+ go e.protocolMainLoop(false) // S/R-SAFE: will be drained before save.
return tcpip.ErrConnectStarted
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index dbb70ff21..ebab7006d 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -6,6 +6,7 @@ package tcp
import (
"fmt"
+ "sync"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
@@ -22,27 +23,43 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported endpoint state: " + e.Err.Error()
}
+func (e *endpoint) drainSegmentLocked() {
+ // Drain only up to once.
+ if e.drainDone != nil {
+ return
+ }
+
+ e.drainDone = make(chan struct{})
+ e.undrain = make(chan struct{})
+ e.mu.Unlock()
+
+ e.notificationWaker.Assert()
+ <-e.drainDone
+
+ e.mu.Lock()
+}
+
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
e.segmentQueue.setLimit(0)
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
switch e.state {
case stateInitial:
case stateBound:
case stateListen:
if !e.segmentQueue.empty() {
- e.mu.RUnlock()
- e.drainDone = make(chan struct{}, 1)
- e.notificationWaker.Assert()
- <-e.drainDone
- e.mu.RLock()
+ e.drainSegmentLocked()
}
case stateConnecting:
- panic(ErrSaveRejection{fmt.Errorf("endpoint in connecting state upon save: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ e.drainSegmentLocked()
+ if e.state != stateConnected {
+ break
+ }
+ fallthrough
case stateConnected:
// FIXME
panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
@@ -56,28 +73,12 @@ func (e *endpoint) beforeSave() {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
e.stack = stack.StackFromEnv
+ e.segmentQueue.setLimit(2 * e.rcvBufSize)
+ e.workMu.Init()
- if e.state == stateListen {
- e.state = stateBound
- backlog := cap(e.acceptedChan)
- e.acceptedChan = nil
- defer func() {
- if err := e.Listen(backlog); err != nil {
- panic("endpoint listening failed: " + err.String())
- }
- }()
- }
-
- if e.state == stateBound {
- e.state = stateInitial
- defer func() {
- if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
- panic("endpoint binding failed: " + err.String())
- }
- }()
- }
-
- if e.state == stateInitial {
+ state := e.state
+ switch state {
+ case stateInitial, stateBound, stateListen, stateConnecting:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
@@ -89,8 +90,29 @@ func (e *endpoint) afterLoad() {
}
}
- e.segmentQueue.setLimit(2 * e.rcvBufSize)
- e.workMu.Init()
+ switch state {
+ case stateBound, stateListen, stateConnecting:
+ e.state = stateInitial
+ if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case stateListen:
+ backlog := cap(e.acceptedChan)
+ e.acceptedChan = nil
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case stateConnecting:
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ }
}
// saveAcceptedChan is invoked by stateify.
@@ -126,3 +148,96 @@ type endpointChan struct {
buffer []*endpoint
cap int
}
+
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = loadError(s)
+}
+
+// saveHardError is invoked by stateify.
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
+ return ""
+ }
+
+ return e.hardError.String()
+}
+
+// loadHardError is invoked by stateify.
+func (e *endpoint) loadHardError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.hardError = loadError(s)
+}
+
+var messageToError map[string]*tcpip.Error
+
+var populate sync.Once
+
+func loadError(s string) *tcpip.Error {
+ populate.Do(func() {
+ var errors = []*tcpip.Error{
+ tcpip.ErrUnknownProtocol,
+ tcpip.ErrUnknownNICID,
+ tcpip.ErrUnknownProtocolOption,
+ tcpip.ErrDuplicateNICID,
+ tcpip.ErrDuplicateAddress,
+ tcpip.ErrNoRoute,
+ tcpip.ErrBadLinkEndpoint,
+ tcpip.ErrAlreadyBound,
+ tcpip.ErrInvalidEndpointState,
+ tcpip.ErrAlreadyConnecting,
+ tcpip.ErrAlreadyConnected,
+ tcpip.ErrNoPortAvailable,
+ tcpip.ErrPortInUse,
+ tcpip.ErrBadLocalAddress,
+ tcpip.ErrClosedForSend,
+ tcpip.ErrClosedForReceive,
+ tcpip.ErrWouldBlock,
+ tcpip.ErrConnectionRefused,
+ tcpip.ErrTimeout,
+ tcpip.ErrAborted,
+ tcpip.ErrConnectStarted,
+ tcpip.ErrDestinationRequired,
+ tcpip.ErrNotSupported,
+ tcpip.ErrQueueSizeNotSupported,
+ tcpip.ErrNotConnected,
+ tcpip.ErrConnectionReset,
+ tcpip.ErrConnectionAborted,
+ tcpip.ErrNoSuchFile,
+ tcpip.ErrInvalidOptionValue,
+ tcpip.ErrNoLinkAddress,
+ tcpip.ErrBadAddress,
+ }
+
+ messageToError = make(map[string]*tcpip.Error)
+ for _, e := range errors {
+ if messageToError[e.String()] != nil {
+ panic("tcpip errors with duplicated message: " + e.String())
+ }
+ messageToError[e.String()] = e
+ }
+ })
+
+ e, ok := messageToError[s]
+ if !ok {
+ panic("unknown error message: " + s)
+ }
+
+ return e
+}