summaryrefslogtreecommitdiffhomepage
path: root/pkg/lisafs/channel.go
blob: 301212e510964ad78f843e6f8abe35f61d7beb32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package lisafs

import (
	"math"
	"runtime"

	"golang.org/x/sys/unix"
	"gvisor.dev/gvisor/pkg/fdchannel"
	"gvisor.dev/gvisor/pkg/flipcall"
	"gvisor.dev/gvisor/pkg/log"
)

var (
	chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes())
)

// maxChannels returns the number of channels a client can create.
//
// The server will reject channel creation requests beyond this (per client).
// Note that we don't want the number of channels to be too large, because each
// accounts for a large region of shared memory.
// TODO(gvisor.dev/issue/6313): Tune the number of channels.
func maxChannels() int {
	maxChans := runtime.GOMAXPROCS(0)
	if maxChans < 2 {
		maxChans = 2
	}
	if maxChans > 4 {
		maxChans = 4
	}
	return maxChans
}

// channel implements Communicator and represents the communication endpoint
// for the client and server and is used to perform fast IPC. Apart from
// communicating data, a channel is also capable of donating file descriptors.
type channel struct {
	fdTracker
	dead   bool
	data   flipcall.Endpoint
	fdChan fdchannel.Endpoint
}

var _ Communicator = (*channel)(nil)

// PayloadBuf implements Communicator.PayloadBuf.
func (ch *channel) PayloadBuf(size uint32) []byte {
	return ch.data.Data()[chanHeaderLen : chanHeaderLen+size]
}

// SndRcvMessage implements Communicator.SndRcvMessage.
func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
	// Write header. Requests can not donate FDs.
	ch.marshalHdr(m, 0 /* numFDs */)

	// One-shot communication. RPCs are expected to be quick rather than block.
	rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen)
	if err != nil {
		// This channel is now unusable.
		ch.dead = true
		// Map the transport errors to EIO, but also log the real error.
		log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err)
		return 0, 0, unix.EIO
	}

	return ch.rcvMsg(rcvDataLen)
}

func (ch *channel) shutdown() {
	ch.data.Shutdown()
}

func (ch *channel) destroy() {
	ch.dead = true
	ch.fdChan.Destroy()
	ch.data.Destroy()
}

// createChannel creates a server side channel. It returns a packet window
// descriptor (for the data channel) and an open socket for the FD channel.
func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) {
	c.channelsMu.Lock()
	defer c.channelsMu.Unlock()
	// If c.channels is nil, the connection has closed.
	if c.channels == nil || len(c.channels) >= maxChannels() {
		return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS
	}
	ch := &channel{}

	// Set up data channel.
	desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize))
	if err != nil {
		return nil, flipcall.PacketWindowDescriptor{}, -1, err
	}
	if err := ch.data.Init(flipcall.ServerSide, desc); err != nil {
		return nil, flipcall.PacketWindowDescriptor{}, -1, err
	}

	// Set up FD channel.
	fdSocks, err := fdchannel.NewConnectedSockets()
	if err != nil {
		ch.data.Destroy()
		return nil, flipcall.PacketWindowDescriptor{}, -1, err
	}
	ch.fdChan.Init(fdSocks[0])
	clientFDSock := fdSocks[1]

	c.channels = append(c.channels, ch)
	return ch, desc, clientFDSock, nil
}

// sendFDs sends as many FDs as it can. The failure to send an FD does not
// cause an error and fail the entire RPC. FDs are considered supplementary
// responses that are not critical to the RPC response itself. The failure to
// send the (i)th FD will cause all the following FDs to not be sent as well
// because the order in which FDs are donated is important.
func (ch *channel) sendFDs(fds []int) uint8 {
	numFDs := len(fds)
	if numFDs == 0 {
		return 0
	}

	if numFDs > math.MaxUint8 {
		log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs)
		return 0
	}

	for i, fd := range fds {
		if err := ch.fdChan.SendFD(fd); err != nil {
			log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err)
			return uint8(i)
		}
	}
	return uint8(numFDs)
}

// channelHeader is the header present in front of each message received on
// flipcall endpoint when the protocol version being used is 1.
//
// +marshal
type channelHeader struct {
	message MID
	numFDs  uint8
	_       uint8 // Need to make struct packed.
}

func (ch *channel) marshalHdr(m MID, numFDs uint8) {
	header := &channelHeader{
		message: m,
		numFDs:  numFDs,
	}
	header.MarshalUnsafe(ch.data.Data())
}

func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) {
	if dataLen < chanHeaderLen {
		log.Warningf("received data has size smaller than header length: %d", dataLen)
		return 0, 0, unix.EIO
	}

	// Read header first.
	var header channelHeader
	header.UnmarshalUnsafe(ch.data.Data())

	// Read any FDs.
	for i := 0; i < int(header.numFDs); i++ {
		fd, err := ch.fdChan.RecvFDNonblock()
		if err != nil {
			log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err)
			break
		}
		ch.TrackFD(fd)
	}

	return header.message, dataLen - chanHeaderLen, nil
}