summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/lisafs/BUILD116
-rw-r--r--pkg/lisafs/README.md3
-rw-r--r--pkg/lisafs/channel.go190
-rw-r--r--pkg/lisafs/client.go377
-rw-r--r--pkg/lisafs/communicator.go80
-rw-r--r--pkg/lisafs/connection.go304
-rw-r--r--pkg/lisafs/connection_test.go194
-rw-r--r--pkg/lisafs/fd.go348
-rw-r--r--pkg/lisafs/handlers.go124
-rw-r--r--pkg/lisafs/lisafs.go18
-rw-r--r--pkg/lisafs/message.go258
-rw-r--r--pkg/lisafs/sample_message.go110
-rw-r--r--pkg/lisafs/server.go113
-rw-r--r--pkg/lisafs/sock.go208
-rw-r--r--pkg/lisafs/sock_test.go217
-rw-r--r--pkg/p9/client.go2
16 files changed, 2661 insertions, 1 deletions
diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD
new file mode 100644
index 000000000..9914ed2f5
--- /dev/null
+++ b/pkg/lisafs/BUILD
@@ -0,0 +1,116 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "control_fd_refs",
+ out = "control_fd_refs.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_refs",
+ out = "open_fd_refs.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "OpenFD",
+ },
+)
+
+go_template_instance(
+ name = "control_fd_list",
+ out = "control_fd_list.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*ControlFD",
+ "Linker": "*ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_list",
+ out = "open_fd_list.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*OpenFD",
+ "Linker": "*OpenFD",
+ },
+)
+
+go_library(
+ name = "lisafs",
+ srcs = [
+ "channel.go",
+ "client.go",
+ "communicator.go",
+ "connection.go",
+ "control_fd_list.go",
+ "control_fd_refs.go",
+ "fd.go",
+ "handlers.go",
+ "lisafs.go",
+ "message.go",
+ "open_fd_list.go",
+ "open_fd_refs.go",
+ "sample_message.go",
+ "server.go",
+ "sock.go",
+ ],
+ marshal = True,
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/cleanup",
+ "//pkg/context",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
+ "//pkg/fspath",
+ "//pkg/hostarch",
+ "//pkg/log",
+ "//pkg/marshal/primitive",
+ "//pkg/p9",
+ "//pkg/refsvfs2",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "sock_test",
+ size = "small",
+ srcs = ["sock_test.go"],
+ library = ":lisafs",
+ deps = [
+ "//pkg/marshal",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "connection_test",
+ size = "small",
+ srcs = ["connection_test.go"],
+ deps = [
+ ":lisafs",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md
index 51d0d40e5..6b857321a 100644
--- a/pkg/lisafs/README.md
+++ b/pkg/lisafs/README.md
@@ -1,5 +1,8 @@
# Replacing 9P
+NOTE: LISAFS is **NOT** production ready. There are still some security concerns
+that must be resolved first.
+
## Background
The Linux filesystem model consists of the following key aspects (modulo mounts,
diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go
new file mode 100644
index 000000000..301212e51
--- /dev/null
+++ b/pkg/lisafs/channel.go
@@ -0,0 +1,190 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "runtime"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var (
+ chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes())
+)
+
+// maxChannels returns the number of channels a client can create.
+//
+// The server will reject channel creation requests beyond this (per client).
+// Note that we don't want the number of channels to be too large, because each
+// accounts for a large region of shared memory.
+// TODO(gvisor.dev/issue/6313): Tune the number of channels.
+func maxChannels() int {
+ maxChans := runtime.GOMAXPROCS(0)
+ if maxChans < 2 {
+ maxChans = 2
+ }
+ if maxChans > 4 {
+ maxChans = 4
+ }
+ return maxChans
+}
+
+// channel implements Communicator and represents the communication endpoint
+// for the client and server and is used to perform fast IPC. Apart from
+// communicating data, a channel is also capable of donating file descriptors.
+type channel struct {
+ fdTracker
+ dead bool
+ data flipcall.Endpoint
+ fdChan fdchannel.Endpoint
+}
+
+var _ Communicator = (*channel)(nil)
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (ch *channel) PayloadBuf(size uint32) []byte {
+ return ch.data.Data()[chanHeaderLen : chanHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ // Write header. Requests can not donate FDs.
+ ch.marshalHdr(m, 0 /* numFDs */)
+
+ // One-shot communication. RPCs are expected to be quick rather than block.
+ rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen)
+ if err != nil {
+ // This channel is now unusable.
+ ch.dead = true
+ // Map the transport errors to EIO, but also log the real error.
+ log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err)
+ return 0, 0, unix.EIO
+ }
+
+ return ch.rcvMsg(rcvDataLen)
+}
+
+func (ch *channel) shutdown() {
+ ch.data.Shutdown()
+}
+
+func (ch *channel) destroy() {
+ ch.dead = true
+ ch.fdChan.Destroy()
+ ch.data.Destroy()
+}
+
+// createChannel creates a server side channel. It returns a packet window
+// descriptor (for the data channel) and an open socket for the FD channel.
+func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ // If c.channels is nil, the connection has closed.
+ if c.channels == nil || len(c.channels) >= maxChannels() {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS
+ }
+ ch := &channel{}
+
+ // Set up data channel.
+ desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize))
+ if err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ if err := ch.data.Init(flipcall.ServerSide, desc); err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+
+ // Set up FD channel.
+ fdSocks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ ch.data.Destroy()
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ ch.fdChan.Init(fdSocks[0])
+ clientFDSock := fdSocks[1]
+
+ c.channels = append(c.channels, ch)
+ return ch, desc, clientFDSock, nil
+}
+
+// sendFDs sends as many FDs as it can. The failure to send an FD does not
+// cause an error and fail the entire RPC. FDs are considered supplementary
+// responses that are not critical to the RPC response itself. The failure to
+// send the (i)th FD will cause all the following FDs to not be sent as well
+// because the order in which FDs are donated is important.
+func (ch *channel) sendFDs(fds []int) uint8 {
+ numFDs := len(fds)
+ if numFDs == 0 {
+ return 0
+ }
+
+ if numFDs > math.MaxUint8 {
+ log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs)
+ return 0
+ }
+
+ for i, fd := range fds {
+ if err := ch.fdChan.SendFD(fd); err != nil {
+ log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err)
+ return uint8(i)
+ }
+ }
+ return uint8(numFDs)
+}
+
+// channelHeader is the header present in front of each message received on
+// flipcall endpoint when the protocol version being used is 1.
+//
+// +marshal
+type channelHeader struct {
+ message MID
+ numFDs uint8
+ _ uint8 // Need to make struct packed.
+}
+
+func (ch *channel) marshalHdr(m MID, numFDs uint8) {
+ header := &channelHeader{
+ message: m,
+ numFDs: numFDs,
+ }
+ header.MarshalUnsafe(ch.data.Data())
+}
+
+func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) {
+ if dataLen < chanHeaderLen {
+ log.Warningf("received data has size smaller than header length: %d", dataLen)
+ return 0, 0, unix.EIO
+ }
+
+ // Read header first.
+ var header channelHeader
+ header.UnmarshalUnsafe(ch.data.Data())
+
+ // Read any FDs.
+ for i := 0; i < int(header.numFDs); i++ {
+ fd, err := ch.fdChan.RecvFDNonblock()
+ if err != nil {
+ log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err)
+ break
+ }
+ ch.TrackFD(fd)
+ }
+
+ return header.message, dataLen - chanHeaderLen, nil
+}
diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go
new file mode 100644
index 000000000..c99f8c73d
--- /dev/null
+++ b/pkg/lisafs/client.go
@@ -0,0 +1,377 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "math"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Client helps manage a connection to the lisafs server and pass messages
+// efficiently. There is a 1:1 mapping between a Connection and a Client.
+type Client struct {
+ // sockComm is the main socket by which this connections is established.
+ // Communication over the socket is synchronized by sockMu.
+ sockMu sync.Mutex
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels and availableChannels.
+ channelsMu sync.Mutex
+ // channels tracks all the channels.
+ channels []*channel
+ // availableChannels is a LIFO (stack) of channels available to be used.
+ availableChannels []*channel
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // watchdogWg only holds the watchdog goroutine.
+ watchdogWg sync.WaitGroup
+
+ // supported caches information about which messages are supported. It is
+ // indexed by MID. An MID is supported if supported[MID] is true.
+ supported []bool
+
+ // maxMessageSize is the maximum payload length (in bytes) that can be sent.
+ // It is initialized on Mount and is immutable.
+ maxMessageSize uint32
+}
+
+// NewClient creates a new client for communication with the server. It mounts
+// the server and creates channels for fast IPC. NewClient takes ownership over
+// the passed socket. On success, it returns the initialized client along with
+// the root Inode.
+func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) {
+ maxChans := maxChannels()
+ c := &Client{
+ sockComm: newSockComm(sock),
+ channels: make([]*channel, 0, maxChans),
+ availableChannels: make([]*channel, 0, maxChans),
+ maxMessageSize: 1 << 20, // 1 MB for now.
+ }
+
+ // Start a goroutine to check socket health. This goroutine is also
+ // responsible for client cleanup.
+ c.watchdogWg.Add(1)
+ go c.watchdog()
+
+ // Clean everything up if anything fails.
+ cu := cleanup.Make(func() {
+ c.Close()
+ })
+ defer cu.Clean()
+
+ // Mount the server first. Assume Mount is supported so that we can make the
+ // Mount RPC below.
+ c.supported = make([]bool, Mount+1)
+ c.supported[Mount] = true
+ mountMsg := MountReq{
+ MountPath: SizedString(mountPath),
+ }
+ var mountResp MountResp
+ if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil {
+ return nil, nil, err
+ }
+
+ // Initialize client.
+ c.maxMessageSize = uint32(mountResp.MaxMessageSize)
+ var maxSuppMID MID
+ for _, suppMID := range mountResp.SupportedMs {
+ if suppMID > maxSuppMID {
+ maxSuppMID = suppMID
+ }
+ }
+ c.supported = make([]bool, maxSuppMID+1)
+ for _, suppMID := range mountResp.SupportedMs {
+ c.supported[suppMID] = true
+ }
+
+ // Create channels parallely so that channels can be used to create more
+ // channels and costly initialization like flipcall.Endpoint.Connect can
+ // proceed parallely.
+ var channelsWg sync.WaitGroup
+ channelErrs := make([]error, maxChans)
+ for i := 0; i < maxChans; i++ {
+ channelsWg.Add(1)
+ curChanID := i
+ go func() {
+ defer channelsWg.Done()
+ ch, err := c.createChannel()
+ if err != nil {
+ log.Warningf("channel creation failed: %v", err)
+ channelErrs[curChanID] = err
+ return
+ }
+ c.channelsMu.Lock()
+ c.channels = append(c.channels, ch)
+ c.availableChannels = append(c.availableChannels, ch)
+ c.channelsMu.Unlock()
+ }()
+ }
+ channelsWg.Wait()
+
+ for _, channelErr := range channelErrs {
+ // Return the first non-nil channel creation error.
+ if channelErr != nil {
+ return nil, nil, channelErr
+ }
+ }
+ cu.Release()
+
+ return c, &mountResp.Root, nil
+}
+
+func (c *Client) watchdog() {
+ defer c.watchdogWg.Done()
+
+ events := []unix.PollFd{
+ {
+ Fd: int32(c.sockComm.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("lisafs.Client.watch(): %v", err)
+ } else if n != 1 {
+ log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Shutdown all active channels and wait for them to complete.
+ c.shutdownActiveChans()
+ c.activeWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ c.channelsMu.Unlock()
+
+ // Close main socket.
+ c.sockComm.destroy()
+}
+
+func (c *Client) shutdownActiveChans() {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ availableChans := make(map[*channel]bool)
+ for _, ch := range c.availableChannels {
+ availableChans[ch] = true
+ }
+ for _, ch := range c.channels {
+ // A channel that is not available is active.
+ if _, ok := availableChans[ch]; !ok {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.shutdown()
+ }
+ }
+
+ // Prevent channels from becoming available and serving new requests.
+ c.availableChannels = nil
+}
+
+// Close shuts down the main socket and waits for the watchdog to clean up.
+func (c *Client) Close() {
+ // This shutdown has no effect if the watchdog has already fired and closed
+ // the main socket.
+ c.sockComm.shutdown()
+ c.watchdogWg.Wait()
+}
+
+func (c *Client) createChannel() (*channel, error) {
+ var chanResp ChannelResp
+ var fds [2]int
+ if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil {
+ return nil, err
+ }
+ if fds[0] < 0 || fds[1] < 0 {
+ closeFDs(fds[:])
+ return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds)
+ }
+
+ // Lets create the channel.
+ defer closeFDs(fds[:1]) // The data FD is not needed after this.
+ desc := flipcall.PacketWindowDescriptor{
+ FD: fds[0],
+ Offset: chanResp.dataOffset,
+ Length: int(chanResp.dataLength),
+ }
+
+ ch := &channel{}
+ if err := ch.data.Init(flipcall.ClientSide, desc); err != nil {
+ closeFDs(fds[1:])
+ return nil, err
+ }
+ ch.fdChan.Init(fds[1]) // fdChan now owns this FD.
+
+ // Only a connected channel is usable.
+ if err := ch.data.Connect(); err != nil {
+ ch.destroy()
+ return nil, err
+ }
+ return ch, nil
+}
+
+// IsSupported returns true if this connection supports the passed message.
+func (c *Client) IsSupported(m MID) bool {
+ return int(m) < len(c.supported) && c.supported[m]
+}
+
+// SndRcvMessage invokes reqMarshal to marshal the request onto the payload
+// buffer, wakes up the server to process the request, waits for the response
+// and invokes respUnmarshal with the response payload. respFDs is populated
+// with the received FDs, extra fields are set to -1.
+//
+// Note that the function arguments intentionally accept marshal.Marshallable
+// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of
+// directly accepting the marshal.Marshallable interface. Even though just
+// accepting marshal.Marshallable is cleaner, it leads to a heap allocation
+// (even if that interface variable itself does not escape). In other words,
+// implicit conversion to an interface leads to an allocation.
+//
+// Precondition: reqMarshal and respUnmarshal must be non-nil.
+func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error {
+ if !c.IsSupported(m) {
+ return unix.EOPNOTSUPP
+ }
+ if payloadLen > c.maxMessageSize {
+ log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize)
+ return unix.EIO
+ }
+ wantFDs := len(respFDs)
+ if wantFDs > math.MaxUint8 {
+ log.Warningf("want too many FDs: %d", wantFDs)
+ return unix.EINVAL
+ }
+
+ // Acquire a communicator.
+ comm := c.acquireCommunicator()
+ defer c.releaseCommunicator(comm)
+
+ // Marshal the request into comm's payload buffer and make the RPC.
+ reqMarshal(comm.PayloadBuf(payloadLen))
+ respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs))
+
+ // Handle FD donation.
+ rcvFDs := comm.ReleaseFDs()
+ if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 {
+ // releasedFDs is memory owned by comm which can not be returned to caller.
+ // Copy it into the caller's buffer.
+ numFDCopied := copy(respFDs, rcvFDs)
+ if numFDCopied < numRcvFDs {
+ log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs)
+ closeFDs(rcvFDs[numFDCopied:])
+ }
+ if numFDCopied < wantFDs {
+ for i := numFDCopied; i < wantFDs; i++ {
+ respFDs[i] = -1
+ }
+ }
+ }
+
+ // Error cases.
+ if err != nil {
+ closeFDs(respFDs)
+ return err
+ }
+ if respM == Error {
+ closeFDs(respFDs)
+ var resp ErrorResp
+ resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen))
+ return unix.Errno(resp.errno)
+ }
+ if respM != m {
+ closeFDs(respFDs)
+ log.Warningf("sent %d message but got %d in response", m, respM)
+ return unix.EINVAL
+ }
+
+ // Success. The payload must be unmarshalled *before* comm is released.
+ respUnmarshal(comm.PayloadBuf(respPayloadLen))
+ return nil
+}
+
+// Postcondition: releaseCommunicator() must be called on the returned value.
+func (c *Client) acquireCommunicator() Communicator {
+ // Prefer using channel over socket because:
+ // - Channel uses a shared memory region for passing messages. IO from shared
+ // memory is faster and does not involve making a syscall.
+ // - No intermediate buffer allocation needed. With a channel, the message
+ // can be directly pasted into the shared memory region.
+ if ch := c.getChannel(); ch != nil {
+ return ch
+ }
+
+ c.sockMu.Lock()
+ return c.sockComm
+}
+
+// Precondition: comm must have been acquired via acquireCommunicator().
+func (c *Client) releaseCommunicator(comm Communicator) {
+ switch t := comm.(type) {
+ case *sockCommunicator:
+ c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator().
+ case *channel:
+ c.releaseChannel(t)
+ default:
+ panic(fmt.Sprintf("unknown communicator type %T", t))
+ }
+}
+
+// getChannel pops a channel from the available channels stack. The caller must
+// release the channel after use.
+func (c *Client) getChannel() *channel {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ if len(c.availableChannels) == 0 {
+ return nil
+ }
+
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ c.activeWg.Add(1)
+ return ch
+}
+
+// releaseChannel pushes the passed channel onto the available channel stack if
+// reinsert is true.
+func (c *Client) releaseChannel(ch *channel) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ // If availableChannels is nil, then watchdog has fired and the client is
+ // shutting down. So don't make this channel available again.
+ if !ch.dead && c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.activeWg.Done()
+}
diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go
new file mode 100644
index 000000000..ec2035158
--- /dev/null
+++ b/pkg/lisafs/communicator.go
@@ -0,0 +1,80 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import "golang.org/x/sys/unix"
+
+// Communicator is a server side utility which represents exactly how the
+// server is communicating with the client.
+type Communicator interface {
+ // PayloadBuf returns a slice to the payload section of its internal buffer
+ // where the message can be marshalled. The handlers should use this to
+ // populate the payload buffer with the message.
+ //
+ // The payload buffer contents *should* be preserved across calls with
+ // different sizes. Note that this is not a guarantee, because a compromised
+ // owner of a "shared" payload buffer can tamper with its contents anytime,
+ // even when it's not its turn to do so.
+ PayloadBuf(size uint32) []byte
+
+ // SndRcvMessage sends message m. The caller must have populated PayloadBuf()
+ // with payloadLen bytes. The caller expects to receive wantFDs FDs.
+ // Any received FDs must be accessible via ReleaseFDs(). It returns the
+ // response message along with the response payload length.
+ SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error)
+
+ // DonateFD makes fd non-blocking and starts tracking it. The next call to
+ // ReleaseFDs will include fd in the order it was added. Communicator takes
+ // ownership of fd. Server side should call this.
+ DonateFD(fd int) error
+
+ // Track starts tracking fd. The next call to ReleaseFDs will include fd in
+ // the order it was added. Communicator takes ownership of fd. Client side
+ // should use this for accumulating received FDs.
+ TrackFD(fd int)
+
+ // ReleaseFDs returns the accumulated FDs and stops tracking them. The
+ // ownership of the FDs is transferred to the caller.
+ ReleaseFDs() []int
+}
+
+// fdTracker is a partial implementation of Communicator. It can be embedded in
+// Communicator implementations to keep track of FD donations.
+type fdTracker struct {
+ fds []int
+}
+
+// DonateFD implements Communicator.DonateFD.
+func (d *fdTracker) DonateFD(fd int) error {
+ // Make sure the FD is non-blocking.
+ if err := unix.SetNonblock(fd, true); err != nil {
+ unix.Close(fd)
+ return err
+ }
+ d.TrackFD(fd)
+ return nil
+}
+
+// TrackFD implements Communicator.TrackFD.
+func (d *fdTracker) TrackFD(fd int) {
+ d.fds = append(d.fds, fd)
+}
+
+// ReleaseFDs implements Communicator.ReleaseFDs.
+func (d *fdTracker) ReleaseFDs() []int {
+ ret := d.fds
+ d.fds = d.fds[:0]
+ return ret
+}
diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go
new file mode 100644
index 000000000..8dba4805f
--- /dev/null
+++ b/pkg/lisafs/connection.go
@@ -0,0 +1,304 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Connection represents a connection between a mount point in the client and a
+// mount point in the server. It is owned by the server on which it was started
+// and facilitates communication with the client mount.
+//
+// Each connection is set up using a unix domain socket. One end is owned by
+// the server and the other end is owned by the client. The connection may
+// spawn additional comunicational channels for the same mount for increased
+// RPC concurrency.
+type Connection struct {
+ // server is the server on which this connection was created. It is immutably
+ // associated with it for its entire lifetime.
+ server *Server
+
+ // mounted is a one way flag indicating whether this connection has been
+ // mounted correctly and the server is initialized properly.
+ mounted bool
+
+ // readonly indicates if this connection is readonly. All write operations
+ // will fail with EROFS.
+ readonly bool
+
+ // sockComm is the main socket by which this connections is established.
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+ // channels keeps track of all open channels.
+ channels []*channel
+
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // reqGate counts requests that are still being handled.
+ reqGate sync.Gate
+
+ // channelAlloc is used to allocate memory for channels.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ fdsMu sync.RWMutex
+ // fds keeps tracks of open FDs on this server. It is protected by fdsMu.
+ fds map[FDID]genericFD
+ // nextFDID is the next available FDID. It is protected by fdsMu.
+ nextFDID FDID
+}
+
+// CreateConnection initializes a new connection - creating a server if
+// required. The connection must be started separately.
+func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) {
+ c := &Connection{
+ sockComm: newSockComm(sock),
+ server: s,
+ readonly: readonly,
+ channels: make([]*channel, 0, maxChannels()),
+ fds: make(map[FDID]genericFD),
+ nextFDID: InvalidFDID + 1,
+ }
+
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return nil, err
+ }
+ c.channelAlloc = alloc
+ return c, nil
+}
+
+// Server returns the associated server.
+func (c *Connection) Server() *Server {
+ return c.server
+}
+
+// ServerImpl returns the associated server implementation.
+func (c *Connection) ServerImpl() ServerImpl {
+ return c.server.impl
+}
+
+// Run defines the lifecycle of a connection.
+func (c *Connection) Run() {
+ defer c.close()
+
+ // Start handling requests on this connection.
+ for {
+ m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */)
+ if err != nil {
+ log.Debugf("sock read failed, closing connection: %v", err)
+ return
+ }
+
+ respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen)
+ err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs)
+ closeFDs(respFDs)
+ if err != nil {
+ log.Debugf("sock write failed, closing connection: %v", err)
+ return
+ }
+ }
+}
+
+// service starts servicing the passed channel until the channel is shutdown.
+// This is a blocking method and hence must be called in a separate goroutine.
+func (c *Connection) service(ch *channel) error {
+ rcvDataLen, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rcvDataLen > 0 {
+ m, payloadLen, err := ch.rcvMsg(rcvDataLen)
+ if err != nil {
+ return err
+ }
+ respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen)
+ numFDs := ch.sendFDs(respFDs)
+ closeFDs(respFDs)
+
+ ch.marshalHdr(respM, numFDs)
+ rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) {
+ resp := &ErrorResp{errno: uint32(err)}
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return Error, respLen, nil
+}
+
+func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) {
+ if !c.reqGate.Enter() {
+ // c.close() has been called; the connection is shutting down.
+ return c.respondError(comm, unix.ECONNRESET)
+ }
+ defer c.reqGate.Leave()
+
+ if !c.mounted && m != Mount {
+ log.Warningf("connection must first be mounted")
+ return c.respondError(comm, unix.EINVAL)
+ }
+
+ // Check if the message is supported for forward compatibility.
+ if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil {
+ log.Warningf("received request which is not supported by the server, MID = %d", m)
+ return c.respondError(comm, unix.EOPNOTSUPP)
+ }
+
+ // Try handling the request.
+ respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen)
+ fds := comm.ReleaseFDs()
+ if err != nil {
+ closeFDs(fds)
+ return c.respondError(comm, p9.ExtractErrno(err))
+ }
+
+ return m, respPayloadLen, fds
+}
+
+func (c *Connection) close() {
+ // Wait for completion of all inflight requests. This is mostly so that if
+ // a request is stuck, the sandbox supervisor has the opportunity to kill
+ // us with SIGABRT to get a stack dump of the offending handler.
+ c.reqGate.Close()
+
+ // Shutdown and clean up channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.shutdown()
+ }
+ c.activeWg.Wait()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ // This is to prevent additional channels from being created.
+ c.channels = nil
+ c.channelsMu.Unlock()
+
+ // Free the channel memory.
+ if c.channelAlloc != nil {
+ c.channelAlloc.Destroy()
+ }
+
+ // Ensure the connection is closed.
+ c.sockComm.destroy()
+
+ // Cleanup all FDs.
+ c.fdsMu.Lock()
+ for fdid := range c.fds {
+ fd := c.removeFDLocked(fdid)
+ fd.DecRef(nil) // Drop the ref held by c.
+ }
+ c.fdsMu.Unlock()
+}
+
+// The caller gains a ref on the FD on success.
+func (c *Connection) lookupFD(id FDID) (genericFD, error) {
+ c.fdsMu.RLock()
+ defer c.fdsMu.RUnlock()
+
+ fd, ok := c.fds[id]
+ if !ok {
+ return nil, unix.EBADF
+ }
+ fd.IncRef()
+ return fd, nil
+}
+
+// LookupControlFD retrieves the control FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ cfd, ok := fd.(*ControlFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return cfd, nil
+}
+
+// LookupOpenFD retrieves the open FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ ofd, ok := fd.(*OpenFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return ofd, nil
+}
+
+// insertFD inserts the passed fd into the internal datastructure to track FDs.
+// The caller must hold a ref on fd which is transferred to the connection.
+func (c *Connection) insertFD(fd genericFD) FDID {
+ c.fdsMu.Lock()
+ defer c.fdsMu.Unlock()
+
+ res := c.nextFDID
+ c.nextFDID++
+ if c.nextFDID < res {
+ panic("ran out of FDIDs")
+ }
+ c.fds[res] = fd
+ return res
+}
+
+// RemoveFD makes c stop tracking the passed FDID and drops its ref on it.
+func (c *Connection) RemoveFD(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.DecRef(nil)
+ }
+}
+
+// removeFDLocked makes c stop tracking the passed FDID. Note that the caller
+// must drop ref on the returned fd (preferably without holding c.fdsMu).
+//
+// Precondition: c.fdsMu is locked.
+func (c *Connection) removeFDLocked(id FDID) genericFD {
+ fd := c.fds[id]
+ if fd == nil {
+ log.Warningf("removeFDLocked called on non-existent FDID %d", id)
+ return nil
+ }
+ delete(c.fds, id)
+ return fd
+}
diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go
new file mode 100644
index 000000000..28ba47112
--- /dev/null
+++ b/pkg/lisafs/connection_test.go
@@ -0,0 +1,194 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package connection_test
+
+import (
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ dynamicMsgID = lisafs.Channel + 1
+ versionMsgID = dynamicMsgID + 1
+)
+
+var handlers = [...]lisafs.RPCHandler{
+ lisafs.Error: lisafs.ErrorHandler,
+ lisafs.Mount: lisafs.MountHandler,
+ lisafs.Channel: lisafs.ChannelHandler,
+ dynamicMsgID: dynamicMsgHandler,
+ versionMsgID: versionHandler,
+}
+
+// testServer implements lisafs.ServerImpl.
+type testServer struct {
+ lisafs.Server
+}
+
+var _ lisafs.ServerImpl = (*testServer)(nil)
+
+type testControlFD struct {
+ lisafs.ControlFD
+ lisafs.ControlFDImpl
+}
+
+func (fd *testControlFD) FD() *lisafs.ControlFD {
+ return &fd.ControlFD
+}
+
+// Mount implements lisafs.Mount.
+func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) {
+ return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil
+}
+
+// MaxMessageSize implements lisafs.MaxMessageSize.
+func (s *testServer) MaxMessageSize() uint32 {
+ return lisafs.MaxMessageSize()
+}
+
+// SupportedMessages implements lisafs.ServerImpl.SupportedMessages.
+func (s *testServer) SupportedMessages() []lisafs.MID {
+ return []lisafs.MID{
+ lisafs.Mount,
+ lisafs.Channel,
+ dynamicMsgID,
+ versionMsgID,
+ }
+}
+
+func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) {
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+
+ ts := &testServer{}
+ ts.Server.InitTestOnly(ts, handlers[:])
+ conn, err := ts.CreateConnection(serverSocket, false /* readonly */)
+ if err != nil {
+ t.Fatalf("starting connection failed: %v", err)
+ return
+ }
+ ts.StartConnection(conn)
+
+ c, _, err := lisafs.NewClient(clientSocket, "/")
+ if err != nil {
+ t.Fatalf("client creation failed: %v", err)
+ }
+
+ clientFn(c)
+
+ c.Close() // This should trigger client and server shutdown.
+ ts.Wait()
+}
+
+// TestStartUp tests that the server and client can be started up correctly.
+func TestStartUp(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ if c.IsSupported(lisafs.Error) {
+ t.Errorf("sending error messages should not be supported")
+ }
+ })
+}
+
+func TestUnsupportedMessage(t *testing.T) {
+ unsupportedM := lisafs.MID(len(handlers))
+ runServerClient(t, func(c *lisafs.Client) {
+ if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP {
+ t.Errorf("expected EOPNOTSUPP but got err: %v", err)
+ }
+ })
+}
+
+func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ var req lisafs.MsgDynamic
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Just echo back the message.
+ respPayloadLen := uint32(req.SizeBytes())
+ req.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// TestStress stress tests sending many messages from various goroutines.
+func TestStress(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ concurrency := 8
+ numMsgPerGoroutine := 5000
+ var clientWg sync.WaitGroup
+ for i := 0; i < concurrency; i++ {
+ clientWg.Add(1)
+ go func() {
+ defer clientWg.Done()
+
+ for j := 0; j < numMsgPerGoroutine; j++ {
+ // Create a massive random message.
+ var req lisafs.MsgDynamic
+ req.Randomize(100)
+
+ var resp lisafs.MsgDynamic
+ if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil {
+ t.Errorf("SndRcvMessage: received unexpected error %v", err)
+ return
+ }
+ if !reflect.DeepEqual(&req, &resp) {
+ t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp)
+ }
+ }
+ }()
+ }
+
+ clientWg.Wait()
+ })
+}
+
+func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ // To be fair, usually handlers will create their own objects and return a
+ // pointer to those. Might be tempting to reuse above variables, but don't.
+ var rv lisafs.P9Version
+ rv.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Create a new response.
+ sv := lisafs.P9Version{
+ MSize: rv.MSize,
+ Version: "9P2000.L.Google.11",
+ }
+ respPayloadLen := uint32(sv.SizeBytes())
+ sv.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel.
+func BenchmarkSendRecv(b *testing.B) {
+ b.ReportAllocs()
+ sendV := lisafs.P9Version{
+ MSize: 1 << 20,
+ Version: "9P2000.L.Google.12",
+ }
+
+ var recvV lisafs.P9Version
+ runServerClient(b, func(c *lisafs.Client) {
+ for i := 0; i < b.N; i++ {
+ if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil {
+ b.Fatalf("unexpected error occurred: %v", err)
+ }
+ }
+ })
+}
diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go
new file mode 100644
index 000000000..9dd8ba384
--- /dev/null
+++ b/pkg/lisafs/fd.go
@@ -0,0 +1,348 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// FDID (file descriptor identifier) is used to identify FDs on a connection.
+// Each connection has its own FDID namespace.
+//
+// +marshal slice:FDIDSlice
+type FDID uint32
+
+// InvalidFDID represents an invalid FDID.
+const InvalidFDID FDID = 0
+
+// Ok returns true if f is a valid FDID.
+func (f FDID) Ok() bool {
+ return f != InvalidFDID
+}
+
+// genericFD can represent a ControlFD or OpenFD.
+type genericFD interface {
+ refsvfs2.RefCounter
+}
+
+// A ControlFD is the gateway to the backing filesystem tree node. It is an
+// unusual concept. This exists to provide a safe way to do path-based
+// operations on the file. It performs operations that can modify the
+// filesystem tree and synchronizes these operations. See ControlFDImpl for
+// supported operations.
+//
+// It is not an inode, because multiple control FDs are allowed to exist on the
+// same file. It is not a file descriptor because it is not tied to any access
+// mode, i.e. a control FD can change its access mode based on the operation
+// being performed.
+//
+// Reference Model:
+// * When a control FD is created, the connection takes a ref on it which
+// represents the client's ref on the FD.
+// * The client can drop its ref via the Close RPC which will in turn make the
+// connection drop its ref.
+// * Each control FD holds a ref on its parent for its entire life time.
+type ControlFD struct {
+ controlFDRefs
+ controlFDEntry
+
+ // parent is the parent directory FD containing the file this FD represents.
+ // A ControlFD holds a ref on parent for its entire lifetime. If this FD
+ // represents the root, then parent is nil. parent may be a control FD from
+ // another connection (another mount point). parent is protected by the
+ // backing server's rename mutex.
+ parent *ControlFD
+
+ // name is the file path's last component name. If this FD represents the
+ // root directory, then name is the mount path. name is protected by the
+ // backing server's rename mutex.
+ name string
+
+ // children is a linked list of all children control FDs. As per reference
+ // model, all children hold a ref on this FD.
+ // children is protected by childrenMu and server's rename mutex. To have
+ // mutual exclusion, it is sufficient to:
+ // * Hold rename mutex for reading and lock childrenMu. OR
+ // * Or hold rename mutex for writing.
+ childrenMu sync.Mutex
+ children controlFDList
+
+ // openFDs is a linked list of all FDs opened on this FD. As per reference
+ // model, all open FDs hold a ref on this FD.
+ openFDsMu sync.RWMutex
+ openFDs openFDList
+
+ // All the following fields are immutable.
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // conn is the backing connection owning this FD.
+ conn *Connection
+
+ // ftype is the file type of the backing inode. ftype.FileType() == ftype.
+ ftype linux.FileMode
+
+ // impl is the control FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl ControlFDImpl
+}
+
+var _ genericFD = (*ControlFD)(nil)
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *ControlFD) DecRef(context.Context) {
+ fd.controlFDRefs.DecRef(func() {
+ if fd.parent != nil {
+ fd.conn.server.RenameMu.RLock()
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.conn.server.RenameMu.RUnlock()
+ fd.parent.DecRef(nil) // Drop the ref on the parent.
+ }
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// DecRefLocked is the same as DecRef except the added precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) DecRefLocked() {
+ fd.controlFDRefs.DecRef(func() {
+ fd.clearParentLocked()
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) clearParentLocked() {
+ if fd.parent == nil {
+ return
+ }
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.parent.DecRefLocked() // Drop the ref on the parent.
+}
+
+// Init must be called before first use of fd. It inserts fd into the
+// filesystem tree.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.controlFDRefs.InitRefs()
+ fd.conn = c
+ fd.id = c.insertFD(fd)
+ fd.name = name
+ fd.ftype = mode.FileType()
+ fd.impl = impl
+ fd.setParentLocked(parent)
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) setParentLocked(parent *ControlFD) {
+ fd.parent = parent
+ if parent != nil {
+ parent.IncRef() // Hold a ref on parent.
+ parent.childrenMu.Lock()
+ parent.children.PushBack(fd)
+ parent.childrenMu.Unlock()
+ }
+}
+
+// FileType returns the file mode only containing the file type bits.
+func (fd *ControlFD) FileType() linux.FileMode {
+ return fd.ftype
+}
+
+// IsDir indicates whether fd represents a directory.
+func (fd *ControlFD) IsDir() bool {
+ return fd.ftype == unix.S_IFDIR
+}
+
+// IsRegular indicates whether fd represents a regular file.
+func (fd *ControlFD) IsRegular() bool {
+ return fd.ftype == unix.S_IFREG
+}
+
+// IsSymlink indicates whether fd represents a symbolic link.
+func (fd *ControlFD) IsSymlink() bool {
+ return fd.ftype == unix.S_IFLNK
+}
+
+// IsSocket indicates whether fd represents a socket.
+func (fd *ControlFD) IsSocket() bool {
+ return fd.ftype == unix.S_IFSOCK
+}
+
+// NameLocked returns the backing file's last component name.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) NameLocked() string {
+ return fd.name
+}
+
+// ParentLocked returns the parent control FD. Note that parent might be a
+// control FD from another connection on this server. So its ID must not
+// returned on this connection because FDIDs are local to their connection.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) ParentLocked() ControlFDImpl {
+ if fd.parent == nil {
+ return nil
+ }
+ return fd.parent.impl
+}
+
+// ID returns fd's ID.
+func (fd *ControlFD) ID() FDID {
+ return fd.id
+}
+
+// FilePath returns the absolute path of the file fd was opened on. This is
+// expensive and must not be called on hot paths. FilePath acquires the rename
+// mutex for reading so callers should not be holding it.
+func (fd *ControlFD) FilePath() string {
+ // Lock the rename mutex for reading to ensure that the filesystem tree is not
+ // changed while we traverse it upwards.
+ fd.conn.server.RenameMu.RLock()
+ defer fd.conn.server.RenameMu.RUnlock()
+ return fd.FilePathLocked()
+}
+
+// FilePathLocked is the same as FilePath with the additonal precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) FilePathLocked() string {
+ // Walk upwards and prepend name to res.
+ var res fspath.Builder
+ for fd != nil {
+ res.PrependComponent(fd.name)
+ fd = fd.parent
+ }
+ return res.String()
+}
+
+// ForEachOpenFD executes fn on each FD opened on fd.
+func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) {
+ fd.openFDsMu.RLock()
+ defer fd.openFDsMu.RUnlock()
+ for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() {
+ fn(ofd.impl)
+ }
+}
+
+// OpenFD represents an open file descriptor on the protocol. It resonates
+// closely with a Linux file descriptor. Its operations are limited to the
+// file. Its operations are not allowed to modify or traverse the filesystem
+// tree. See OpenFDImpl for the supported operations.
+//
+// Reference Model:
+// * An OpenFD takes a reference on the control FD it was opened on.
+type OpenFD struct {
+ openFDRefs
+ openFDEntry
+
+ // All the following fields are immutable.
+
+ // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref
+ // on controlFD for its entire lifetime.
+ controlFD *ControlFD
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // Access mode for this FD.
+ readable bool
+ writable bool
+
+ // impl is the open FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl OpenFDImpl
+}
+
+var _ genericFD = (*OpenFD)(nil)
+
+// ID returns fd's ID.
+func (fd *OpenFD) ID() FDID {
+ return fd.id
+}
+
+// ControlFD returns the control FD on which this FD was opened.
+func (fd *OpenFD) ControlFD() ControlFDImpl {
+ return fd.controlFD.impl
+}
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *OpenFD) DecRef(context.Context) {
+ fd.openFDRefs.DecRef(func() {
+ fd.controlFD.openFDsMu.Lock()
+ fd.controlFD.openFDs.Remove(fd)
+ fd.controlFD.openFDsMu.Unlock()
+ fd.controlFD.DecRef(nil) // Drop the ref on the control FD.
+ fd.impl.Close(fd.controlFD.conn)
+ })
+}
+
+// Init must be called before first use of fd.
+func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.openFDRefs.InitRefs()
+ fd.controlFD = cfd
+ fd.id = cfd.conn.insertFD(fd)
+ accessMode := flags & unix.O_ACCMODE
+ fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR
+ fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR
+ fd.impl = impl
+ cfd.IncRef() // Holds a ref on cfd for its lifetime.
+ cfd.openFDsMu.Lock()
+ cfd.openFDs.PushBack(fd)
+ cfd.openFDsMu.Unlock()
+}
+
+// ControlFDImpl contains implementation details for a ControlFD.
+// Implementations of ControlFDImpl should contain their associated ControlFD
+// by value as their first field.
+//
+// The operations that perform path traversal or any modification to the
+// filesystem tree must synchronize those modifications with the server's
+// rename mutex.
+type ControlFDImpl interface {
+ FD() *ControlFD
+ Close(c *Connection)
+}
+
+// OpenFDImpl contains implementation details for a OpenFD. Implementations of
+// OpenFDImpl should contain their associated OpenFD by value as their first
+// field.
+//
+// Since these operations do not perform any path traversal or any modification
+// to the filesystem tree, there is no need to synchronize with rename
+// operations.
+type OpenFDImpl interface {
+ FD() *OpenFD
+ Close(c *Connection)
+}
diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go
new file mode 100644
index 000000000..9b8d8164a
--- /dev/null
+++ b/pkg/lisafs/handlers.go
@@ -0,0 +1,124 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "path"
+ "path/filepath"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// RPCHandler defines a handler that is invoked when the associated message is
+// received. The handler is responsible for:
+//
+// * Unmarshalling the request from the passed payload and interpreting it.
+// * Marshalling the response into the communicator's payload buffer.
+// * Return the number of payload bytes written.
+// * Donate any FDs (if needed) to comm which will in turn donate it to client.
+type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error)
+
+var handlers = [...]RPCHandler{
+ Error: ErrorHandler,
+ Mount: MountHandler,
+ Channel: ChannelHandler,
+}
+
+// ErrorHandler handles Error message.
+func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ // Client should never send Error.
+ return 0, unix.EINVAL
+}
+
+// MountHandler handles the Mount RPC. Note that there can not be concurrent
+// executions of MountHandler on a connection because the connection enforces
+// that Mount is the first message on the connection. Only after the connection
+// has been successfully mounted can other channels be created.
+func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req MountReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ mountPath := path.Clean(string(req.MountPath))
+ if !filepath.IsAbs(mountPath) {
+ log.Warningf("mountPath %q is not absolute", mountPath)
+ return 0, unix.EINVAL
+ }
+
+ if c.mounted {
+ log.Warningf("connection has already been mounted at %q", mountPath)
+ return 0, unix.EBUSY
+ }
+
+ rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath)
+ if err != nil {
+ return 0, err
+ }
+
+ c.server.addMountPoint(rootFD.FD())
+ c.mounted = true
+ resp := MountResp{
+ Root: rootIno,
+ SupportedMs: c.ServerImpl().SupportedMessages(),
+ MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()),
+ }
+ respPayloadLen := uint32(resp.SizeBytes())
+ resp.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// ChannelHandler handles the Channel RPC.
+func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize())
+ if err != nil {
+ return 0, err
+ }
+
+ // Start servicing the channel in a separate goroutine.
+ c.activeWg.Add(1)
+ go func() {
+ if err := c.service(ch); err != nil {
+ // Don't log shutdown error which is expected during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err)
+ }
+ }
+ c.activeWg.Done()
+ }()
+
+ clientDataFD, err := unix.Dup(desc.FD)
+ if err != nil {
+ unix.Close(fdSock)
+ ch.shutdown()
+ return 0, err
+ }
+
+ // Respond to client with successful channel creation message.
+ if err := comm.DonateFD(clientDataFD); err != nil {
+ return 0, err
+ }
+ if err := comm.DonateFD(fdSock); err != nil {
+ return 0, err
+ }
+ resp := ChannelResp{
+ dataOffset: desc.Offset,
+ dataLength: uint64(desc.Length),
+ }
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go
new file mode 100644
index 000000000..4d8e956ab
--- /dev/null
+++ b/pkg/lisafs/lisafs.go
@@ -0,0 +1,18 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lisafs (LInux SAndbox FileSystem) defines the protocol for
+// filesystem RPCs between an untrusted Sandbox (client) and a trusted
+// filesystem server.
+package lisafs
diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go
new file mode 100644
index 000000000..55fd2c0b1
--- /dev/null
+++ b/pkg/lisafs/message.go
@@ -0,0 +1,258 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "os"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// Messages have two parts:
+// * A transport header used to decipher received messages.
+// * A byte array referred to as "payload" which contains the actual message.
+//
+// "dataLen" refers to the size of both combined.
+
+// MID (message ID) is used to identify messages to parse from payload.
+//
+// +marshal slice:MIDSlice
+type MID uint16
+
+// These constants are used to identify their corresponding message types.
+const (
+ // Error is only used in responses to pass errors to client.
+ Error MID = 0
+
+ // Mount is used to establish connection between the client and server mount
+ // point. lisafs requires that the client makes a successful Mount RPC before
+ // making other RPCs.
+ Mount MID = 1
+
+ // Channel requests to start a new communicational channel.
+ Channel MID = 2
+)
+
+const (
+ // NoUID is a sentinel used to indicate no valid UID.
+ NoUID UID = math.MaxUint32
+
+ // NoGID is a sentinel used to indicate no valid GID.
+ NoGID GID = math.MaxUint32
+)
+
+// MaxMessageSize is the recommended max message size that can be used by
+// connections. Server implementations may choose to use other values.
+func MaxMessageSize() uint32 {
+ // Return HugePageSize - PageSize so that when flipcall packet window is
+ // created with MaxMessageSize() + flipcall header size + channel header
+ // size, HugePageSize is allocated and can be backed by a single huge page
+ // if supported by the underlying memfd.
+ return uint32(hostarch.HugePageSize - os.Getpagesize())
+}
+
+// TODO(gvisor.dev/issue/6450): Once this is resolved:
+// * Update manual implementations and function signatures.
+// * Update RPC handlers and appropriate callers to handle errors correctly.
+// * Update manual implementations to get rid of buffer shifting.
+
+// UID represents a user ID.
+//
+// +marshal
+type UID uint32
+
+// Ok returns true if uid is not NoUID.
+func (uid UID) Ok() bool {
+ return uid != NoUID
+}
+
+// GID represents a group ID.
+//
+// +marshal
+type GID uint32
+
+// Ok returns true if gid is not NoGID.
+func (gid GID) Ok() bool {
+ return gid != NoGID
+}
+
+// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes.
+func NoopMarshal([]byte) {}
+
+// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes.
+func NoopUnmarshal([]byte) {}
+
+// SizedString represents a string in memory. The marshalled string bytes are
+// preceded by a uint32 signifying the string length.
+type SizedString string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *SizedString) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + len(*s)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *SizedString) MarshalBytes(dst []byte) {
+ strLen := primitive.Uint32(len(*s))
+ strLen.MarshalUnsafe(dst)
+ dst = dst[strLen.SizeBytes():]
+ // Copy without any allocation.
+ copy(dst[:strLen], *s)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *SizedString) UnmarshalBytes(src []byte) {
+ var strLen primitive.Uint32
+ strLen.UnmarshalUnsafe(src)
+ src = src[strLen.SizeBytes():]
+ // Take the hit, this leads to an allocation + memcpy. No way around it.
+ *s = SizedString(src[:strLen])
+}
+
+// StringArray represents an array of SizedStrings in memory. The marshalled
+// array data is preceded by a uint32 signifying the array length.
+type StringArray []string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *StringArray) SizeBytes() int {
+ size := (*primitive.Uint32)(nil).SizeBytes()
+ for _, str := range *s {
+ sstr := SizedString(str)
+ size += sstr.SizeBytes()
+ }
+ return size
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *StringArray) MarshalBytes(dst []byte) {
+ arrLen := primitive.Uint32(len(*s))
+ arrLen.MarshalUnsafe(dst)
+ dst = dst[arrLen.SizeBytes():]
+ for _, str := range *s {
+ sstr := SizedString(str)
+ sstr.MarshalBytes(dst)
+ dst = dst[sstr.SizeBytes():]
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *StringArray) UnmarshalBytes(src []byte) {
+ var arrLen primitive.Uint32
+ arrLen.UnmarshalUnsafe(src)
+ src = src[arrLen.SizeBytes():]
+
+ if cap(*s) < int(arrLen) {
+ *s = make([]string, arrLen)
+ } else {
+ *s = (*s)[:arrLen]
+ }
+
+ for i := primitive.Uint32(0); i < arrLen; i++ {
+ var sstr SizedString
+ sstr.UnmarshalBytes(src)
+ src = src[sstr.SizeBytes():]
+ (*s)[i] = string(sstr)
+ }
+}
+
+// Inode represents an inode on the remote filesystem.
+//
+// +marshal slice:InodeSlice
+type Inode struct {
+ ControlFD FDID
+ _ uint32 // Need to make struct packed.
+ Stat linux.Statx
+}
+
+// MountReq represents a Mount request.
+type MountReq struct {
+ MountPath SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountReq) SizeBytes() int {
+ return m.MountPath.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountReq) MarshalBytes(dst []byte) {
+ m.MountPath.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountReq) UnmarshalBytes(src []byte) {
+ m.MountPath.UnmarshalBytes(src)
+}
+
+// MountResp represents a Mount response.
+type MountResp struct {
+ Root Inode
+ // MaxMessageSize is the maximum size of messages communicated between the
+ // client and server in bytes. This includes the communication header.
+ MaxMessageSize primitive.Uint32
+ // SupportedMs holds all the supported messages.
+ SupportedMs []MID
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountResp) SizeBytes() int {
+ return m.Root.SizeBytes() +
+ m.MaxMessageSize.SizeBytes() +
+ (*primitive.Uint16)(nil).SizeBytes() +
+ (len(m.SupportedMs) * (*MID)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountResp) MarshalBytes(dst []byte) {
+ m.Root.MarshalUnsafe(dst)
+ dst = dst[m.Root.SizeBytes():]
+ m.MaxMessageSize.MarshalUnsafe(dst)
+ dst = dst[m.MaxMessageSize.SizeBytes():]
+ numSupported := primitive.Uint16(len(m.SupportedMs))
+ numSupported.MarshalBytes(dst)
+ dst = dst[numSupported.SizeBytes():]
+ MarshalUnsafeMIDSlice(m.SupportedMs, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountResp) UnmarshalBytes(src []byte) {
+ m.Root.UnmarshalUnsafe(src)
+ src = src[m.Root.SizeBytes():]
+ m.MaxMessageSize.UnmarshalUnsafe(src)
+ src = src[m.MaxMessageSize.SizeBytes():]
+ var numSupported primitive.Uint16
+ numSupported.UnmarshalBytes(src)
+ src = src[numSupported.SizeBytes():]
+ m.SupportedMs = make([]MID, numSupported)
+ UnmarshalUnsafeMIDSlice(m.SupportedMs, src)
+}
+
+// ChannelResp is the response to the create channel request.
+//
+// +marshal
+type ChannelResp struct {
+ dataOffset int64
+ dataLength uint64
+}
+
+// ErrorResp is returned to represent an error while handling a request.
+//
+// +marshal
+type ErrorResp struct {
+ errno uint32
+}
diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go
new file mode 100644
index 000000000..3868dfa08
--- /dev/null
+++ b/pkg/lisafs/sample_message.go
@@ -0,0 +1,110 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// MsgSimple is a sample packed struct which can be used to test message passing.
+//
+// +marshal slice:Msg1Slice
+type MsgSimple struct {
+ A uint16
+ B uint16
+ C uint32
+ D uint64
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgSimple) Randomize() {
+ m.A = uint16(rand.Uint32())
+ m.B = uint16(rand.Uint32())
+ m.C = rand.Uint32()
+ m.D = rand.Uint64()
+}
+
+// MsgDynamic is a sample dynamic struct which can be used to test message passing.
+//
+// +marshal dynamic
+type MsgDynamic struct {
+ N primitive.Uint32
+ Arr []MsgSimple
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MsgDynamic) SizeBytes() int {
+ return m.N.SizeBytes() +
+ (int(m.N) * (*MsgSimple)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MsgDynamic) MarshalBytes(dst []byte) {
+ m.N.MarshalUnsafe(dst)
+ dst = dst[m.N.SizeBytes():]
+ MarshalUnsafeMsg1Slice(m.Arr, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MsgDynamic) UnmarshalBytes(src []byte) {
+ m.N.UnmarshalUnsafe(src)
+ src = src[m.N.SizeBytes():]
+ m.Arr = make([]MsgSimple, m.N)
+ UnmarshalUnsafeMsg1Slice(m.Arr, src)
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgDynamic) Randomize(arrLen int) {
+ m.N = primitive.Uint32(arrLen)
+ m.Arr = make([]MsgSimple, arrLen)
+ for i := 0; i < arrLen; i++ {
+ m.Arr[i].Randomize()
+ }
+}
+
+// P9Version mimics p9.TVersion and p9.Rversion.
+//
+// +marshal dynamic
+type P9Version struct {
+ MSize primitive.Uint32
+ Version string
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (v *P9Version) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (v *P9Version) MarshalBytes(dst []byte) {
+ v.MSize.MarshalUnsafe(dst)
+ dst = dst[v.MSize.SizeBytes():]
+ versionLen := primitive.Uint16(len(v.Version))
+ versionLen.MarshalUnsafe(dst)
+ dst = dst[versionLen.SizeBytes():]
+ copy(dst, v.Version)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (v *P9Version) UnmarshalBytes(src []byte) {
+ v.MSize.UnmarshalUnsafe(src)
+ src = src[v.MSize.SizeBytes():]
+ var versionLen primitive.Uint16
+ versionLen.UnmarshalUnsafe(src)
+ src = src[versionLen.SizeBytes():]
+ v.Version = string(src[:versionLen])
+}
diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go
new file mode 100644
index 000000000..7515355ec
--- /dev/null
+++ b/pkg/lisafs/server.go
@@ -0,0 +1,113 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Server serves a filesystem tree. Multiple connections on different mount
+// points can be started on a server. The server provides utilities to safely
+// modify the filesystem tree across its connections (mount points). Note that
+// it does not support synchronizing filesystem tree mutations across other
+// servers serving the same filesystem subtree. Server also manages the
+// lifecycle of all connections.
+type Server struct {
+ // connWg counts the number of active connections being tracked.
+ connWg sync.WaitGroup
+
+ // RenameMu synchronizes rename operations within this filesystem tree.
+ RenameMu sync.RWMutex
+
+ // handlers is a list of RPC handlers which can be indexed by the handler's
+ // corresponding MID.
+ handlers []RPCHandler
+
+ // mountPoints keeps track of all the mount points this server serves.
+ mpMu sync.RWMutex
+ mountPoints []*ControlFD
+
+ // impl is the server implementation which embeds this server.
+ impl ServerImpl
+}
+
+// Init must be called before first use of server.
+func (s *Server) Init(impl ServerImpl) {
+ s.impl = impl
+ s.handlers = handlers[:]
+}
+
+// InitTestOnly is the same as Init except that it allows to swap out the
+// underlying handlers with something custom. This is for test only.
+func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) {
+ s.impl = impl
+ s.handlers = handlers
+}
+
+// WithRenameReadLock invokes fn with the server's rename mutex locked for
+// reading. This ensures that no rename operations occur concurrently.
+func (s *Server) WithRenameReadLock(fn func() error) error {
+ s.RenameMu.RLock()
+ err := fn()
+ s.RenameMu.RUnlock()
+ return err
+}
+
+// StartConnection starts the connection on a separate goroutine and tracks it.
+func (s *Server) StartConnection(c *Connection) {
+ s.connWg.Add(1)
+ go func() {
+ c.Run()
+ s.connWg.Done()
+ }()
+}
+
+// Wait waits for all connections started via StartConnection() to terminate.
+func (s *Server) Wait() {
+ s.connWg.Wait()
+}
+
+func (s *Server) addMountPoint(root *ControlFD) {
+ s.mpMu.Lock()
+ defer s.mpMu.Unlock()
+ s.mountPoints = append(s.mountPoints, root)
+}
+
+func (s *Server) forEachMountPoint(fn func(root *ControlFD)) {
+ s.mpMu.RLock()
+ defer s.mpMu.RUnlock()
+ for _, mp := range s.mountPoints {
+ fn(mp)
+ }
+}
+
+// ServerImpl contains the implementation details for a Server.
+// Implementations of ServerImpl should contain their associated Server by
+// value as their first field.
+type ServerImpl interface {
+ // Mount is called when a Mount RPC is made. It mounts the connection at
+ // mountPath.
+ //
+ // Precondition: mountPath == path.Clean(mountPath).
+ Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error)
+
+ // SupportedMessages returns a list of messages that the server
+ // implementation supports.
+ SupportedMessages() []MID
+
+ // MaxMessageSize is the maximum payload length (in bytes) that can be sent
+ // to this server implementation.
+ MaxMessageSize() uint32
+}
diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go
new file mode 100644
index 000000000..88210242f
--- /dev/null
+++ b/pkg/lisafs/sock.go
@@ -0,0 +1,208 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+var (
+ sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes())
+)
+
+// sockHeader is the header present in front of each message received on a UDS.
+//
+// +marshal
+type sockHeader struct {
+ payloadLen uint32
+ message MID
+ _ uint16 // Need to make struct packed.
+}
+
+// sockCommunicator implements Communicator. This is not thread safe.
+type sockCommunicator struct {
+ fdTracker
+ sock *unet.Socket
+ buf []byte
+}
+
+var _ Communicator = (*sockCommunicator)(nil)
+
+func newSockComm(sock *unet.Socket) *sockCommunicator {
+ return &sockCommunicator{
+ sock: sock,
+ buf: make([]byte, sockHeaderLen),
+ }
+}
+
+func (s *sockCommunicator) FD() int {
+ return s.sock.FD()
+}
+
+func (s *sockCommunicator) destroy() {
+ s.sock.Close()
+}
+
+func (s *sockCommunicator) shutdown() {
+ if err := s.sock.Shutdown(); err != nil {
+ log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err)
+ }
+}
+
+func (s *sockCommunicator) resizeBuf(size uint32) {
+ if cap(s.buf) < int(size) {
+ s.buf = s.buf[:cap(s.buf)]
+ s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...)
+ } else {
+ s.buf = s.buf[:size]
+ }
+}
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (s *sockCommunicator) PayloadBuf(size uint32) []byte {
+ s.resizeBuf(sockHeaderLen + size)
+ return s.buf[sockHeaderLen : sockHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil {
+ return 0, 0, err
+ }
+
+ return s.rcvMsg(wantFDs)
+}
+
+// sndPrepopulatedMsg assumes that s.buf has already been populated with
+// `payloadLen` bytes of data.
+func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error {
+ header := sockHeader{payloadLen: payloadLen, message: m}
+ header.MarshalUnsafe(s.buf)
+ dataLen := sockHeaderLen + payloadLen
+ return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds)
+}
+
+// writeTo writes the passed iovec to the UDS and donates any passed FDs.
+func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error {
+ w := sock.Writer(true)
+ if len(fds) > 0 {
+ w.PackFDs(fds...)
+ }
+
+ fdsUnpacked := false
+ for n := 0; n < dataLen; {
+ cur, err := w.WriteVec(iovec)
+ if err != nil {
+ return err
+ }
+ n += cur
+
+ // Fast common path.
+ if n >= dataLen {
+ break
+ }
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(iovec[0]) <= cur-consumed {
+ consumed += len(iovec[0])
+ iovec = iovec[1:]
+ } else {
+ iovec[0] = iovec[0][cur-consumed:]
+ break
+ }
+ }
+
+ if n > 0 && !fdsUnpacked {
+ // Don't resend any control message.
+ fdsUnpacked = true
+ w.UnpackFDs()
+ }
+ }
+ return nil
+}
+
+// rcvMsg reads the message header and payload from the UDS. It also populates
+// fds with any donated FDs.
+func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) {
+ fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs)
+ if err != nil {
+ return 0, 0, err
+ }
+ for _, fd := range fds {
+ s.TrackFD(fd)
+ }
+
+ var header sockHeader
+ header.UnmarshalUnsafe(s.buf)
+
+ // No payload? We are done.
+ if header.payloadLen == 0 {
+ return header.message, 0, nil
+ }
+
+ if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil {
+ return 0, 0, err
+ }
+
+ return header.message, header.payloadLen, nil
+}
+
+// readFrom fills the passed buffer with data from the socket. It also returns
+// any donated FDs.
+func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) {
+ r := sock.Reader(true)
+ r.EnableFDs(int(wantFDs))
+
+ var (
+ fds []int
+ fdInit bool
+ )
+ n := len(buf)
+ for got := 0; got < n; {
+ cur, err := r.ReadVec([][]byte{buf[got:]})
+
+ // Ignore EOF if cur > 0.
+ if err != nil && (err != io.EOF || cur == 0) {
+ r.CloseFDs()
+ return nil, err
+ }
+
+ if !fdInit && cur > 0 {
+ fds, err = r.ExtractFDs()
+ if err != nil {
+ return nil, err
+ }
+
+ fdInit = true
+ r.EnableFDs(0)
+ }
+
+ got += cur
+ }
+ return fds, nil
+}
+
+func closeFDs(fds []int) {
+ for _, fd := range fds {
+ if fd >= 0 {
+ unix.Close(fd)
+ }
+ }
+}
diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go
new file mode 100644
index 000000000..387f4b7a8
--- /dev/null
+++ b/pkg/lisafs/sock_test.go
@@ -0,0 +1,217 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "bytes"
+ "math/rand"
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) {
+ sock1, sock2, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer sock1.Close()
+ defer sock2.Close()
+
+ var testWg sync.WaitGroup
+ testWg.Add(2)
+
+ go func() {
+ fun1(newSockComm(sock1))
+ testWg.Done()
+ }()
+
+ go func() {
+ fun2(newSockComm(sock2))
+ testWg.Done()
+ }()
+
+ testWg.Wait()
+}
+
+func TestReadWrite(t *testing.T) {
+ // Create random data to send.
+ n := 10000
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ // Scatter that data into two parts using Iovecs while sending.
+ mid := n / 2
+ if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ if _, err := readFrom(comm.sock, gotData, 0); err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+
+ // Make sure we got the right data.
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+ })
+}
+
+func TestFDDonation(t *testing.T) {
+ n := 10
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ // Try donating FDs to these files.
+ path1 := "/dev/null"
+ path2 := "/dev"
+ path3 := "/dev/random"
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0)
+ defer unix.Close(devNullFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path1, err)
+ }
+ devFD, err := unix.Open(path2, unix.O_RDONLY, 0)
+ defer unix.Close(devFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0)
+ defer unix.Close(devRandomFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ fds, err := readFrom(comm.sock, gotData, 3)
+ if err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+ defer closeFDs(fds[:])
+
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+
+ if len(fds) != 3 {
+ t.Fatalf("wanted 3 FD, got %d", len(fds))
+ }
+
+ // Check that the FDs actually point to the correct file.
+ compareFDWithFile(t, fds[0], path1)
+ compareFDWithFile(t, fds[1], path2)
+ compareFDWithFile(t, fds[2], path3)
+ })
+}
+
+func compareFDWithFile(t *testing.T, fd int, path string) {
+ var want unix.Stat_t
+ if err := unix.Stat(path, &want); err != nil {
+ t.Fatalf("stat(%s) failed: %v", path, err)
+ }
+
+ var got unix.Stat_t
+ if err := unix.Fstat(fd, &got); err != nil {
+ t.Fatalf("fstat on donated FD failed: %v", err)
+ }
+
+ if got.Ino != want.Ino || got.Dev != want.Dev {
+ t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got)
+ }
+}
+
+func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error {
+ var payloadLen uint32
+ if msg != nil {
+ payloadLen = uint32(msg.SizeBytes())
+ msg.MarshalUnsafe(comm.PayloadBuf(payloadLen))
+ }
+ return comm.sndPrepopulatedMsg(m, payloadLen, nil)
+}
+
+func TestSndRcvMessage(t *testing.T) {
+ req := &MsgSimple{}
+ req.Randomize()
+ reqM := MID(1)
+
+ // Create a massive random response.
+ var resp MsgDynamic
+ resp.Randomize(100)
+ respM := MID(2)
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, req); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, &resp)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, req)
+ if err := testSndMsg(comm, respM, &resp); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func TestSndRcvMessageNoPayload(t *testing.T) {
+ reqM := MID(1)
+ respM := MID(2)
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, nil)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, nil)
+ if err := testSndMsg(comm, respM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) {
+ gotM, payloadLen, err := comm.rcvMsg(0)
+ if err != nil {
+ t.Fatalf("readMessageFrom failed: %v", err)
+ }
+ if gotM != wantM {
+ t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM)
+ }
+ if wantMsg == nil {
+ if payloadLen != 0 {
+ t.Errorf("no payload expect but got %d bytes", payloadLen)
+ }
+ } else {
+ gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable)
+ gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+ if !reflect.DeepEqual(wantMsg, gotMsg) {
+ t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg)
+ }
+ }
+}
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index eb496f02f..d618da820 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -115,7 +115,7 @@ type Client struct {
// channels is the set of all initialized channels.
channels []*channel
- // availableChannels is a FIFO of inactive channels.
+ // availableChannels is a LIFO of inactive channels.
availableChannels []*channel
// -- below corresponds to sendRecvLegacy --