summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorZhaozhong Ni <nzz@google.com>2018-07-10 09:22:37 -0700
committerShentubot <shentubot@google.com>2018-07-10 09:23:35 -0700
commitb1683df90bf81974e9e309ed66edaff30537c1be (patch)
tree728061e78466951d1f069e5a73358f84aa16d6c0
parentafd655a5d8b9d9bc747ee99b1ec2475cc526c996 (diff)
netstack: tcp socket connected state S/R support.
PiperOrigin-RevId: 203958972 Change-Id: Ia6fe16547539296d48e2c6731edacdd96bd6e93c
-rw-r--r--pkg/sentry/kernel/BUILD5
-rw-r--r--pkg/sentry/kernel/kernel.go6
-rw-r--r--pkg/sentry/kernel/kernel_state.go31
-rw-r--r--pkg/tcpip/stack/stack_global_state.go2
-rw-r--r--pkg/tcpip/tcpip.go36
-rw-r--r--pkg/tcpip/transport/tcp/BUILD7
-rw-r--r--pkg/tcpip/transport/tcp/accept.go11
-rw-r--r--pkg/tcpip/transport/tcp/connect.go43
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go72
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go193
-rw-r--r--pkg/tcpip/transport/tcp/segment.go8
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go4
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go51
-rw-r--r--pkg/tcpip/transport/tcp/snd.go4
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go49
15 files changed, 430 insertions, 92 deletions
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index b2a55ddff..07568b47c 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -12,6 +12,7 @@ go_stateify(
"fs_context.go",
"ipc_namespace.go",
"kernel.go",
+ "kernel_state.go",
"pending_signals.go",
"pending_signals_state.go",
"process_group_list.go",
@@ -45,10 +46,11 @@ go_stateify(
"vdso.go",
"version.go",
],
- out = "kernel_state.go",
+ out = "kernel_autogen_state.go",
imports = [
"gvisor.googlesource.com/gvisor/pkg/sentry/arch",
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs",
+ "gvisor.googlesource.com/gvisor/pkg/tcpip",
],
package = "kernel",
)
@@ -117,6 +119,7 @@ go_library(
"fs_context.go",
"ipc_namespace.go",
"kernel.go",
+ "kernel_autogen_state.go",
"kernel_state.go",
"pending_signals.go",
"pending_signals_list.go",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 5662b8f08..64439cd9d 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -57,6 +57,7 @@ import (
sentrytime "gvisor.googlesource.com/gvisor/pkg/sentry/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
"gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
)
// Kernel represents an emulated Linux kernel. It must be initialized by calling
@@ -158,6 +159,9 @@ type Kernel struct {
// exitErr is the error causing the sandbox to exit, if any. It is
// protected by extMu.
exitErr error
+
+ // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
+ danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
}
// InitKernelArgs holds arguments to Init.
@@ -422,6 +426,8 @@ func (k *Kernel) LoadFrom(r io.Reader, p platform.Platform, net inet.Stack) erro
return err
}
+ tcpip.AsyncLoading.Wait()
+
log.Infof("Overall load took [%s]", time.Since(loadStart))
// Applications may size per-cpu structures based on k.applicationCores, so
diff --git a/pkg/sentry/kernel/kernel_state.go b/pkg/sentry/kernel/kernel_state.go
new file mode 100644
index 000000000..bb2d5102d
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_state.go
@@ -0,0 +1,31 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// saveDanglingEndpoints is invoked by stateify.
+func (k *Kernel) saveDanglingEndpoints() []tcpip.Endpoint {
+ return tcpip.GetDanglingEndpoints()
+}
+
+// loadDanglingEndpoints is invoked by stateify.
+func (k *Kernel) loadDanglingEndpoints(es []tcpip.Endpoint) {
+ for _, e := range es {
+ tcpip.AddDanglingEndpoint(e)
+ }
+}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
index 6d261ce96..b6c095efb 100644
--- a/pkg/tcpip/stack/stack_global_state.go
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -15,5 +15,5 @@
package stack
// StackFromEnv is the global stack created in restore run.
-// FIXME: remove this variable once tcpip S/R is fully supported.
+// FIXME
var StackFromEnv *Stack
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4107c0f78..eb1e4645d 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -33,6 +33,7 @@ import (
"fmt"
"strconv"
"strings"
+ "sync"
"time"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
@@ -562,3 +563,38 @@ type ProtocolAddress struct {
// Address is a network address.
Address Address
}
+
+// danglingEndpointsMu protects access to danglingEndpoints.
+var danglingEndpointsMu sync.Mutex
+
+// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+var danglingEndpoints = make(map[Endpoint]struct{})
+
+// GetDanglingEndpoints returns all dangling endpoints.
+func GetDanglingEndpoints() []Endpoint {
+ es := make([]Endpoint, 0, len(danglingEndpoints))
+ danglingEndpointsMu.Lock()
+ for e, _ := range danglingEndpoints {
+ es = append(es, e)
+ }
+ danglingEndpointsMu.Unlock()
+ return es
+}
+
+// AddDanglingEndpoint adds a dangling endpoint.
+func AddDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ danglingEndpoints[e] = struct{}{}
+ danglingEndpointsMu.Unlock()
+}
+
+// DeleteDanglingEndpoint removes a dangling endpoint.
+func DeleteDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ delete(danglingEndpoints, e)
+ danglingEndpointsMu.Unlock()
+}
+
+// AsyncLoading is the global barrier for asynchronous endpoint loading
+// activities.
+var AsyncLoading sync.WaitGroup
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 6cb0ebab2..6a2f42a12 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -10,11 +10,16 @@ go_stateify(
"endpoint.go",
"endpoint_state.go",
"rcv.go",
+ "segment.go",
"segment_heap.go",
+ "segment_queue.go",
+ "segment_state.go",
"snd.go",
+ "snd_state.go",
"tcp_segment_list.go",
],
out = "tcp_state.go",
+ imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
package = "tcp",
)
@@ -43,7 +48,9 @@ go_library(
"segment.go",
"segment_heap.go",
"segment_queue.go",
+ "segment_state.go",
"snd.go",
+ "snd_state.go",
"tcp_segment_list.go",
"tcp_state.go",
"timer.go",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index ae4359ff4..d6d2b4555 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -78,7 +78,8 @@ func encodeMSS(mss uint16) uint32 {
// to go above a threshold.
var synRcvdCount struct {
sync.Mutex
- value uint64
+ value uint64
+ pending sync.WaitGroup
}
// listenContext is used by a listening endpoint to store state used while
@@ -112,6 +113,7 @@ func incSynRcvdCount() bool {
return false
}
+ synRcvdCount.pending.Add(1)
synRcvdCount.value++
return true
@@ -125,6 +127,7 @@ func decSynRcvdCount() {
defer synRcvdCount.Unlock()
synRcvdCount.value--
+ synRcvdCount.pending.Done()
}
// newListenContext creates a new listen context.
@@ -302,7 +305,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
s.incRef()
- go e.handleSynSegment(ctx, s, &opts) // S/R-FIXME
+ go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
} else {
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
// Send SYN with window scaling because we currently
@@ -391,10 +394,12 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
return nil
}
if n&notifyDrain != 0 {
- for s := e.segmentQueue.dequeue(); s != nil; s = e.segmentQueue.dequeue() {
+ for !e.segmentQueue.empty() {
+ s := e.segmentQueue.dequeue()
e.handleListenSegment(ctx, s)
s.decRef()
}
+ synRcvdCount.pending.Wait()
close(e.drainDone)
<-e.undrain
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index afdea2b53..33bf4fc0b 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -453,7 +453,8 @@ func (h *handshake) execute() *tcpip.Error {
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
- for s := h.ep.segmentQueue.dequeue(); s != nil; s = h.ep.segmentQueue.dequeue() {
+ for !h.ep.segmentQueue.empty() {
+ s := h.ep.segmentQueue.dequeue()
err := h.handleSegment(s)
s.decRef()
if err != nil {
@@ -823,15 +824,13 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
var closeTimer *time.Timer
var closeWaker sleep.Waker
defer func() {
// e.mu is expected to be hold upon entering this section.
- e.completeWorkerLocked()
-
if e.snd != nil {
e.snd.resendTimer.cleanup()
}
@@ -840,6 +839,8 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
closeTimer.Stop()
}
+ e.completeWorkerLocked()
+
if e.drainDone != nil {
close(e.drainDone)
}
@@ -850,7 +851,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}()
- if !passive {
+ if handshake {
// This is an active connection, so we must initiate the 3-way
// handshake, and then inform potential waiters about its
// completion.
@@ -960,6 +961,17 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
closeWaker.Assert()
})
}
+
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ if err := e.handleSegments(); err != nil {
+ return err
+ }
+ }
+ close(e.drainDone)
+ <-e.undrain
+ }
+
return nil
},
},
@@ -971,6 +983,27 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
s.AddWaker(funcs[i].w, i)
}
+ // The following assertions and notifications are needed for restored
+ // endpoints. Fresh newly created endpoints have empty states and should
+ // not invoke any.
+ e.segmentQueue.mu.Lock()
+ if !e.segmentQueue.list.Empty() {
+ e.newSegmentWaker.Assert()
+ }
+ e.segmentQueue.mu.Unlock()
+
+ e.rcvListMu.Lock()
+ if !e.rcvList.Empty() {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ e.rcvListMu.Unlock()
+
+ e.mu.RLock()
+ if e.workerCleanup {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+ e.mu.RUnlock()
+
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index cb105b863..8b9a81f6a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -80,7 +80,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
netProto tcpip.NetworkProtocolNumber
- waiterQueue *waiter.Queue
+ waiterQueue *waiter.Queue `state:"wait"`
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
@@ -95,8 +95,8 @@ type endpoint struct {
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
- rcvList segmentList
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
rcvClosed bool
rcvBufSize int
rcvBufUsed int
@@ -104,8 +104,8 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
id stack.TransportEndpointID
- state endpointState
- isPortReserved bool `state:"manual"`
+ state endpointState `state:".(endpointState)"`
+ isPortReserved bool `state:"manual"`
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
route stack.Route `state:"manual"`
@@ -131,7 +131,7 @@ type endpoint struct {
// workerCleanup specifies if the worker goroutine must perform cleanup
// before exitting. This can only be set to true when workerRunning is
// also true, and they're both protected by the mutex.
- workerCleanup bool `state:"zerovalue"`
+ workerCleanup bool
// sendTSOk is used to indicate when the TS Option has been negotiated.
// When sendTSOk is true every non-RST segment should carry a TS as per
@@ -166,7 +166,7 @@ type endpoint struct {
// segmentQueue is used to hand received segments to the protocol
// goroutine. Segments are queued as long as the queue is not full,
// and dropped when it is.
- segmentQueue segmentQueue `state:"zerovalue"`
+ segmentQueue segmentQueue `state:"wait"`
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
@@ -179,7 +179,7 @@ type endpoint struct {
sndBufUsed int
sndClosed bool
sndBufInQueue seqnum.Size
- sndQueue segmentList
+ sndQueue segmentList `state:"wait"`
sndWaker sleep.Waker `state:"manual"`
sndCloseWaker sleep.Waker `state:"manual"`
@@ -201,17 +201,21 @@ type endpoint struct {
// notifyFlags is a bitmask of flags used to indicate to the protocol
// goroutine what it was notified; this is only accessed atomically.
- notifyFlags uint32 `state:"zerovalue"`
+ notifyFlags uint32 `state:"nosave"`
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- acceptedChan chan *endpoint `state:".(endpointChan)"`
+ acceptedChan chan *endpoint `state:"manual"`
+
+ // acceptedEndpoints is only used to save / restore the channel buffer.
+ // FIXME
+ acceptedEndpoints []*endpoint
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
- rcv *receiver
- snd *sender
+ rcv *receiver `state:"wait"`
+ snd *sender `state:"wait"`
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
@@ -224,6 +228,7 @@ type endpoint struct {
probe stack.TCPProbeFunc `state:"nosave"`
// The following are only used to assist the restore run to re-connect.
+ bindAddress tcpip.Address
connectingAddress tcpip.Address
}
@@ -357,6 +362,7 @@ func (e *endpoint) Close() {
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
+ tcpip.AddDanglingEndpoint(e)
if !e.workerRunning {
e.cleanupLocked()
} else {
@@ -376,9 +382,12 @@ func (e *endpoint) cleanupLocked() {
if e.acceptedChan != nil {
close(e.acceptedChan)
for n := range e.acceptedChan {
+ n.mu.Lock()
n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
n.Close()
}
+ e.acceptedChan = nil
}
e.workerCleanup = false
@@ -387,6 +396,7 @@ func (e *endpoint) cleanupLocked() {
}
e.route.Release()
+ tcpip.DeleteDanglingEndpoint(e)
}
// Read reads data from the endpoint.
@@ -801,6 +811,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return e.connect(addr, true, true)
+}
+
+// connect connects the endpoint to its peer. In the normal non-S/R case, the
+// new connection is expected to run the main goroutine and perform handshake.
+// In restore of previously connected endpoints, both ends will be passively
+// created (so no new handshaking is done); for stack-accepted connections not
+// yet accepted by the app, they are restored without running the main goroutine
+// here.
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -912,9 +932,27 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.boundNICID = nicid
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
- e.workerRunning = true
- go e.protocolMainLoop(false) // S/R-SAFE: will be drained before save.
+ // Connect in the restore phase does not perform handshake. Restore its
+ // connection setting here.
+ if !handshake {
+ e.segmentQueue.mu.Lock()
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for s := l.Front(); s != nil; s = s.Next() {
+ s.id = e.id
+ s.route = r.Clone()
+ e.sndWaker.Assert()
+ }
+ }
+ e.segmentQueue.mu.Unlock()
+ e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
+ e.state = stateConnected
+ }
+
+ if run {
+ e.workerRunning = true
+ go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ }
return tcpip.ErrConnectStarted
}
@@ -999,6 +1037,9 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error {
if len(e.acceptedChan) > backlog {
return tcpip.ErrInvalidEndpointState
}
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
origChan := e.acceptedChan
e.acceptedChan = make(chan *endpoint, backlog)
close(origChan)
@@ -1036,7 +1077,7 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error {
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
e.waiterQueue = waiterQueue
e.workerRunning = true
- go e.protocolMainLoop(true) // S/R-FIXME
+ go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
}
// Accept returns a new endpoint if a peer has established a connection
@@ -1077,6 +1118,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (ret
return tcpip.ErrAlreadyBound
}
+ e.bindAddress = addr.Addr
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index aa4ccea75..43765d425 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -17,8 +17,10 @@ package tcp
import (
"fmt"
"sync"
+ "time"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
)
@@ -32,7 +34,7 @@ func (e *endpoint) drainSegmentLocked() {
e.undrain = make(chan struct{})
e.mu.Unlock()
- e.notificationWaker.Assert()
+ e.notifyProtocolGoroutine(notifyDrain)
<-e.drainDone
e.mu.Lock()
@@ -48,37 +50,103 @@ func (e *endpoint) beforeSave() {
switch e.state {
case stateInitial, stateBound:
- case stateListen:
- if !e.segmentQueue.empty() {
- e.drainSegmentLocked()
+ case stateListen, stateConnecting, stateConnected:
+ if e.state == stateConnected && !e.workerRunning {
+ // The endpoint must be in acceptedChan.
+ break
}
- case stateConnecting:
e.drainSegmentLocked()
- if e.state != stateConnected {
+ if e.state != stateClosed && e.state != stateError {
+ if !e.workerRunning {
+ panic("endpoint has no worker running in listen, connecting, or connected state")
+ }
break
}
fallthrough
- case stateConnected:
- // FIXME
- panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
- case stateClosed, stateError:
+ case stateError, stateClosed:
+ for e.state == stateError && e.workerRunning {
+ e.mu.Unlock()
+ time.Sleep(100 * time.Millisecond)
+ e.mu.Lock()
+ }
if e.workerRunning {
- panic(fmt.Sprintf("endpoint still has worker running in closed or error state"))
+ panic("endpoint still has worker running in closed or error state")
}
default:
panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
}
+
+ if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
+ panic("endpoint still has waiters upon save")
+ }
+
+ if !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) {
+ panic("endpoint port must and must only be reserved in bound or listen state")
+ }
+
+ if e.acceptedChan != nil {
+ close(e.acceptedChan)
+ e.acceptedEndpoints = make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
+ i := 0
+ for ep := range e.acceptedChan {
+ e.acceptedEndpoints[i] = ep
+ i++
+ }
+ if i != len(e.acceptedEndpoints) {
+ panic("endpoint acceptedChan buffer got consumed by background context")
+ }
+ }
+}
+
+// saveState is invoked by stateify.
+func (e *endpoint) saveState() endpointState {
+ return e.state
+}
+
+// Endpoint loading must be done in the following ordering by their state, to
+// avoid dangling connecting w/o listening peer, and to avoid conflicts in port
+// reservation.
+var connectedLoading sync.WaitGroup
+var listenLoading sync.WaitGroup
+var connectingLoading sync.WaitGroup
+
+// Bound endpoint loading happens last.
+
+// loadState is invoked by stateify.
+func (e *endpoint) loadState(state endpointState) {
+ // This is to ensure that the loading wait groups include all applicable
+ // endpoints before any asynchronous calls to the Wait() methods.
+ switch state {
+ case stateConnected:
+ connectedLoading.Add(1)
+ case stateListen:
+ listenLoading.Add(1)
+ case stateConnecting:
+ connectingLoading.Add(1)
+ }
+ e.state = state
}
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
+ // We load acceptedChan buffer indirectly here. Note that closed
+ // endpoints might not need to allocate the channel.
+ // FIXME
+ if cap(e.acceptedEndpoints) > 0 {
+ e.acceptedChan = make(chan *endpoint, cap(e.acceptedEndpoints))
+ for _, ep := range e.acceptedEndpoints {
+ e.acceptedChan <- ep
+ }
+ e.acceptedEndpoints = nil
+ }
+
e.stack = stack.StackFromEnv
e.segmentQueue.setLimit(2 * e.rcvBufSize)
e.workMu.Init()
state := e.state
switch state {
- case stateInitial, stateBound, stateListen, stateConnecting:
+ case stateInitial, stateBound, stateListen, stateConnecting, stateConnected:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
@@ -90,65 +158,72 @@ func (e *endpoint) afterLoad() {
}
}
- switch state {
- case stateBound, stateListen, stateConnecting:
+ bind := func() {
e.state = stateInitial
- if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
+ if len(e.bindAddress) == 0 {
+ e.bindAddress = e.id.LocalAddress
+ }
+ if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}, nil); err != nil {
panic("endpoint binding failed: " + err.String())
}
}
switch state {
- case stateListen:
- backlog := cap(e.acceptedChan)
- e.acceptedChan = nil
- if err := e.Listen(backlog); err != nil {
- panic("endpoint listening failed: " + err.String())
+ case stateConnected:
+ bind()
+ if len(e.connectingAddress) == 0 {
+ // This endpoint is accepted by netstack but not yet by
+ // the app. If the endpoint is IPv6 but the remote
+ // address is IPv4, we need to connect as IPv6 so that
+ // dual-stack mode can be properly activated.
+ if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ } else {
+ e.connectingAddress = e.id.RemoteAddress
+ }
}
- }
-
- switch state {
- case stateConnecting:
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
+ connectedLoading.Done()
+ case stateListen:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ bind()
+ backlog := cap(e.acceptedChan)
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ listenLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateConnecting:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ bind()
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectingLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateBound:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateClosed, stateError:
+ tcpip.DeleteDanglingEndpoint(e)
}
}
-// saveAcceptedChan is invoked by stateify.
-func (e *endpoint) saveAcceptedChan() endpointChan {
- if e.acceptedChan == nil {
- return endpointChan{}
- }
- close(e.acceptedChan)
- buffer := make([]*endpoint, 0, len(e.acceptedChan))
- for ep := range e.acceptedChan {
- buffer = append(buffer, ep)
- }
- if len(buffer) != cap(buffer) {
- panic("endpoint.acceptedChan buffer got consumed by background context")
- }
- c := cap(e.acceptedChan)
- e.acceptedChan = nil
- return endpointChan{buffer: buffer, cap: c}
-}
-
-// loadAcceptedChan is invoked by stateify.
-func (e *endpoint) loadAcceptedChan(c endpointChan) {
- if c.cap == 0 {
- return
- }
- e.acceptedChan = make(chan *endpoint, c.cap)
- for _, ep := range c.buffer {
- e.acceptedChan <- ep
- }
-}
-
-type endpointChan struct {
- buffer []*endpoint
- cap int
-}
-
// saveLastError is invoked by stateify.
func (e *endpoint) saveLastError() string {
if e.lastError == nil {
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index a90f6661d..40928ba2c 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -39,9 +39,9 @@ const (
type segment struct {
segmentEntry
refCnt int32
- id stack.TransportEndpointID
- route stack.Route `state:"manual"`
- data buffer.VectorisedView
+ id stack.TransportEndpointID `state:"manual"`
+ route stack.Route `state:"manual"`
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
@@ -55,7 +55,7 @@ type segment struct {
// parsedOptions stores the parsed values from the options in the segment.
parsedOptions header.TCPOptions
- options []byte
+ options []byte `state:".([]byte)"`
}
func newSegment(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) *segment {
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 83f554ebd..2ddcf5f10 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -22,8 +22,8 @@ import (
// segmentQueue is a bounded, thread-safe queue of TCP segments.
type segmentQueue struct {
- mu sync.Mutex
- list segmentList
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
limit int
used int
}
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
new file mode 100644
index 000000000..22f0bbf18
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -0,0 +1,51 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// saveData is invoked by stateify.
+func (s *segment) saveData() buffer.VectorisedView {
+ // We cannot save s.data directly as s.data.views may alias to s.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return s.data.Clone(nil)
+}
+
+// loadData is invoked by stateify.
+func (s *segment) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the s.data = data.Clone(s.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing s.views for data.views.
+ s.data = data
+}
+
+// saveOptions is invoked by stateify.
+func (s *segment) saveOptions() []byte {
+ // We cannot save s.options directly as it may point to s.data's trimmed
+ // tail, which is not allowed by state framework (in-struct pointer).
+ b := make([]byte, 0, cap(s.options))
+ return append(b, s.options...)
+}
+
+// loadOptions is invoked by stateify.
+func (s *segment) loadOptions(options []byte) {
+ // NOTE: We cannot point s.options back into s.data's trimmed tail. But
+ // it is OK as they do not need to aliased. Plus, options is already
+ // allocated so there is no cost here.
+ s.options = options
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index a9892eb64..7dfbf6384 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -38,7 +38,7 @@ type sender struct {
ep *endpoint
// lastSendTime is the timestamp when the last packet was sent.
- lastSendTime time.Time
+ lastSendTime time.Time `state:".(unixTime)"`
// dupAckCount is the number of duplicated acks received. It is used for
// fast retransmit.
@@ -81,7 +81,7 @@ type sender struct {
rttMeasureSeqNum seqnum.Value
// rttMeasureTime is the time when the rttMeasureSeqNum was sent.
- rttMeasureTime time.Time
+ rttMeasureTime time.Time `state:".(unixTime)"`
closed bool
writeNext *segment
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
new file mode 100644
index 000000000..33c8867f4
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -0,0 +1,49 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastSendTime is invoked by stateify.
+func (s *sender) saveLastSendTime() unixTime {
+ return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (s *sender) loadLastSendTime(unix unixTime) {
+ s.lastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRttMeasureTime is invoked by stateify.
+func (s *sender) saveRttMeasureTime() unixTime {
+ return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()}
+}
+
+// loadRttMeasureTime is invoked by stateify.
+func (s *sender) loadRttMeasureTime(unix unixTime) {
+ s.rttMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// afterLoad is invoked by stateify.
+func (s *sender) afterLoad() {
+ s.resendTimer.init(&s.resendWaker)
+}