diff options
-rw-r--r-- | pkg/lisafs/BUILD | 116 | ||||
-rw-r--r-- | pkg/lisafs/README.md | 3 | ||||
-rw-r--r-- | pkg/lisafs/channel.go | 190 | ||||
-rw-r--r-- | pkg/lisafs/client.go | 377 | ||||
-rw-r--r-- | pkg/lisafs/communicator.go | 80 | ||||
-rw-r--r-- | pkg/lisafs/connection.go | 304 | ||||
-rw-r--r-- | pkg/lisafs/connection_test.go | 194 | ||||
-rw-r--r-- | pkg/lisafs/fd.go | 348 | ||||
-rw-r--r-- | pkg/lisafs/handlers.go | 124 | ||||
-rw-r--r-- | pkg/lisafs/lisafs.go | 18 | ||||
-rw-r--r-- | pkg/lisafs/message.go | 258 | ||||
-rw-r--r-- | pkg/lisafs/sample_message.go | 110 | ||||
-rw-r--r-- | pkg/lisafs/server.go | 113 | ||||
-rw-r--r-- | pkg/lisafs/sock.go | 208 | ||||
-rw-r--r-- | pkg/lisafs/sock_test.go | 217 | ||||
-rw-r--r-- | pkg/p9/client.go | 2 |
16 files changed, 2661 insertions, 1 deletions
diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD new file mode 100644 index 000000000..9914ed2f5 --- /dev/null +++ b/pkg/lisafs/BUILD @@ -0,0 +1,116 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_template_instance( + name = "control_fd_refs", + out = "control_fd_refs.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "ControlFD", + }, +) + +go_template_instance( + name = "open_fd_refs", + out = "open_fd_refs.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "OpenFD", + }, +) + +go_template_instance( + name = "control_fd_list", + out = "control_fd_list.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*ControlFD", + "Linker": "*ControlFD", + }, +) + +go_template_instance( + name = "open_fd_list", + out = "open_fd_list.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*OpenFD", + "Linker": "*OpenFD", + }, +) + +go_library( + name = "lisafs", + srcs = [ + "channel.go", + "client.go", + "communicator.go", + "connection.go", + "control_fd_list.go", + "control_fd_refs.go", + "fd.go", + "handlers.go", + "lisafs.go", + "message.go", + "open_fd_list.go", + "open_fd_refs.go", + "sample_message.go", + "server.go", + "sock.go", + ], + marshal = True, + deps = [ + "//pkg/abi/linux", + "//pkg/cleanup", + "//pkg/context", + "//pkg/fdchannel", + "//pkg/flipcall", + "//pkg/fspath", + "//pkg/hostarch", + "//pkg/log", + "//pkg/marshal/primitive", + "//pkg/p9", + "//pkg/refsvfs2", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "sock_test", + size = "small", + srcs = ["sock_test.go"], + library = ":lisafs", + deps = [ + "//pkg/marshal", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "connection_test", + size = "small", + srcs = ["connection_test.go"], + deps = [ + ":lisafs", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md index 51d0d40e5..6b857321a 100644 --- a/pkg/lisafs/README.md +++ b/pkg/lisafs/README.md @@ -1,5 +1,8 @@ # Replacing 9P +NOTE: LISAFS is **NOT** production ready. There are still some security concerns +that must be resolved first. + ## Background The Linux filesystem model consists of the following key aspects (modulo mounts, diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go new file mode 100644 index 000000000..301212e51 --- /dev/null +++ b/pkg/lisafs/channel.go @@ -0,0 +1,190 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "runtime" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +var ( + chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes()) +) + +// maxChannels returns the number of channels a client can create. +// +// The server will reject channel creation requests beyond this (per client). +// Note that we don't want the number of channels to be too large, because each +// accounts for a large region of shared memory. +// TODO(gvisor.dev/issue/6313): Tune the number of channels. +func maxChannels() int { + maxChans := runtime.GOMAXPROCS(0) + if maxChans < 2 { + maxChans = 2 + } + if maxChans > 4 { + maxChans = 4 + } + return maxChans +} + +// channel implements Communicator and represents the communication endpoint +// for the client and server and is used to perform fast IPC. Apart from +// communicating data, a channel is also capable of donating file descriptors. +type channel struct { + fdTracker + dead bool + data flipcall.Endpoint + fdChan fdchannel.Endpoint +} + +var _ Communicator = (*channel)(nil) + +// PayloadBuf implements Communicator.PayloadBuf. +func (ch *channel) PayloadBuf(size uint32) []byte { + return ch.data.Data()[chanHeaderLen : chanHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + // Write header. Requests can not donate FDs. + ch.marshalHdr(m, 0 /* numFDs */) + + // One-shot communication. RPCs are expected to be quick rather than block. + rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen) + if err != nil { + // This channel is now unusable. + ch.dead = true + // Map the transport errors to EIO, but also log the real error. + log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err) + return 0, 0, unix.EIO + } + + return ch.rcvMsg(rcvDataLen) +} + +func (ch *channel) shutdown() { + ch.data.Shutdown() +} + +func (ch *channel) destroy() { + ch.dead = true + ch.fdChan.Destroy() + ch.data.Destroy() +} + +// createChannel creates a server side channel. It returns a packet window +// descriptor (for the data channel) and an open socket for the FD channel. +func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + // If c.channels is nil, the connection has closed. + if c.channels == nil || len(c.channels) >= maxChannels() { + return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS + } + ch := &channel{} + + // Set up data channel. + desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize)) + if err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + if err := ch.data.Init(flipcall.ServerSide, desc); err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + + // Set up FD channel. + fdSocks, err := fdchannel.NewConnectedSockets() + if err != nil { + ch.data.Destroy() + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + ch.fdChan.Init(fdSocks[0]) + clientFDSock := fdSocks[1] + + c.channels = append(c.channels, ch) + return ch, desc, clientFDSock, nil +} + +// sendFDs sends as many FDs as it can. The failure to send an FD does not +// cause an error and fail the entire RPC. FDs are considered supplementary +// responses that are not critical to the RPC response itself. The failure to +// send the (i)th FD will cause all the following FDs to not be sent as well +// because the order in which FDs are donated is important. +func (ch *channel) sendFDs(fds []int) uint8 { + numFDs := len(fds) + if numFDs == 0 { + return 0 + } + + if numFDs > math.MaxUint8 { + log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs) + return 0 + } + + for i, fd := range fds { + if err := ch.fdChan.SendFD(fd); err != nil { + log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err) + return uint8(i) + } + } + return uint8(numFDs) +} + +// channelHeader is the header present in front of each message received on +// flipcall endpoint when the protocol version being used is 1. +// +// +marshal +type channelHeader struct { + message MID + numFDs uint8 + _ uint8 // Need to make struct packed. +} + +func (ch *channel) marshalHdr(m MID, numFDs uint8) { + header := &channelHeader{ + message: m, + numFDs: numFDs, + } + header.MarshalUnsafe(ch.data.Data()) +} + +func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) { + if dataLen < chanHeaderLen { + log.Warningf("received data has size smaller than header length: %d", dataLen) + return 0, 0, unix.EIO + } + + // Read header first. + var header channelHeader + header.UnmarshalUnsafe(ch.data.Data()) + + // Read any FDs. + for i := 0; i < int(header.numFDs); i++ { + fd, err := ch.fdChan.RecvFDNonblock() + if err != nil { + log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err) + break + } + ch.TrackFD(fd) + } + + return header.message, dataLen - chanHeaderLen, nil +} diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go new file mode 100644 index 000000000..c99f8c73d --- /dev/null +++ b/pkg/lisafs/client.go @@ -0,0 +1,377 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + "math" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// Client helps manage a connection to the lisafs server and pass messages +// efficiently. There is a 1:1 mapping between a Connection and a Client. +type Client struct { + // sockComm is the main socket by which this connections is established. + // Communication over the socket is synchronized by sockMu. + sockMu sync.Mutex + sockComm *sockCommunicator + + // channelsMu protects channels and availableChannels. + channelsMu sync.Mutex + // channels tracks all the channels. + channels []*channel + // availableChannels is a LIFO (stack) of channels available to be used. + availableChannels []*channel + // activeWg represents active channels. + activeWg sync.WaitGroup + + // watchdogWg only holds the watchdog goroutine. + watchdogWg sync.WaitGroup + + // supported caches information about which messages are supported. It is + // indexed by MID. An MID is supported if supported[MID] is true. + supported []bool + + // maxMessageSize is the maximum payload length (in bytes) that can be sent. + // It is initialized on Mount and is immutable. + maxMessageSize uint32 +} + +// NewClient creates a new client for communication with the server. It mounts +// the server and creates channels for fast IPC. NewClient takes ownership over +// the passed socket. On success, it returns the initialized client along with +// the root Inode. +func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) { + maxChans := maxChannels() + c := &Client{ + sockComm: newSockComm(sock), + channels: make([]*channel, 0, maxChans), + availableChannels: make([]*channel, 0, maxChans), + maxMessageSize: 1 << 20, // 1 MB for now. + } + + // Start a goroutine to check socket health. This goroutine is also + // responsible for client cleanup. + c.watchdogWg.Add(1) + go c.watchdog() + + // Clean everything up if anything fails. + cu := cleanup.Make(func() { + c.Close() + }) + defer cu.Clean() + + // Mount the server first. Assume Mount is supported so that we can make the + // Mount RPC below. + c.supported = make([]bool, Mount+1) + c.supported[Mount] = true + mountMsg := MountReq{ + MountPath: SizedString(mountPath), + } + var mountResp MountResp + if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil { + return nil, nil, err + } + + // Initialize client. + c.maxMessageSize = uint32(mountResp.MaxMessageSize) + var maxSuppMID MID + for _, suppMID := range mountResp.SupportedMs { + if suppMID > maxSuppMID { + maxSuppMID = suppMID + } + } + c.supported = make([]bool, maxSuppMID+1) + for _, suppMID := range mountResp.SupportedMs { + c.supported[suppMID] = true + } + + // Create channels parallely so that channels can be used to create more + // channels and costly initialization like flipcall.Endpoint.Connect can + // proceed parallely. + var channelsWg sync.WaitGroup + channelErrs := make([]error, maxChans) + for i := 0; i < maxChans; i++ { + channelsWg.Add(1) + curChanID := i + go func() { + defer channelsWg.Done() + ch, err := c.createChannel() + if err != nil { + log.Warningf("channel creation failed: %v", err) + channelErrs[curChanID] = err + return + } + c.channelsMu.Lock() + c.channels = append(c.channels, ch) + c.availableChannels = append(c.availableChannels, ch) + c.channelsMu.Unlock() + }() + } + channelsWg.Wait() + + for _, channelErr := range channelErrs { + // Return the first non-nil channel creation error. + if channelErr != nil { + return nil, nil, channelErr + } + } + cu.Release() + + return c, &mountResp.Root, nil +} + +func (c *Client) watchdog() { + defer c.watchdogWg.Done() + + events := []unix.PollFd{ + { + Fd: int32(c.sockComm.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + if err != nil { + log.Warningf("lisafs.Client.watch(): %v", err) + } else if n != 1 { + log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Shutdown all active channels and wait for them to complete. + c.shutdownActiveChans() + c.activeWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.destroy() + } + c.channelsMu.Unlock() + + // Close main socket. + c.sockComm.destroy() +} + +func (c *Client) shutdownActiveChans() { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + availableChans := make(map[*channel]bool) + for _, ch := range c.availableChannels { + availableChans[ch] = true + } + for _, ch := range c.channels { + // A channel that is not available is active. + if _, ok := availableChans[ch]; !ok { + log.Debugf("shutting down active channel@%p...", ch) + ch.shutdown() + } + } + + // Prevent channels from becoming available and serving new requests. + c.availableChannels = nil +} + +// Close shuts down the main socket and waits for the watchdog to clean up. +func (c *Client) Close() { + // This shutdown has no effect if the watchdog has already fired and closed + // the main socket. + c.sockComm.shutdown() + c.watchdogWg.Wait() +} + +func (c *Client) createChannel() (*channel, error) { + var chanResp ChannelResp + var fds [2]int + if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil { + return nil, err + } + if fds[0] < 0 || fds[1] < 0 { + closeFDs(fds[:]) + return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds) + } + + // Lets create the channel. + defer closeFDs(fds[:1]) // The data FD is not needed after this. + desc := flipcall.PacketWindowDescriptor{ + FD: fds[0], + Offset: chanResp.dataOffset, + Length: int(chanResp.dataLength), + } + + ch := &channel{} + if err := ch.data.Init(flipcall.ClientSide, desc); err != nil { + closeFDs(fds[1:]) + return nil, err + } + ch.fdChan.Init(fds[1]) // fdChan now owns this FD. + + // Only a connected channel is usable. + if err := ch.data.Connect(); err != nil { + ch.destroy() + return nil, err + } + return ch, nil +} + +// IsSupported returns true if this connection supports the passed message. +func (c *Client) IsSupported(m MID) bool { + return int(m) < len(c.supported) && c.supported[m] +} + +// SndRcvMessage invokes reqMarshal to marshal the request onto the payload +// buffer, wakes up the server to process the request, waits for the response +// and invokes respUnmarshal with the response payload. respFDs is populated +// with the received FDs, extra fields are set to -1. +// +// Note that the function arguments intentionally accept marshal.Marshallable +// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of +// directly accepting the marshal.Marshallable interface. Even though just +// accepting marshal.Marshallable is cleaner, it leads to a heap allocation +// (even if that interface variable itself does not escape). In other words, +// implicit conversion to an interface leads to an allocation. +// +// Precondition: reqMarshal and respUnmarshal must be non-nil. +func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error { + if !c.IsSupported(m) { + return unix.EOPNOTSUPP + } + if payloadLen > c.maxMessageSize { + log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize) + return unix.EIO + } + wantFDs := len(respFDs) + if wantFDs > math.MaxUint8 { + log.Warningf("want too many FDs: %d", wantFDs) + return unix.EINVAL + } + + // Acquire a communicator. + comm := c.acquireCommunicator() + defer c.releaseCommunicator(comm) + + // Marshal the request into comm's payload buffer and make the RPC. + reqMarshal(comm.PayloadBuf(payloadLen)) + respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs)) + + // Handle FD donation. + rcvFDs := comm.ReleaseFDs() + if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 { + // releasedFDs is memory owned by comm which can not be returned to caller. + // Copy it into the caller's buffer. + numFDCopied := copy(respFDs, rcvFDs) + if numFDCopied < numRcvFDs { + log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs) + closeFDs(rcvFDs[numFDCopied:]) + } + if numFDCopied < wantFDs { + for i := numFDCopied; i < wantFDs; i++ { + respFDs[i] = -1 + } + } + } + + // Error cases. + if err != nil { + closeFDs(respFDs) + return err + } + if respM == Error { + closeFDs(respFDs) + var resp ErrorResp + resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen)) + return unix.Errno(resp.errno) + } + if respM != m { + closeFDs(respFDs) + log.Warningf("sent %d message but got %d in response", m, respM) + return unix.EINVAL + } + + // Success. The payload must be unmarshalled *before* comm is released. + respUnmarshal(comm.PayloadBuf(respPayloadLen)) + return nil +} + +// Postcondition: releaseCommunicator() must be called on the returned value. +func (c *Client) acquireCommunicator() Communicator { + // Prefer using channel over socket because: + // - Channel uses a shared memory region for passing messages. IO from shared + // memory is faster and does not involve making a syscall. + // - No intermediate buffer allocation needed. With a channel, the message + // can be directly pasted into the shared memory region. + if ch := c.getChannel(); ch != nil { + return ch + } + + c.sockMu.Lock() + return c.sockComm +} + +// Precondition: comm must have been acquired via acquireCommunicator(). +func (c *Client) releaseCommunicator(comm Communicator) { + switch t := comm.(type) { + case *sockCommunicator: + c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator(). + case *channel: + c.releaseChannel(t) + default: + panic(fmt.Sprintf("unknown communicator type %T", t)) + } +} + +// getChannel pops a channel from the available channels stack. The caller must +// release the channel after use. +func (c *Client) getChannel() *channel { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + if len(c.availableChannels) == 0 { + return nil + } + + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + c.activeWg.Add(1) + return ch +} + +// releaseChannel pushes the passed channel onto the available channel stack if +// reinsert is true. +func (c *Client) releaseChannel(ch *channel) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + // If availableChannels is nil, then watchdog has fired and the client is + // shutting down. So don't make this channel available again. + if !ch.dead && c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.activeWg.Done() +} diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go new file mode 100644 index 000000000..ec2035158 --- /dev/null +++ b/pkg/lisafs/communicator.go @@ -0,0 +1,80 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import "golang.org/x/sys/unix" + +// Communicator is a server side utility which represents exactly how the +// server is communicating with the client. +type Communicator interface { + // PayloadBuf returns a slice to the payload section of its internal buffer + // where the message can be marshalled. The handlers should use this to + // populate the payload buffer with the message. + // + // The payload buffer contents *should* be preserved across calls with + // different sizes. Note that this is not a guarantee, because a compromised + // owner of a "shared" payload buffer can tamper with its contents anytime, + // even when it's not its turn to do so. + PayloadBuf(size uint32) []byte + + // SndRcvMessage sends message m. The caller must have populated PayloadBuf() + // with payloadLen bytes. The caller expects to receive wantFDs FDs. + // Any received FDs must be accessible via ReleaseFDs(). It returns the + // response message along with the response payload length. + SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) + + // DonateFD makes fd non-blocking and starts tracking it. The next call to + // ReleaseFDs will include fd in the order it was added. Communicator takes + // ownership of fd. Server side should call this. + DonateFD(fd int) error + + // Track starts tracking fd. The next call to ReleaseFDs will include fd in + // the order it was added. Communicator takes ownership of fd. Client side + // should use this for accumulating received FDs. + TrackFD(fd int) + + // ReleaseFDs returns the accumulated FDs and stops tracking them. The + // ownership of the FDs is transferred to the caller. + ReleaseFDs() []int +} + +// fdTracker is a partial implementation of Communicator. It can be embedded in +// Communicator implementations to keep track of FD donations. +type fdTracker struct { + fds []int +} + +// DonateFD implements Communicator.DonateFD. +func (d *fdTracker) DonateFD(fd int) error { + // Make sure the FD is non-blocking. + if err := unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return err + } + d.TrackFD(fd) + return nil +} + +// TrackFD implements Communicator.TrackFD. +func (d *fdTracker) TrackFD(fd int) { + d.fds = append(d.fds, fd) +} + +// ReleaseFDs implements Communicator.ReleaseFDs. +func (d *fdTracker) ReleaseFDs() []int { + ret := d.fds + d.fds = d.fds[:0] + return ret +} diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go new file mode 100644 index 000000000..8dba4805f --- /dev/null +++ b/pkg/lisafs/connection.go @@ -0,0 +1,304 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// Connection represents a connection between a mount point in the client and a +// mount point in the server. It is owned by the server on which it was started +// and facilitates communication with the client mount. +// +// Each connection is set up using a unix domain socket. One end is owned by +// the server and the other end is owned by the client. The connection may +// spawn additional comunicational channels for the same mount for increased +// RPC concurrency. +type Connection struct { + // server is the server on which this connection was created. It is immutably + // associated with it for its entire lifetime. + server *Server + + // mounted is a one way flag indicating whether this connection has been + // mounted correctly and the server is initialized properly. + mounted bool + + // readonly indicates if this connection is readonly. All write operations + // will fail with EROFS. + readonly bool + + // sockComm is the main socket by which this connections is established. + sockComm *sockCommunicator + + // channelsMu protects channels. + channelsMu sync.Mutex + // channels keeps track of all open channels. + channels []*channel + + // activeWg represents active channels. + activeWg sync.WaitGroup + + // reqGate counts requests that are still being handled. + reqGate sync.Gate + + // channelAlloc is used to allocate memory for channels. + channelAlloc *flipcall.PacketWindowAllocator + + fdsMu sync.RWMutex + // fds keeps tracks of open FDs on this server. It is protected by fdsMu. + fds map[FDID]genericFD + // nextFDID is the next available FDID. It is protected by fdsMu. + nextFDID FDID +} + +// CreateConnection initializes a new connection - creating a server if +// required. The connection must be started separately. +func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) { + c := &Connection{ + sockComm: newSockComm(sock), + server: s, + readonly: readonly, + channels: make([]*channel, 0, maxChannels()), + fds: make(map[FDID]genericFD), + nextFDID: InvalidFDID + 1, + } + + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return nil, err + } + c.channelAlloc = alloc + return c, nil +} + +// Server returns the associated server. +func (c *Connection) Server() *Server { + return c.server +} + +// ServerImpl returns the associated server implementation. +func (c *Connection) ServerImpl() ServerImpl { + return c.server.impl +} + +// Run defines the lifecycle of a connection. +func (c *Connection) Run() { + defer c.close() + + // Start handling requests on this connection. + for { + m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */) + if err != nil { + log.Debugf("sock read failed, closing connection: %v", err) + return + } + + respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen) + err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs) + closeFDs(respFDs) + if err != nil { + log.Debugf("sock write failed, closing connection: %v", err) + return + } + } +} + +// service starts servicing the passed channel until the channel is shutdown. +// This is a blocking method and hence must be called in a separate goroutine. +func (c *Connection) service(ch *channel) error { + rcvDataLen, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rcvDataLen > 0 { + m, payloadLen, err := ch.rcvMsg(rcvDataLen) + if err != nil { + return err + } + respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen) + numFDs := ch.sendFDs(respFDs) + closeFDs(respFDs) + + ch.marshalHdr(respM, numFDs) + rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen) + if err != nil { + return err + } + } + return nil +} + +func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) { + resp := &ErrorResp{errno: uint32(err)} + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return Error, respLen, nil +} + +func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) { + if !c.reqGate.Enter() { + // c.close() has been called; the connection is shutting down. + return c.respondError(comm, unix.ECONNRESET) + } + defer c.reqGate.Leave() + + if !c.mounted && m != Mount { + log.Warningf("connection must first be mounted") + return c.respondError(comm, unix.EINVAL) + } + + // Check if the message is supported for forward compatibility. + if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil { + log.Warningf("received request which is not supported by the server, MID = %d", m) + return c.respondError(comm, unix.EOPNOTSUPP) + } + + // Try handling the request. + respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen) + fds := comm.ReleaseFDs() + if err != nil { + closeFDs(fds) + return c.respondError(comm, p9.ExtractErrno(err)) + } + + return m, respPayloadLen, fds +} + +func (c *Connection) close() { + // Wait for completion of all inflight requests. This is mostly so that if + // a request is stuck, the sandbox supervisor has the opportunity to kill + // us with SIGABRT to get a stack dump of the offending handler. + c.reqGate.Close() + + // Shutdown and clean up channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.shutdown() + } + c.activeWg.Wait() + for _, ch := range c.channels { + ch.destroy() + } + // This is to prevent additional channels from being created. + c.channels = nil + c.channelsMu.Unlock() + + // Free the channel memory. + if c.channelAlloc != nil { + c.channelAlloc.Destroy() + } + + // Ensure the connection is closed. + c.sockComm.destroy() + + // Cleanup all FDs. + c.fdsMu.Lock() + for fdid := range c.fds { + fd := c.removeFDLocked(fdid) + fd.DecRef(nil) // Drop the ref held by c. + } + c.fdsMu.Unlock() +} + +// The caller gains a ref on the FD on success. +func (c *Connection) lookupFD(id FDID) (genericFD, error) { + c.fdsMu.RLock() + defer c.fdsMu.RUnlock() + + fd, ok := c.fds[id] + if !ok { + return nil, unix.EBADF + } + fd.IncRef() + return fd, nil +} + +// LookupControlFD retrieves the control FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + cfd, ok := fd.(*ControlFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return cfd, nil +} + +// LookupOpenFD retrieves the open FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + ofd, ok := fd.(*OpenFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return ofd, nil +} + +// insertFD inserts the passed fd into the internal datastructure to track FDs. +// The caller must hold a ref on fd which is transferred to the connection. +func (c *Connection) insertFD(fd genericFD) FDID { + c.fdsMu.Lock() + defer c.fdsMu.Unlock() + + res := c.nextFDID + c.nextFDID++ + if c.nextFDID < res { + panic("ran out of FDIDs") + } + c.fds[res] = fd + return res +} + +// RemoveFD makes c stop tracking the passed FDID and drops its ref on it. +func (c *Connection) RemoveFD(id FDID) { + c.fdsMu.Lock() + fd := c.removeFDLocked(id) + c.fdsMu.Unlock() + if fd != nil { + // Drop the ref held by c. This can take arbitrarily long. So do not hold + // c.fdsMu while calling it. + fd.DecRef(nil) + } +} + +// removeFDLocked makes c stop tracking the passed FDID. Note that the caller +// must drop ref on the returned fd (preferably without holding c.fdsMu). +// +// Precondition: c.fdsMu is locked. +func (c *Connection) removeFDLocked(id FDID) genericFD { + fd := c.fds[id] + if fd == nil { + log.Warningf("removeFDLocked called on non-existent FDID %d", id) + return nil + } + delete(c.fds, id) + return fd +} diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go new file mode 100644 index 000000000..28ba47112 --- /dev/null +++ b/pkg/lisafs/connection_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connection_test + +import ( + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/lisafs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + dynamicMsgID = lisafs.Channel + 1 + versionMsgID = dynamicMsgID + 1 +) + +var handlers = [...]lisafs.RPCHandler{ + lisafs.Error: lisafs.ErrorHandler, + lisafs.Mount: lisafs.MountHandler, + lisafs.Channel: lisafs.ChannelHandler, + dynamicMsgID: dynamicMsgHandler, + versionMsgID: versionHandler, +} + +// testServer implements lisafs.ServerImpl. +type testServer struct { + lisafs.Server +} + +var _ lisafs.ServerImpl = (*testServer)(nil) + +type testControlFD struct { + lisafs.ControlFD + lisafs.ControlFDImpl +} + +func (fd *testControlFD) FD() *lisafs.ControlFD { + return &fd.ControlFD +} + +// Mount implements lisafs.Mount. +func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) { + return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil +} + +// MaxMessageSize implements lisafs.MaxMessageSize. +func (s *testServer) MaxMessageSize() uint32 { + return lisafs.MaxMessageSize() +} + +// SupportedMessages implements lisafs.ServerImpl.SupportedMessages. +func (s *testServer) SupportedMessages() []lisafs.MID { + return []lisafs.MID{ + lisafs.Mount, + lisafs.Channel, + dynamicMsgID, + versionMsgID, + } +} + +func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) { + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + + ts := &testServer{} + ts.Server.InitTestOnly(ts, handlers[:]) + conn, err := ts.CreateConnection(serverSocket, false /* readonly */) + if err != nil { + t.Fatalf("starting connection failed: %v", err) + return + } + ts.StartConnection(conn) + + c, _, err := lisafs.NewClient(clientSocket, "/") + if err != nil { + t.Fatalf("client creation failed: %v", err) + } + + clientFn(c) + + c.Close() // This should trigger client and server shutdown. + ts.Wait() +} + +// TestStartUp tests that the server and client can be started up correctly. +func TestStartUp(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + if c.IsSupported(lisafs.Error) { + t.Errorf("sending error messages should not be supported") + } + }) +} + +func TestUnsupportedMessage(t *testing.T) { + unsupportedM := lisafs.MID(len(handlers)) + runServerClient(t, func(c *lisafs.Client) { + if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP { + t.Errorf("expected EOPNOTSUPP but got err: %v", err) + } + }) +} + +func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + var req lisafs.MsgDynamic + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Just echo back the message. + respPayloadLen := uint32(req.SizeBytes()) + req.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// TestStress stress tests sending many messages from various goroutines. +func TestStress(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + concurrency := 8 + numMsgPerGoroutine := 5000 + var clientWg sync.WaitGroup + for i := 0; i < concurrency; i++ { + clientWg.Add(1) + go func() { + defer clientWg.Done() + + for j := 0; j < numMsgPerGoroutine; j++ { + // Create a massive random message. + var req lisafs.MsgDynamic + req.Randomize(100) + + var resp lisafs.MsgDynamic + if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil { + t.Errorf("SndRcvMessage: received unexpected error %v", err) + return + } + if !reflect.DeepEqual(&req, &resp) { + t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp) + } + } + }() + } + + clientWg.Wait() + }) +} + +func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + // To be fair, usually handlers will create their own objects and return a + // pointer to those. Might be tempting to reuse above variables, but don't. + var rv lisafs.P9Version + rv.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Create a new response. + sv := lisafs.P9Version{ + MSize: rv.MSize, + Version: "9P2000.L.Google.11", + } + respPayloadLen := uint32(sv.SizeBytes()) + sv.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel. +func BenchmarkSendRecv(b *testing.B) { + b.ReportAllocs() + sendV := lisafs.P9Version{ + MSize: 1 << 20, + Version: "9P2000.L.Google.12", + } + + var recvV lisafs.P9Version + runServerClient(b, func(c *lisafs.Client) { + for i := 0; i < b.N; i++ { + if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil { + b.Fatalf("unexpected error occurred: %v", err) + } + } + }) +} diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go new file mode 100644 index 000000000..9dd8ba384 --- /dev/null +++ b/pkg/lisafs/fd.go @@ -0,0 +1,348 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/sync" +) + +// FDID (file descriptor identifier) is used to identify FDs on a connection. +// Each connection has its own FDID namespace. +// +// +marshal slice:FDIDSlice +type FDID uint32 + +// InvalidFDID represents an invalid FDID. +const InvalidFDID FDID = 0 + +// Ok returns true if f is a valid FDID. +func (f FDID) Ok() bool { + return f != InvalidFDID +} + +// genericFD can represent a ControlFD or OpenFD. +type genericFD interface { + refsvfs2.RefCounter +} + +// A ControlFD is the gateway to the backing filesystem tree node. It is an +// unusual concept. This exists to provide a safe way to do path-based +// operations on the file. It performs operations that can modify the +// filesystem tree and synchronizes these operations. See ControlFDImpl for +// supported operations. +// +// It is not an inode, because multiple control FDs are allowed to exist on the +// same file. It is not a file descriptor because it is not tied to any access +// mode, i.e. a control FD can change its access mode based on the operation +// being performed. +// +// Reference Model: +// * When a control FD is created, the connection takes a ref on it which +// represents the client's ref on the FD. +// * The client can drop its ref via the Close RPC which will in turn make the +// connection drop its ref. +// * Each control FD holds a ref on its parent for its entire life time. +type ControlFD struct { + controlFDRefs + controlFDEntry + + // parent is the parent directory FD containing the file this FD represents. + // A ControlFD holds a ref on parent for its entire lifetime. If this FD + // represents the root, then parent is nil. parent may be a control FD from + // another connection (another mount point). parent is protected by the + // backing server's rename mutex. + parent *ControlFD + + // name is the file path's last component name. If this FD represents the + // root directory, then name is the mount path. name is protected by the + // backing server's rename mutex. + name string + + // children is a linked list of all children control FDs. As per reference + // model, all children hold a ref on this FD. + // children is protected by childrenMu and server's rename mutex. To have + // mutual exclusion, it is sufficient to: + // * Hold rename mutex for reading and lock childrenMu. OR + // * Or hold rename mutex for writing. + childrenMu sync.Mutex + children controlFDList + + // openFDs is a linked list of all FDs opened on this FD. As per reference + // model, all open FDs hold a ref on this FD. + openFDsMu sync.RWMutex + openFDs openFDList + + // All the following fields are immutable. + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // conn is the backing connection owning this FD. + conn *Connection + + // ftype is the file type of the backing inode. ftype.FileType() == ftype. + ftype linux.FileMode + + // impl is the control FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl ControlFDImpl +} + +var _ genericFD = (*ControlFD)(nil) + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *ControlFD) DecRef(context.Context) { + fd.controlFDRefs.DecRef(func() { + if fd.parent != nil { + fd.conn.server.RenameMu.RLock() + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.conn.server.RenameMu.RUnlock() + fd.parent.DecRef(nil) // Drop the ref on the parent. + } + fd.impl.Close(fd.conn) + }) +} + +// DecRefLocked is the same as DecRef except the added precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) DecRefLocked() { + fd.controlFDRefs.DecRef(func() { + fd.clearParentLocked() + fd.impl.Close(fd.conn) + }) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) clearParentLocked() { + if fd.parent == nil { + return + } + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.parent.DecRefLocked() // Drop the ref on the parent. +} + +// Init must be called before first use of fd. It inserts fd into the +// filesystem tree. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.controlFDRefs.InitRefs() + fd.conn = c + fd.id = c.insertFD(fd) + fd.name = name + fd.ftype = mode.FileType() + fd.impl = impl + fd.setParentLocked(parent) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) setParentLocked(parent *ControlFD) { + fd.parent = parent + if parent != nil { + parent.IncRef() // Hold a ref on parent. + parent.childrenMu.Lock() + parent.children.PushBack(fd) + parent.childrenMu.Unlock() + } +} + +// FileType returns the file mode only containing the file type bits. +func (fd *ControlFD) FileType() linux.FileMode { + return fd.ftype +} + +// IsDir indicates whether fd represents a directory. +func (fd *ControlFD) IsDir() bool { + return fd.ftype == unix.S_IFDIR +} + +// IsRegular indicates whether fd represents a regular file. +func (fd *ControlFD) IsRegular() bool { + return fd.ftype == unix.S_IFREG +} + +// IsSymlink indicates whether fd represents a symbolic link. +func (fd *ControlFD) IsSymlink() bool { + return fd.ftype == unix.S_IFLNK +} + +// IsSocket indicates whether fd represents a socket. +func (fd *ControlFD) IsSocket() bool { + return fd.ftype == unix.S_IFSOCK +} + +// NameLocked returns the backing file's last component name. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) NameLocked() string { + return fd.name +} + +// ParentLocked returns the parent control FD. Note that parent might be a +// control FD from another connection on this server. So its ID must not +// returned on this connection because FDIDs are local to their connection. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) ParentLocked() ControlFDImpl { + if fd.parent == nil { + return nil + } + return fd.parent.impl +} + +// ID returns fd's ID. +func (fd *ControlFD) ID() FDID { + return fd.id +} + +// FilePath returns the absolute path of the file fd was opened on. This is +// expensive and must not be called on hot paths. FilePath acquires the rename +// mutex for reading so callers should not be holding it. +func (fd *ControlFD) FilePath() string { + // Lock the rename mutex for reading to ensure that the filesystem tree is not + // changed while we traverse it upwards. + fd.conn.server.RenameMu.RLock() + defer fd.conn.server.RenameMu.RUnlock() + return fd.FilePathLocked() +} + +// FilePathLocked is the same as FilePath with the additonal precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) FilePathLocked() string { + // Walk upwards and prepend name to res. + var res fspath.Builder + for fd != nil { + res.PrependComponent(fd.name) + fd = fd.parent + } + return res.String() +} + +// ForEachOpenFD executes fn on each FD opened on fd. +func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) { + fd.openFDsMu.RLock() + defer fd.openFDsMu.RUnlock() + for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() { + fn(ofd.impl) + } +} + +// OpenFD represents an open file descriptor on the protocol. It resonates +// closely with a Linux file descriptor. Its operations are limited to the +// file. Its operations are not allowed to modify or traverse the filesystem +// tree. See OpenFDImpl for the supported operations. +// +// Reference Model: +// * An OpenFD takes a reference on the control FD it was opened on. +type OpenFD struct { + openFDRefs + openFDEntry + + // All the following fields are immutable. + + // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref + // on controlFD for its entire lifetime. + controlFD *ControlFD + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // Access mode for this FD. + readable bool + writable bool + + // impl is the open FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl OpenFDImpl +} + +var _ genericFD = (*OpenFD)(nil) + +// ID returns fd's ID. +func (fd *OpenFD) ID() FDID { + return fd.id +} + +// ControlFD returns the control FD on which this FD was opened. +func (fd *OpenFD) ControlFD() ControlFDImpl { + return fd.controlFD.impl +} + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *OpenFD) DecRef(context.Context) { + fd.openFDRefs.DecRef(func() { + fd.controlFD.openFDsMu.Lock() + fd.controlFD.openFDs.Remove(fd) + fd.controlFD.openFDsMu.Unlock() + fd.controlFD.DecRef(nil) // Drop the ref on the control FD. + fd.impl.Close(fd.controlFD.conn) + }) +} + +// Init must be called before first use of fd. +func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.openFDRefs.InitRefs() + fd.controlFD = cfd + fd.id = cfd.conn.insertFD(fd) + accessMode := flags & unix.O_ACCMODE + fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR + fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR + fd.impl = impl + cfd.IncRef() // Holds a ref on cfd for its lifetime. + cfd.openFDsMu.Lock() + cfd.openFDs.PushBack(fd) + cfd.openFDsMu.Unlock() +} + +// ControlFDImpl contains implementation details for a ControlFD. +// Implementations of ControlFDImpl should contain their associated ControlFD +// by value as their first field. +// +// The operations that perform path traversal or any modification to the +// filesystem tree must synchronize those modifications with the server's +// rename mutex. +type ControlFDImpl interface { + FD() *ControlFD + Close(c *Connection) +} + +// OpenFDImpl contains implementation details for a OpenFD. Implementations of +// OpenFDImpl should contain their associated OpenFD by value as their first +// field. +// +// Since these operations do not perform any path traversal or any modification +// to the filesystem tree, there is no need to synchronize with rename +// operations. +type OpenFDImpl interface { + FD() *OpenFD + Close(c *Connection) +} diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go new file mode 100644 index 000000000..9b8d8164a --- /dev/null +++ b/pkg/lisafs/handlers.go @@ -0,0 +1,124 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "path" + "path/filepath" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// RPCHandler defines a handler that is invoked when the associated message is +// received. The handler is responsible for: +// +// * Unmarshalling the request from the passed payload and interpreting it. +// * Marshalling the response into the communicator's payload buffer. +// * Return the number of payload bytes written. +// * Donate any FDs (if needed) to comm which will in turn donate it to client. +type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) + +var handlers = [...]RPCHandler{ + Error: ErrorHandler, + Mount: MountHandler, + Channel: ChannelHandler, +} + +// ErrorHandler handles Error message. +func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + // Client should never send Error. + return 0, unix.EINVAL +} + +// MountHandler handles the Mount RPC. Note that there can not be concurrent +// executions of MountHandler on a connection because the connection enforces +// that Mount is the first message on the connection. Only after the connection +// has been successfully mounted can other channels be created. +func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req MountReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + mountPath := path.Clean(string(req.MountPath)) + if !filepath.IsAbs(mountPath) { + log.Warningf("mountPath %q is not absolute", mountPath) + return 0, unix.EINVAL + } + + if c.mounted { + log.Warningf("connection has already been mounted at %q", mountPath) + return 0, unix.EBUSY + } + + rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath) + if err != nil { + return 0, err + } + + c.server.addMountPoint(rootFD.FD()) + c.mounted = true + resp := MountResp{ + Root: rootIno, + SupportedMs: c.ServerImpl().SupportedMessages(), + MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()), + } + respPayloadLen := uint32(resp.SizeBytes()) + resp.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// ChannelHandler handles the Channel RPC. +func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize()) + if err != nil { + return 0, err + } + + // Start servicing the channel in a separate goroutine. + c.activeWg.Add(1) + go func() { + if err := c.service(ch); err != nil { + // Don't log shutdown error which is expected during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err) + } + } + c.activeWg.Done() + }() + + clientDataFD, err := unix.Dup(desc.FD) + if err != nil { + unix.Close(fdSock) + ch.shutdown() + return 0, err + } + + // Respond to client with successful channel creation message. + if err := comm.DonateFD(clientDataFD); err != nil { + return 0, err + } + if err := comm.DonateFD(fdSock); err != nil { + return 0, err + } + resp := ChannelResp{ + dataOffset: desc.Offset, + dataLength: uint64(desc.Length), + } + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return respLen, nil +} diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go new file mode 100644 index 000000000..4d8e956ab --- /dev/null +++ b/pkg/lisafs/lisafs.go @@ -0,0 +1,18 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lisafs (LInux SAndbox FileSystem) defines the protocol for +// filesystem RPCs between an untrusted Sandbox (client) and a trusted +// filesystem server. +package lisafs diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go new file mode 100644 index 000000000..55fd2c0b1 --- /dev/null +++ b/pkg/lisafs/message.go @@ -0,0 +1,258 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "os" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// Messages have two parts: +// * A transport header used to decipher received messages. +// * A byte array referred to as "payload" which contains the actual message. +// +// "dataLen" refers to the size of both combined. + +// MID (message ID) is used to identify messages to parse from payload. +// +// +marshal slice:MIDSlice +type MID uint16 + +// These constants are used to identify their corresponding message types. +const ( + // Error is only used in responses to pass errors to client. + Error MID = 0 + + // Mount is used to establish connection between the client and server mount + // point. lisafs requires that the client makes a successful Mount RPC before + // making other RPCs. + Mount MID = 1 + + // Channel requests to start a new communicational channel. + Channel MID = 2 +) + +const ( + // NoUID is a sentinel used to indicate no valid UID. + NoUID UID = math.MaxUint32 + + // NoGID is a sentinel used to indicate no valid GID. + NoGID GID = math.MaxUint32 +) + +// MaxMessageSize is the recommended max message size that can be used by +// connections. Server implementations may choose to use other values. +func MaxMessageSize() uint32 { + // Return HugePageSize - PageSize so that when flipcall packet window is + // created with MaxMessageSize() + flipcall header size + channel header + // size, HugePageSize is allocated and can be backed by a single huge page + // if supported by the underlying memfd. + return uint32(hostarch.HugePageSize - os.Getpagesize()) +} + +// TODO(gvisor.dev/issue/6450): Once this is resolved: +// * Update manual implementations and function signatures. +// * Update RPC handlers and appropriate callers to handle errors correctly. +// * Update manual implementations to get rid of buffer shifting. + +// UID represents a user ID. +// +// +marshal +type UID uint32 + +// Ok returns true if uid is not NoUID. +func (uid UID) Ok() bool { + return uid != NoUID +} + +// GID represents a group ID. +// +// +marshal +type GID uint32 + +// Ok returns true if gid is not NoGID. +func (gid GID) Ok() bool { + return gid != NoGID +} + +// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes. +func NoopMarshal([]byte) {} + +// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes. +func NoopUnmarshal([]byte) {} + +// SizedString represents a string in memory. The marshalled string bytes are +// preceded by a uint32 signifying the string length. +type SizedString string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SizedString) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + len(*s) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SizedString) MarshalBytes(dst []byte) { + strLen := primitive.Uint32(len(*s)) + strLen.MarshalUnsafe(dst) + dst = dst[strLen.SizeBytes():] + // Copy without any allocation. + copy(dst[:strLen], *s) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SizedString) UnmarshalBytes(src []byte) { + var strLen primitive.Uint32 + strLen.UnmarshalUnsafe(src) + src = src[strLen.SizeBytes():] + // Take the hit, this leads to an allocation + memcpy. No way around it. + *s = SizedString(src[:strLen]) +} + +// StringArray represents an array of SizedStrings in memory. The marshalled +// array data is preceded by a uint32 signifying the array length. +type StringArray []string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *StringArray) SizeBytes() int { + size := (*primitive.Uint32)(nil).SizeBytes() + for _, str := range *s { + sstr := SizedString(str) + size += sstr.SizeBytes() + } + return size +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *StringArray) MarshalBytes(dst []byte) { + arrLen := primitive.Uint32(len(*s)) + arrLen.MarshalUnsafe(dst) + dst = dst[arrLen.SizeBytes():] + for _, str := range *s { + sstr := SizedString(str) + sstr.MarshalBytes(dst) + dst = dst[sstr.SizeBytes():] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *StringArray) UnmarshalBytes(src []byte) { + var arrLen primitive.Uint32 + arrLen.UnmarshalUnsafe(src) + src = src[arrLen.SizeBytes():] + + if cap(*s) < int(arrLen) { + *s = make([]string, arrLen) + } else { + *s = (*s)[:arrLen] + } + + for i := primitive.Uint32(0); i < arrLen; i++ { + var sstr SizedString + sstr.UnmarshalBytes(src) + src = src[sstr.SizeBytes():] + (*s)[i] = string(sstr) + } +} + +// Inode represents an inode on the remote filesystem. +// +// +marshal slice:InodeSlice +type Inode struct { + ControlFD FDID + _ uint32 // Need to make struct packed. + Stat linux.Statx +} + +// MountReq represents a Mount request. +type MountReq struct { + MountPath SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountReq) SizeBytes() int { + return m.MountPath.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountReq) MarshalBytes(dst []byte) { + m.MountPath.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountReq) UnmarshalBytes(src []byte) { + m.MountPath.UnmarshalBytes(src) +} + +// MountResp represents a Mount response. +type MountResp struct { + Root Inode + // MaxMessageSize is the maximum size of messages communicated between the + // client and server in bytes. This includes the communication header. + MaxMessageSize primitive.Uint32 + // SupportedMs holds all the supported messages. + SupportedMs []MID +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountResp) SizeBytes() int { + return m.Root.SizeBytes() + + m.MaxMessageSize.SizeBytes() + + (*primitive.Uint16)(nil).SizeBytes() + + (len(m.SupportedMs) * (*MID)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountResp) MarshalBytes(dst []byte) { + m.Root.MarshalUnsafe(dst) + dst = dst[m.Root.SizeBytes():] + m.MaxMessageSize.MarshalUnsafe(dst) + dst = dst[m.MaxMessageSize.SizeBytes():] + numSupported := primitive.Uint16(len(m.SupportedMs)) + numSupported.MarshalBytes(dst) + dst = dst[numSupported.SizeBytes():] + MarshalUnsafeMIDSlice(m.SupportedMs, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountResp) UnmarshalBytes(src []byte) { + m.Root.UnmarshalUnsafe(src) + src = src[m.Root.SizeBytes():] + m.MaxMessageSize.UnmarshalUnsafe(src) + src = src[m.MaxMessageSize.SizeBytes():] + var numSupported primitive.Uint16 + numSupported.UnmarshalBytes(src) + src = src[numSupported.SizeBytes():] + m.SupportedMs = make([]MID, numSupported) + UnmarshalUnsafeMIDSlice(m.SupportedMs, src) +} + +// ChannelResp is the response to the create channel request. +// +// +marshal +type ChannelResp struct { + dataOffset int64 + dataLength uint64 +} + +// ErrorResp is returned to represent an error while handling a request. +// +// +marshal +type ErrorResp struct { + errno uint32 +} diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go new file mode 100644 index 000000000..3868dfa08 --- /dev/null +++ b/pkg/lisafs/sample_message.go @@ -0,0 +1,110 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math/rand" + + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// MsgSimple is a sample packed struct which can be used to test message passing. +// +// +marshal slice:Msg1Slice +type MsgSimple struct { + A uint16 + B uint16 + C uint32 + D uint64 +} + +// Randomize randomizes the contents of m. +func (m *MsgSimple) Randomize() { + m.A = uint16(rand.Uint32()) + m.B = uint16(rand.Uint32()) + m.C = rand.Uint32() + m.D = rand.Uint64() +} + +// MsgDynamic is a sample dynamic struct which can be used to test message passing. +// +// +marshal dynamic +type MsgDynamic struct { + N primitive.Uint32 + Arr []MsgSimple +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MsgDynamic) SizeBytes() int { + return m.N.SizeBytes() + + (int(m.N) * (*MsgSimple)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MsgDynamic) MarshalBytes(dst []byte) { + m.N.MarshalUnsafe(dst) + dst = dst[m.N.SizeBytes():] + MarshalUnsafeMsg1Slice(m.Arr, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MsgDynamic) UnmarshalBytes(src []byte) { + m.N.UnmarshalUnsafe(src) + src = src[m.N.SizeBytes():] + m.Arr = make([]MsgSimple, m.N) + UnmarshalUnsafeMsg1Slice(m.Arr, src) +} + +// Randomize randomizes the contents of m. +func (m *MsgDynamic) Randomize(arrLen int) { + m.N = primitive.Uint32(arrLen) + m.Arr = make([]MsgSimple, arrLen) + for i := 0; i < arrLen; i++ { + m.Arr[i].Randomize() + } +} + +// P9Version mimics p9.TVersion and p9.Rversion. +// +// +marshal dynamic +type P9Version struct { + MSize primitive.Uint32 + Version string +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (v *P9Version) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (v *P9Version) MarshalBytes(dst []byte) { + v.MSize.MarshalUnsafe(dst) + dst = dst[v.MSize.SizeBytes():] + versionLen := primitive.Uint16(len(v.Version)) + versionLen.MarshalUnsafe(dst) + dst = dst[versionLen.SizeBytes():] + copy(dst, v.Version) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (v *P9Version) UnmarshalBytes(src []byte) { + v.MSize.UnmarshalUnsafe(src) + src = src[v.MSize.SizeBytes():] + var versionLen primitive.Uint16 + versionLen.UnmarshalUnsafe(src) + src = src[versionLen.SizeBytes():] + v.Version = string(src[:versionLen]) +} diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go new file mode 100644 index 000000000..7515355ec --- /dev/null +++ b/pkg/lisafs/server.go @@ -0,0 +1,113 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "gvisor.dev/gvisor/pkg/sync" +) + +// Server serves a filesystem tree. Multiple connections on different mount +// points can be started on a server. The server provides utilities to safely +// modify the filesystem tree across its connections (mount points). Note that +// it does not support synchronizing filesystem tree mutations across other +// servers serving the same filesystem subtree. Server also manages the +// lifecycle of all connections. +type Server struct { + // connWg counts the number of active connections being tracked. + connWg sync.WaitGroup + + // RenameMu synchronizes rename operations within this filesystem tree. + RenameMu sync.RWMutex + + // handlers is a list of RPC handlers which can be indexed by the handler's + // corresponding MID. + handlers []RPCHandler + + // mountPoints keeps track of all the mount points this server serves. + mpMu sync.RWMutex + mountPoints []*ControlFD + + // impl is the server implementation which embeds this server. + impl ServerImpl +} + +// Init must be called before first use of server. +func (s *Server) Init(impl ServerImpl) { + s.impl = impl + s.handlers = handlers[:] +} + +// InitTestOnly is the same as Init except that it allows to swap out the +// underlying handlers with something custom. This is for test only. +func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) { + s.impl = impl + s.handlers = handlers +} + +// WithRenameReadLock invokes fn with the server's rename mutex locked for +// reading. This ensures that no rename operations occur concurrently. +func (s *Server) WithRenameReadLock(fn func() error) error { + s.RenameMu.RLock() + err := fn() + s.RenameMu.RUnlock() + return err +} + +// StartConnection starts the connection on a separate goroutine and tracks it. +func (s *Server) StartConnection(c *Connection) { + s.connWg.Add(1) + go func() { + c.Run() + s.connWg.Done() + }() +} + +// Wait waits for all connections started via StartConnection() to terminate. +func (s *Server) Wait() { + s.connWg.Wait() +} + +func (s *Server) addMountPoint(root *ControlFD) { + s.mpMu.Lock() + defer s.mpMu.Unlock() + s.mountPoints = append(s.mountPoints, root) +} + +func (s *Server) forEachMountPoint(fn func(root *ControlFD)) { + s.mpMu.RLock() + defer s.mpMu.RUnlock() + for _, mp := range s.mountPoints { + fn(mp) + } +} + +// ServerImpl contains the implementation details for a Server. +// Implementations of ServerImpl should contain their associated Server by +// value as their first field. +type ServerImpl interface { + // Mount is called when a Mount RPC is made. It mounts the connection at + // mountPath. + // + // Precondition: mountPath == path.Clean(mountPath). + Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error) + + // SupportedMessages returns a list of messages that the server + // implementation supports. + SupportedMessages() []MID + + // MaxMessageSize is the maximum payload length (in bytes) that can be sent + // to this server implementation. + MaxMessageSize() uint32 +} diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go new file mode 100644 index 000000000..88210242f --- /dev/null +++ b/pkg/lisafs/sock.go @@ -0,0 +1,208 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/unet" +) + +var ( + sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes()) +) + +// sockHeader is the header present in front of each message received on a UDS. +// +// +marshal +type sockHeader struct { + payloadLen uint32 + message MID + _ uint16 // Need to make struct packed. +} + +// sockCommunicator implements Communicator. This is not thread safe. +type sockCommunicator struct { + fdTracker + sock *unet.Socket + buf []byte +} + +var _ Communicator = (*sockCommunicator)(nil) + +func newSockComm(sock *unet.Socket) *sockCommunicator { + return &sockCommunicator{ + sock: sock, + buf: make([]byte, sockHeaderLen), + } +} + +func (s *sockCommunicator) FD() int { + return s.sock.FD() +} + +func (s *sockCommunicator) destroy() { + s.sock.Close() +} + +func (s *sockCommunicator) shutdown() { + if err := s.sock.Shutdown(); err != nil { + log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err) + } +} + +func (s *sockCommunicator) resizeBuf(size uint32) { + if cap(s.buf) < int(size) { + s.buf = s.buf[:cap(s.buf)] + s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...) + } else { + s.buf = s.buf[:size] + } +} + +// PayloadBuf implements Communicator.PayloadBuf. +func (s *sockCommunicator) PayloadBuf(size uint32) []byte { + s.resizeBuf(sockHeaderLen + size) + return s.buf[sockHeaderLen : sockHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil { + return 0, 0, err + } + + return s.rcvMsg(wantFDs) +} + +// sndPrepopulatedMsg assumes that s.buf has already been populated with +// `payloadLen` bytes of data. +func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error { + header := sockHeader{payloadLen: payloadLen, message: m} + header.MarshalUnsafe(s.buf) + dataLen := sockHeaderLen + payloadLen + return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds) +} + +// writeTo writes the passed iovec to the UDS and donates any passed FDs. +func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error { + w := sock.Writer(true) + if len(fds) > 0 { + w.PackFDs(fds...) + } + + fdsUnpacked := false + for n := 0; n < dataLen; { + cur, err := w.WriteVec(iovec) + if err != nil { + return err + } + n += cur + + // Fast common path. + if n >= dataLen { + break + } + + // Consume iovecs. + for consumed := 0; consumed < cur; { + if len(iovec[0]) <= cur-consumed { + consumed += len(iovec[0]) + iovec = iovec[1:] + } else { + iovec[0] = iovec[0][cur-consumed:] + break + } + } + + if n > 0 && !fdsUnpacked { + // Don't resend any control message. + fdsUnpacked = true + w.UnpackFDs() + } + } + return nil +} + +// rcvMsg reads the message header and payload from the UDS. It also populates +// fds with any donated FDs. +func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) { + fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs) + if err != nil { + return 0, 0, err + } + for _, fd := range fds { + s.TrackFD(fd) + } + + var header sockHeader + header.UnmarshalUnsafe(s.buf) + + // No payload? We are done. + if header.payloadLen == 0 { + return header.message, 0, nil + } + + if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil { + return 0, 0, err + } + + return header.message, header.payloadLen, nil +} + +// readFrom fills the passed buffer with data from the socket. It also returns +// any donated FDs. +func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) { + r := sock.Reader(true) + r.EnableFDs(int(wantFDs)) + + var ( + fds []int + fdInit bool + ) + n := len(buf) + for got := 0; got < n; { + cur, err := r.ReadVec([][]byte{buf[got:]}) + + // Ignore EOF if cur > 0. + if err != nil && (err != io.EOF || cur == 0) { + r.CloseFDs() + return nil, err + } + + if !fdInit && cur > 0 { + fds, err = r.ExtractFDs() + if err != nil { + return nil, err + } + + fdInit = true + r.EnableFDs(0) + } + + got += cur + } + return fds, nil +} + +func closeFDs(fds []int) { + for _, fd := range fds { + if fd >= 0 { + unix.Close(fd) + } + } +} diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go new file mode 100644 index 000000000..387f4b7a8 --- /dev/null +++ b/pkg/lisafs/sock_test.go @@ -0,0 +1,217 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) { + sock1, sock2, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer sock1.Close() + defer sock2.Close() + + var testWg sync.WaitGroup + testWg.Add(2) + + go func() { + fun1(newSockComm(sock1)) + testWg.Done() + }() + + go func() { + fun2(newSockComm(sock2)) + testWg.Done() + }() + + testWg.Wait() +} + +func TestReadWrite(t *testing.T) { + // Create random data to send. + n := 10000 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + runSocketTest(t, func(comm *sockCommunicator) { + // Scatter that data into two parts using Iovecs while sending. + mid := n / 2 + if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + if _, err := readFrom(comm.sock, gotData, 0); err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + + // Make sure we got the right data. + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + }) +} + +func TestFDDonation(t *testing.T) { + n := 10 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + // Try donating FDs to these files. + path1 := "/dev/null" + path2 := "/dev" + path3 := "/dev/random" + + runSocketTest(t, func(comm *sockCommunicator) { + devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0) + defer unix.Close(devNullFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path1, err) + } + devFD, err := unix.Open(path2, unix.O_RDONLY, 0) + defer unix.Close(devFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0) + defer unix.Close(devRandomFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + fds, err := readFrom(comm.sock, gotData, 3) + if err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + defer closeFDs(fds[:]) + + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + + if len(fds) != 3 { + t.Fatalf("wanted 3 FD, got %d", len(fds)) + } + + // Check that the FDs actually point to the correct file. + compareFDWithFile(t, fds[0], path1) + compareFDWithFile(t, fds[1], path2) + compareFDWithFile(t, fds[2], path3) + }) +} + +func compareFDWithFile(t *testing.T, fd int, path string) { + var want unix.Stat_t + if err := unix.Stat(path, &want); err != nil { + t.Fatalf("stat(%s) failed: %v", path, err) + } + + var got unix.Stat_t + if err := unix.Fstat(fd, &got); err != nil { + t.Fatalf("fstat on donated FD failed: %v", err) + } + + if got.Ino != want.Ino || got.Dev != want.Dev { + t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got) + } +} + +func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error { + var payloadLen uint32 + if msg != nil { + payloadLen = uint32(msg.SizeBytes()) + msg.MarshalUnsafe(comm.PayloadBuf(payloadLen)) + } + return comm.sndPrepopulatedMsg(m, payloadLen, nil) +} + +func TestSndRcvMessage(t *testing.T) { + req := &MsgSimple{} + req.Randomize() + reqM := MID(1) + + // Create a massive random response. + var resp MsgDynamic + resp.Randomize(100) + respM := MID(2) + + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, req); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, &resp) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, req) + if err := testSndMsg(comm, respM, &resp); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func TestSndRcvMessageNoPayload(t *testing.T) { + reqM := MID(1) + respM := MID(2) + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, nil) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, nil) + if err := testSndMsg(comm, respM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) { + gotM, payloadLen, err := comm.rcvMsg(0) + if err != nil { + t.Fatalf("readMessageFrom failed: %v", err) + } + if gotM != wantM { + t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM) + } + if wantMsg == nil { + if payloadLen != 0 { + t.Errorf("no payload expect but got %d bytes", payloadLen) + } + } else { + gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable) + gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + if !reflect.DeepEqual(wantMsg, gotMsg) { + t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg) + } + } +} diff --git a/pkg/p9/client.go b/pkg/p9/client.go index eb496f02f..d618da820 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -115,7 +115,7 @@ type Client struct { // channels is the set of all initialized channels. channels []*channel - // availableChannels is a FIFO of inactive channels. + // availableChannels is a LIFO of inactive channels. availableChannels []*channel // -- below corresponds to sendRecvLegacy -- |