diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 177 |
4 files changed, 207 insertions, 51 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 9a5b13066..a71cb444f 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -379,8 +379,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.handleListenSegment(ctx, s) s.decRef() } - e.drainDone <- struct{}{} - return nil + close(e.drainDone) + <-e.undrain } case wakerForNewSegment: diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4d20f4d3f..698e2b440 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -296,6 +296,21 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } +func (h *handshake) handleSegment(s *segment) *tcpip.Error { + h.sndWnd = s.window + if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { + h.sndWnd <<= uint8(h.sndWndScale) + } + + switch h.state { + case handshakeSynRcvd: + return h.synRcvdState(s) + case handshakeSynSent: + return h.synSentState(s) + } + return nil +} + // processSegments goes through the segment queue and processes up to // maxSegmentsPerWake (if they're available). func (h *handshake) processSegments() *tcpip.Error { @@ -305,18 +320,7 @@ func (h *handshake) processSegments() *tcpip.Error { return nil } - h.sndWnd = s.window - if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { - h.sndWnd <<= uint8(h.sndWndScale) - } - - var err *tcpip.Error - switch h.state { - case handshakeSynRcvd: - err = h.synRcvdState(s) - case handshakeSynSent: - err = h.synSentState(s) - } + err := h.handleSegment(s) s.decRef() if err != nil { return err @@ -364,6 +368,10 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } + if n¬ifyDrain != 0 { + close(h.ep.drainDone) + <-h.ep.undrain + } } // Wait for notification. @@ -434,6 +442,20 @@ func (h *handshake) execute() *tcpip.Error { if n¬ifyClose != 0 { return tcpip.ErrAborted } + if n¬ifyDrain != 0 { + for s := h.ep.segmentQueue.dequeue(); s != nil; s = h.ep.segmentQueue.dequeue() { + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + if h.state == handshakeCompleted { + return nil + } + } + close(h.ep.drainDone) + <-h.ep.undrain + } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -833,7 +855,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.mu.Lock() e.state = stateError e.hardError = err + drained := e.drainDone != nil e.mu.Unlock() + if drained { + close(e.drainDone) + <-e.undrain + } return err } @@ -851,7 +878,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() e.state = stateConnected + drained := e.drainDone != nil e.mu.Unlock() + if drained { + close(e.drainDone) + <-e.undrain + } e.waiterQueue.Notify(waiter.EventOut) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d84171b0c..f26b28632 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -73,8 +73,8 @@ type endpoint struct { // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` // The following fields are used to manage the receive queue. The // protocol goroutine adds ready-for-delivery segments to rcvList, @@ -92,7 +92,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` id stack.TransportEndpointID state endpointState - isPortReserved bool + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID route stack.Route `state:"manual"` @@ -105,12 +105,12 @@ type endpoint struct { // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber + effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` // hardError is meaningful only when state is stateError, it stores the // error to be returned when read/write syscalls are called and the // endpoint is in this state. - hardError *tcpip.Error + hardError *tcpip.Error `state:".(string)"` // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -203,9 +203,15 @@ type endpoint struct { // The goroutine drain completion notification channel. drainDone chan struct{} `state:"nosave"` + // The goroutine undrain notification channel. + undrain chan struct{} `state:"nosave"` + // probe if not nil is invoked on every received segment. It is passed // a copy of the current state of the endpoint. probe stack.TCPProbeFunc `state:"nosave"` + + // The following are only used to assist the restore run to re-connect. + connectingAddress tcpip.Address } func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { @@ -786,6 +792,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() + connectingAddr := addr.Addr + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err @@ -891,9 +899,10 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.route = r.Clone() e.boundNICID = nicid e.effectiveNetProtos = netProtos + e.connectingAddress = connectingAddr e.workerRunning = true - go e.protocolMainLoop(false) // S/R-FIXME + go e.protocolMainLoop(false) // S/R-SAFE: will be drained before save. return tcpip.ErrConnectStarted } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index dbb70ff21..ebab7006d 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -6,6 +6,7 @@ package tcp import ( "fmt" + "sync" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" @@ -22,27 +23,43 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported endpoint state: " + e.Err.Error() } +func (e *endpoint) drainSegmentLocked() { + // Drain only up to once. + if e.drainDone != nil { + return + } + + e.drainDone = make(chan struct{}) + e.undrain = make(chan struct{}) + e.mu.Unlock() + + e.notificationWaker.Assert() + <-e.drainDone + + e.mu.Lock() +} + // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. e.segmentQueue.setLimit(0) - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() switch e.state { case stateInitial: case stateBound: case stateListen: if !e.segmentQueue.empty() { - e.mu.RUnlock() - e.drainDone = make(chan struct{}, 1) - e.notificationWaker.Assert() - <-e.drainDone - e.mu.RLock() + e.drainSegmentLocked() } case stateConnecting: - panic(ErrSaveRejection{fmt.Errorf("endpoint in connecting state upon save: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + e.drainSegmentLocked() + if e.state != stateConnected { + break + } + fallthrough case stateConnected: // FIXME panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) @@ -56,28 +73,12 @@ func (e *endpoint) beforeSave() { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { e.stack = stack.StackFromEnv + e.segmentQueue.setLimit(2 * e.rcvBufSize) + e.workMu.Init() - if e.state == stateListen { - e.state = stateBound - backlog := cap(e.acceptedChan) - e.acceptedChan = nil - defer func() { - if err := e.Listen(backlog); err != nil { - panic("endpoint listening failed: " + err.String()) - } - }() - } - - if e.state == stateBound { - e.state = stateInitial - defer func() { - if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil { - panic("endpoint binding failed: " + err.String()) - } - }() - } - - if e.state == stateInitial { + state := e.state + switch state { + case stateInitial, stateBound, stateListen, stateConnecting: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { @@ -89,8 +90,29 @@ func (e *endpoint) afterLoad() { } } - e.segmentQueue.setLimit(2 * e.rcvBufSize) - e.workMu.Init() + switch state { + case stateBound, stateListen, stateConnecting: + e.state = stateInitial + if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil { + panic("endpoint binding failed: " + err.String()) + } + } + + switch state { + case stateListen: + backlog := cap(e.acceptedChan) + e.acceptedChan = nil + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + } + + switch state { + case stateConnecting: + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + } } // saveAcceptedChan is invoked by stateify. @@ -126,3 +148,96 @@ type endpointChan struct { buffer []*endpoint cap int } + +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = loadError(s) +} + +// saveHardError is invoked by stateify. +func (e *endpoint) saveHardError() string { + if e.hardError == nil { + return "" + } + + return e.hardError.String() +} + +// loadHardError is invoked by stateify. +func (e *endpoint) loadHardError(s string) { + if s == "" { + return + } + + e.hardError = loadError(s) +} + +var messageToError map[string]*tcpip.Error + +var populate sync.Once + +func loadError(s string) *tcpip.Error { + populate.Do(func() { + var errors = []*tcpip.Error{ + tcpip.ErrUnknownProtocol, + tcpip.ErrUnknownNICID, + tcpip.ErrUnknownProtocolOption, + tcpip.ErrDuplicateNICID, + tcpip.ErrDuplicateAddress, + tcpip.ErrNoRoute, + tcpip.ErrBadLinkEndpoint, + tcpip.ErrAlreadyBound, + tcpip.ErrInvalidEndpointState, + tcpip.ErrAlreadyConnecting, + tcpip.ErrAlreadyConnected, + tcpip.ErrNoPortAvailable, + tcpip.ErrPortInUse, + tcpip.ErrBadLocalAddress, + tcpip.ErrClosedForSend, + tcpip.ErrClosedForReceive, + tcpip.ErrWouldBlock, + tcpip.ErrConnectionRefused, + tcpip.ErrTimeout, + tcpip.ErrAborted, + tcpip.ErrConnectStarted, + tcpip.ErrDestinationRequired, + tcpip.ErrNotSupported, + tcpip.ErrQueueSizeNotSupported, + tcpip.ErrNotConnected, + tcpip.ErrConnectionReset, + tcpip.ErrConnectionAborted, + tcpip.ErrNoSuchFile, + tcpip.ErrInvalidOptionValue, + tcpip.ErrNoLinkAddress, + tcpip.ErrBadAddress, + } + + messageToError = make(map[string]*tcpip.Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} |