summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/endpoint_state.go
blob: dbb70ff2106a8aff061bb8d06851be12d23369ec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Copyright 2017 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tcp

import (
	"fmt"

	"gvisor.googlesource.com/gvisor/pkg/tcpip"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
)

// ErrSaveRejection indicates a failed save due to unsupported tcp endpoint
// state.
type ErrSaveRejection struct {
	Err error
}

// Error returns a sensible description of the save rejection error.
func (e ErrSaveRejection) Error() string {
	return "save rejected due to unsupported endpoint state: " + e.Err.Error()
}

// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
	// Stop incoming packets.
	e.segmentQueue.setLimit(0)

	e.mu.RLock()
	defer e.mu.RUnlock()

	switch e.state {
	case stateInitial:
	case stateBound:
	case stateListen:
		if !e.segmentQueue.empty() {
			e.mu.RUnlock()
			e.drainDone = make(chan struct{}, 1)
			e.notificationWaker.Assert()
			<-e.drainDone
			e.mu.RLock()
		}
	case stateConnecting:
		panic(ErrSaveRejection{fmt.Errorf("endpoint in connecting state upon save: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
	case stateConnected:
		// FIXME
		panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
	case stateClosed:
	case stateError:
	default:
		panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
	}
}

// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
	e.stack = stack.StackFromEnv

	if e.state == stateListen {
		e.state = stateBound
		backlog := cap(e.acceptedChan)
		e.acceptedChan = nil
		defer func() {
			if err := e.Listen(backlog); err != nil {
				panic("endpoint listening failed: " + err.String())
			}
		}()
	}

	if e.state == stateBound {
		e.state = stateInitial
		defer func() {
			if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
				panic("endpoint binding failed: " + err.String())
			}
		}()
	}

	if e.state == stateInitial {
		var ss SendBufferSizeOption
		if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
			if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
				panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
			}
			if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
				panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
			}
		}
	}

	e.segmentQueue.setLimit(2 * e.rcvBufSize)
	e.workMu.Init()
}

// saveAcceptedChan is invoked by stateify.
func (e *endpoint) saveAcceptedChan() endpointChan {
	if e.acceptedChan == nil {
		return endpointChan{}
	}
	close(e.acceptedChan)
	buffer := make([]*endpoint, 0, len(e.acceptedChan))
	for ep := range e.acceptedChan {
		buffer = append(buffer, ep)
	}
	if len(buffer) != cap(buffer) {
		panic("endpoint.acceptedChan buffer got consumed by background context")
	}
	c := cap(e.acceptedChan)
	e.acceptedChan = nil
	return endpointChan{buffer: buffer, cap: c}
}

// loadAcceptedChan is invoked by stateify.
func (e *endpoint) loadAcceptedChan(c endpointChan) {
	if c.cap == 0 {
		return
	}
	e.acceptedChan = make(chan *endpoint, c.cap)
	for _, ep := range c.buffer {
		e.acceptedChan <- ep
	}
}

type endpointChan struct {
	buffer []*endpoint
	cap    int
}