diff options
Diffstat (limited to 'pkg/lisafs')
-rw-r--r-- | pkg/lisafs/BUILD | 117 | ||||
-rw-r--r-- | pkg/lisafs/README.md | 3 | ||||
-rw-r--r-- | pkg/lisafs/channel.go | 190 | ||||
-rw-r--r-- | pkg/lisafs/client.go | 432 | ||||
-rw-r--r-- | pkg/lisafs/client_file.go | 475 | ||||
-rw-r--r-- | pkg/lisafs/communicator.go | 80 | ||||
-rw-r--r-- | pkg/lisafs/connection.go | 320 | ||||
-rw-r--r-- | pkg/lisafs/connection_test.go | 194 | ||||
-rw-r--r-- | pkg/lisafs/fd.go | 374 | ||||
-rw-r--r-- | pkg/lisafs/handlers.go | 768 | ||||
-rw-r--r-- | pkg/lisafs/lisafs.go | 18 | ||||
-rw-r--r-- | pkg/lisafs/message.go | 1251 | ||||
-rw-r--r-- | pkg/lisafs/sample_message.go | 110 | ||||
-rw-r--r-- | pkg/lisafs/server.go | 113 | ||||
-rw-r--r-- | pkg/lisafs/sock.go | 208 | ||||
-rw-r--r-- | pkg/lisafs/sock_test.go | 217 | ||||
-rw-r--r-- | pkg/lisafs/testsuite/BUILD | 20 | ||||
-rw-r--r-- | pkg/lisafs/testsuite/testsuite.go | 637 |
18 files changed, 5527 insertions, 0 deletions
diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD new file mode 100644 index 000000000..313c1756d --- /dev/null +++ b/pkg/lisafs/BUILD @@ -0,0 +1,117 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_template_instance( + name = "control_fd_refs", + out = "control_fd_refs.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "ControlFD", + }, +) + +go_template_instance( + name = "open_fd_refs", + out = "open_fd_refs.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "OpenFD", + }, +) + +go_template_instance( + name = "control_fd_list", + out = "control_fd_list.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*ControlFD", + "Linker": "*ControlFD", + }, +) + +go_template_instance( + name = "open_fd_list", + out = "open_fd_list.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*OpenFD", + "Linker": "*OpenFD", + }, +) + +go_library( + name = "lisafs", + srcs = [ + "channel.go", + "client.go", + "client_file.go", + "communicator.go", + "connection.go", + "control_fd_list.go", + "control_fd_refs.go", + "fd.go", + "handlers.go", + "lisafs.go", + "message.go", + "open_fd_list.go", + "open_fd_refs.go", + "sample_message.go", + "server.go", + "sock.go", + ], + marshal = True, + deps = [ + "//pkg/abi/linux", + "//pkg/cleanup", + "//pkg/context", + "//pkg/fdchannel", + "//pkg/flipcall", + "//pkg/fspath", + "//pkg/hostarch", + "//pkg/log", + "//pkg/marshal/primitive", + "//pkg/p9", + "//pkg/refsvfs2", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "sock_test", + size = "small", + srcs = ["sock_test.go"], + library = ":lisafs", + deps = [ + "//pkg/marshal", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "connection_test", + size = "small", + srcs = ["connection_test.go"], + deps = [ + ":lisafs", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md index 51d0d40e5..6b857321a 100644 --- a/pkg/lisafs/README.md +++ b/pkg/lisafs/README.md @@ -1,5 +1,8 @@ # Replacing 9P +NOTE: LISAFS is **NOT** production ready. There are still some security concerns +that must be resolved first. + ## Background The Linux filesystem model consists of the following key aspects (modulo mounts, diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go new file mode 100644 index 000000000..301212e51 --- /dev/null +++ b/pkg/lisafs/channel.go @@ -0,0 +1,190 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "runtime" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +var ( + chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes()) +) + +// maxChannels returns the number of channels a client can create. +// +// The server will reject channel creation requests beyond this (per client). +// Note that we don't want the number of channels to be too large, because each +// accounts for a large region of shared memory. +// TODO(gvisor.dev/issue/6313): Tune the number of channels. +func maxChannels() int { + maxChans := runtime.GOMAXPROCS(0) + if maxChans < 2 { + maxChans = 2 + } + if maxChans > 4 { + maxChans = 4 + } + return maxChans +} + +// channel implements Communicator and represents the communication endpoint +// for the client and server and is used to perform fast IPC. Apart from +// communicating data, a channel is also capable of donating file descriptors. +type channel struct { + fdTracker + dead bool + data flipcall.Endpoint + fdChan fdchannel.Endpoint +} + +var _ Communicator = (*channel)(nil) + +// PayloadBuf implements Communicator.PayloadBuf. +func (ch *channel) PayloadBuf(size uint32) []byte { + return ch.data.Data()[chanHeaderLen : chanHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + // Write header. Requests can not donate FDs. + ch.marshalHdr(m, 0 /* numFDs */) + + // One-shot communication. RPCs are expected to be quick rather than block. + rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen) + if err != nil { + // This channel is now unusable. + ch.dead = true + // Map the transport errors to EIO, but also log the real error. + log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err) + return 0, 0, unix.EIO + } + + return ch.rcvMsg(rcvDataLen) +} + +func (ch *channel) shutdown() { + ch.data.Shutdown() +} + +func (ch *channel) destroy() { + ch.dead = true + ch.fdChan.Destroy() + ch.data.Destroy() +} + +// createChannel creates a server side channel. It returns a packet window +// descriptor (for the data channel) and an open socket for the FD channel. +func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + // If c.channels is nil, the connection has closed. + if c.channels == nil || len(c.channels) >= maxChannels() { + return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS + } + ch := &channel{} + + // Set up data channel. + desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize)) + if err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + if err := ch.data.Init(flipcall.ServerSide, desc); err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + + // Set up FD channel. + fdSocks, err := fdchannel.NewConnectedSockets() + if err != nil { + ch.data.Destroy() + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + ch.fdChan.Init(fdSocks[0]) + clientFDSock := fdSocks[1] + + c.channels = append(c.channels, ch) + return ch, desc, clientFDSock, nil +} + +// sendFDs sends as many FDs as it can. The failure to send an FD does not +// cause an error and fail the entire RPC. FDs are considered supplementary +// responses that are not critical to the RPC response itself. The failure to +// send the (i)th FD will cause all the following FDs to not be sent as well +// because the order in which FDs are donated is important. +func (ch *channel) sendFDs(fds []int) uint8 { + numFDs := len(fds) + if numFDs == 0 { + return 0 + } + + if numFDs > math.MaxUint8 { + log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs) + return 0 + } + + for i, fd := range fds { + if err := ch.fdChan.SendFD(fd); err != nil { + log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err) + return uint8(i) + } + } + return uint8(numFDs) +} + +// channelHeader is the header present in front of each message received on +// flipcall endpoint when the protocol version being used is 1. +// +// +marshal +type channelHeader struct { + message MID + numFDs uint8 + _ uint8 // Need to make struct packed. +} + +func (ch *channel) marshalHdr(m MID, numFDs uint8) { + header := &channelHeader{ + message: m, + numFDs: numFDs, + } + header.MarshalUnsafe(ch.data.Data()) +} + +func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) { + if dataLen < chanHeaderLen { + log.Warningf("received data has size smaller than header length: %d", dataLen) + return 0, 0, unix.EIO + } + + // Read header first. + var header channelHeader + header.UnmarshalUnsafe(ch.data.Data()) + + // Read any FDs. + for i := 0; i < int(header.numFDs); i++ { + fd, err := ch.fdChan.RecvFDNonblock() + if err != nil { + log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err) + break + } + ch.TrackFD(fd) + } + + return header.message, dataLen - chanHeaderLen, nil +} diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go new file mode 100644 index 000000000..ccf1b9f72 --- /dev/null +++ b/pkg/lisafs/client.go @@ -0,0 +1,432 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + "math" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + // fdsToCloseBatchSize is the number of closed FDs batched before an Close + // RPC is made to close them all. fdsToCloseBatchSize is immutable. + fdsToCloseBatchSize = 100 +) + +// Client helps manage a connection to the lisafs server and pass messages +// efficiently. There is a 1:1 mapping between a Connection and a Client. +type Client struct { + // sockComm is the main socket by which this connections is established. + // Communication over the socket is synchronized by sockMu. + sockMu sync.Mutex + sockComm *sockCommunicator + + // channelsMu protects channels and availableChannels. + channelsMu sync.Mutex + // channels tracks all the channels. + channels []*channel + // availableChannels is a LIFO (stack) of channels available to be used. + availableChannels []*channel + // activeWg represents active channels. + activeWg sync.WaitGroup + + // watchdogWg only holds the watchdog goroutine. + watchdogWg sync.WaitGroup + + // supported caches information about which messages are supported. It is + // indexed by MID. An MID is supported if supported[MID] is true. + supported []bool + + // maxMessageSize is the maximum payload length (in bytes) that can be sent. + // It is initialized on Mount and is immutable. + maxMessageSize uint32 + + // fdsToClose tracks the FDs to close. It caches the FDs no longer being used + // by the client and closes them in one shot. It is not preserved across + // checkpoint/restore as FDIDs are not preserved. + fdsMu sync.Mutex + fdsToClose []FDID +} + +// NewClient creates a new client for communication with the server. It mounts +// the server and creates channels for fast IPC. NewClient takes ownership over +// the passed socket. On success, it returns the initialized client along with +// the root Inode. +func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) { + maxChans := maxChannels() + c := &Client{ + sockComm: newSockComm(sock), + channels: make([]*channel, 0, maxChans), + availableChannels: make([]*channel, 0, maxChans), + maxMessageSize: 1 << 20, // 1 MB for now. + fdsToClose: make([]FDID, 0, fdsToCloseBatchSize), + } + + // Start a goroutine to check socket health. This goroutine is also + // responsible for client cleanup. + c.watchdogWg.Add(1) + go c.watchdog() + + // Clean everything up if anything fails. + cu := cleanup.Make(func() { + c.Close() + }) + defer cu.Clean() + + // Mount the server first. Assume Mount is supported so that we can make the + // Mount RPC below. + c.supported = make([]bool, Mount+1) + c.supported[Mount] = true + mountMsg := MountReq{ + MountPath: SizedString(mountPath), + } + var mountResp MountResp + if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil { + return nil, nil, err + } + + // Initialize client. + c.maxMessageSize = uint32(mountResp.MaxMessageSize) + var maxSuppMID MID + for _, suppMID := range mountResp.SupportedMs { + if suppMID > maxSuppMID { + maxSuppMID = suppMID + } + } + c.supported = make([]bool, maxSuppMID+1) + for _, suppMID := range mountResp.SupportedMs { + c.supported[suppMID] = true + } + + // Create channels parallely so that channels can be used to create more + // channels and costly initialization like flipcall.Endpoint.Connect can + // proceed parallely. + var channelsWg sync.WaitGroup + channelErrs := make([]error, maxChans) + for i := 0; i < maxChans; i++ { + channelsWg.Add(1) + curChanID := i + go func() { + defer channelsWg.Done() + ch, err := c.createChannel() + if err != nil { + log.Warningf("channel creation failed: %v", err) + channelErrs[curChanID] = err + return + } + c.channelsMu.Lock() + c.channels = append(c.channels, ch) + c.availableChannels = append(c.availableChannels, ch) + c.channelsMu.Unlock() + }() + } + channelsWg.Wait() + + for _, channelErr := range channelErrs { + // Return the first non-nil channel creation error. + if channelErr != nil { + return nil, nil, channelErr + } + } + cu.Release() + + return c, &mountResp.Root, nil +} + +func (c *Client) watchdog() { + defer c.watchdogWg.Done() + + events := []unix.PollFd{ + { + Fd: int32(c.sockComm.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + if err != nil { + log.Warningf("lisafs.Client.watch(): %v", err) + } else if n != 1 { + log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Shutdown all active channels and wait for them to complete. + c.shutdownActiveChans() + c.activeWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.destroy() + } + c.channelsMu.Unlock() + + // Close main socket. + c.sockComm.destroy() +} + +func (c *Client) shutdownActiveChans() { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + availableChans := make(map[*channel]bool) + for _, ch := range c.availableChannels { + availableChans[ch] = true + } + for _, ch := range c.channels { + // A channel that is not available is active. + if _, ok := availableChans[ch]; !ok { + log.Debugf("shutting down active channel@%p...", ch) + ch.shutdown() + } + } + + // Prevent channels from becoming available and serving new requests. + c.availableChannels = nil +} + +// Close shuts down the main socket and waits for the watchdog to clean up. +func (c *Client) Close() { + // This shutdown has no effect if the watchdog has already fired and closed + // the main socket. + c.sockComm.shutdown() + c.watchdogWg.Wait() +} + +func (c *Client) createChannel() (*channel, error) { + var chanResp ChannelResp + var fds [2]int + if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil { + return nil, err + } + if fds[0] < 0 || fds[1] < 0 { + closeFDs(fds[:]) + return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds) + } + + // Lets create the channel. + defer closeFDs(fds[:1]) // The data FD is not needed after this. + desc := flipcall.PacketWindowDescriptor{ + FD: fds[0], + Offset: chanResp.dataOffset, + Length: int(chanResp.dataLength), + } + + ch := &channel{} + if err := ch.data.Init(flipcall.ClientSide, desc); err != nil { + closeFDs(fds[1:]) + return nil, err + } + ch.fdChan.Init(fds[1]) // fdChan now owns this FD. + + // Only a connected channel is usable. + if err := ch.data.Connect(); err != nil { + ch.destroy() + return nil, err + } + return ch, nil +} + +// IsSupported returns true if this connection supports the passed message. +func (c *Client) IsSupported(m MID) bool { + return int(m) < len(c.supported) && c.supported[m] +} + +// CloseFDBatched either queues the passed FD to be closed or makes a batch +// RPC to close all the accumulated FDs-to-close. +func (c *Client) CloseFDBatched(ctx context.Context, fd FDID) { + c.fdsMu.Lock() + c.fdsToClose = append(c.fdsToClose, fd) + if len(c.fdsToClose) < fdsToCloseBatchSize { + c.fdsMu.Unlock() + return + } + + // Flush the cache. We should not hold fdsMu while making an RPC, so be sure + // to copy the fdsToClose to another buffer before unlocking fdsMu. + var toCloseArr [fdsToCloseBatchSize]FDID + toClose := toCloseArr[:len(c.fdsToClose)] + copy(toClose, c.fdsToClose) + + // Clear fdsToClose so other FDIDs can be appended. + c.fdsToClose = c.fdsToClose[:0] + c.fdsMu.Unlock() + + req := CloseReq{FDs: toClose} + ctx.UninterruptibleSleepStart(false) + err := c.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + log.Warningf("lisafs: batch closing FDs returned error: %v", err) + } +} + +// SyncFDs makes a Fsync RPC to sync multiple FDs. +func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error { + if len(fds) == 0 { + return nil + } + req := FsyncReq{FDs: fds} + ctx.UninterruptibleSleepStart(false) + err := c.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// SndRcvMessage invokes reqMarshal to marshal the request onto the payload +// buffer, wakes up the server to process the request, waits for the response +// and invokes respUnmarshal with the response payload. respFDs is populated +// with the received FDs, extra fields are set to -1. +// +// Note that the function arguments intentionally accept marshal.Marshallable +// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of +// directly accepting the marshal.Marshallable interface. Even though just +// accepting marshal.Marshallable is cleaner, it leads to a heap allocation +// (even if that interface variable itself does not escape). In other words, +// implicit conversion to an interface leads to an allocation. +// +// Precondition: reqMarshal and respUnmarshal must be non-nil. +func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error { + if !c.IsSupported(m) { + return unix.EOPNOTSUPP + } + if payloadLen > c.maxMessageSize { + log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize) + return unix.EIO + } + wantFDs := len(respFDs) + if wantFDs > math.MaxUint8 { + log.Warningf("want too many FDs: %d", wantFDs) + return unix.EINVAL + } + + // Acquire a communicator. + comm := c.acquireCommunicator() + defer c.releaseCommunicator(comm) + + // Marshal the request into comm's payload buffer and make the RPC. + reqMarshal(comm.PayloadBuf(payloadLen)) + respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs)) + + // Handle FD donation. + rcvFDs := comm.ReleaseFDs() + if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 { + // releasedFDs is memory owned by comm which can not be returned to caller. + // Copy it into the caller's buffer. + numFDCopied := copy(respFDs, rcvFDs) + if numFDCopied < numRcvFDs { + log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs) + closeFDs(rcvFDs[numFDCopied:]) + } + if numFDCopied < wantFDs { + for i := numFDCopied; i < wantFDs; i++ { + respFDs[i] = -1 + } + } + } + + // Error cases. + if err != nil { + closeFDs(respFDs) + return err + } + if respM == Error { + closeFDs(respFDs) + var resp ErrorResp + resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen)) + return unix.Errno(resp.errno) + } + if respM != m { + closeFDs(respFDs) + log.Warningf("sent %d message but got %d in response", m, respM) + return unix.EINVAL + } + + // Success. The payload must be unmarshalled *before* comm is released. + respUnmarshal(comm.PayloadBuf(respPayloadLen)) + return nil +} + +// Postcondition: releaseCommunicator() must be called on the returned value. +func (c *Client) acquireCommunicator() Communicator { + // Prefer using channel over socket because: + // - Channel uses a shared memory region for passing messages. IO from shared + // memory is faster and does not involve making a syscall. + // - No intermediate buffer allocation needed. With a channel, the message + // can be directly pasted into the shared memory region. + if ch := c.getChannel(); ch != nil { + return ch + } + + c.sockMu.Lock() + return c.sockComm +} + +// Precondition: comm must have been acquired via acquireCommunicator(). +func (c *Client) releaseCommunicator(comm Communicator) { + switch t := comm.(type) { + case *sockCommunicator: + c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator(). + case *channel: + c.releaseChannel(t) + default: + panic(fmt.Sprintf("unknown communicator type %T", t)) + } +} + +// getChannel pops a channel from the available channels stack. The caller must +// release the channel after use. +func (c *Client) getChannel() *channel { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + if len(c.availableChannels) == 0 { + return nil + } + + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + c.activeWg.Add(1) + return ch +} + +// releaseChannel pushes the passed channel onto the available channel stack if +// reinsert is true. +func (c *Client) releaseChannel(ch *channel) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + // If availableChannels is nil, then watchdog has fired and the client is + // shutting down. So don't make this channel available again. + if !ch.dead && c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.activeWg.Done() +} diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go new file mode 100644 index 000000000..0f8788f3b --- /dev/null +++ b/pkg/lisafs/client_file.go @@ -0,0 +1,475 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// ClientFD is a wrapper around FDID that provides client-side utilities +// so that RPC making is easier. +type ClientFD struct { + fd FDID + client *Client +} + +// ID returns the underlying FDID. +func (f *ClientFD) ID() FDID { + return f.fd +} + +// Client returns the backing Client. +func (f *ClientFD) Client() *Client { + return f.client +} + +// NewFD initializes a new ClientFD. +func (c *Client) NewFD(fd FDID) ClientFD { + return ClientFD{ + client: c, + fd: fd, + } +} + +// Ok returns true if the underlying FD is ok. +func (f *ClientFD) Ok() bool { + return f.fd.Ok() +} + +// CloseBatched queues this FD to be closed on the server and resets f.fd. +// This maybe invoke the Close RPC if the queue is full. +func (f *ClientFD) CloseBatched(ctx context.Context) { + f.client.CloseFDBatched(ctx, f.fd) + f.fd = InvalidFDID +} + +// Close closes this FD immediately (invoking a Close RPC). Consider using +// CloseBatched if closing this FD on remote right away is not critical. +func (f *ClientFD) Close(ctx context.Context) error { + fdArr := [1]FDID{f.fd} + req := CloseReq{FDs: fdArr[:]} + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// OpenAt makes the OpenAt RPC. +func (f *ClientFD) OpenAt(ctx context.Context, flags uint32) (FDID, int, error) { + req := OpenAtReq{ + FD: f.fd, + Flags: flags, + } + var respFD [1]int + var resp OpenAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(OpenAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, respFD[:]) + ctx.UninterruptibleSleepFinish(false) + return resp.NewFD, respFD[0], err +} + +// OpenCreateAt makes the OpenCreateAt RPC. +func (f *ClientFD) OpenCreateAt(ctx context.Context, name string, flags uint32, mode linux.FileMode, uid UID, gid GID) (Inode, FDID, int, error) { + var req OpenCreateAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Flags = primitive.Uint32(flags) + req.Mode = mode + req.UID = uid + req.GID = gid + + var respFD [1]int + var resp OpenCreateAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(OpenCreateAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, respFD[:]) + ctx.UninterruptibleSleepFinish(false) + return resp.Child, resp.NewFD, respFD[0], err +} + +// StatTo makes the Fstat RPC and populates stat with the result. +func (f *ClientFD) StatTo(ctx context.Context, stat *linux.Statx) error { + req := StatReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FStat, uint32(req.SizeBytes()), req.MarshalUnsafe, stat.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Sync makes the Fsync RPC. +func (f *ClientFD) Sync(ctx context.Context) error { + req := FsyncReq{FDs: []FDID{f.fd}} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Read makes the PRead RPC. +func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) { + req := PReadReq{ + Offset: offset, + FD: f.fd, + Count: uint32(len(dst)), + } + + resp := PReadResp{ + // This will be unmarshalled into. Already set Buf so that we don't need to + // allocate a temporary buffer during unmarshalling. + // PReadResp.UnmarshalBytes expects this to be set. + Buf: dst, + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return uint64(resp.NumBytes), err +} + +// Write makes the PWrite RPC. +func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) { + req := PWriteReq{ + Offset: primitive.Uint64(offset), + FD: f.fd, + NumBytes: primitive.Uint32(len(src)), + Buf: src, + } + + var resp PWriteResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Count, err +} + +// MkdirAt makes the MkdirAt RPC. +func (f *ClientFD) MkdirAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID) (*Inode, error) { + var req MkdirAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Mode = mode + req.UID = uid + req.GID = gid + + var resp MkdirAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(MkdirAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.ChildDir, err +} + +// SymlinkAt makes the SymlinkAt RPC. +func (f *ClientFD) SymlinkAt(ctx context.Context, name, target string, uid UID, gid GID) (*Inode, error) { + req := SymlinkAtReq{ + DirFD: f.fd, + Name: SizedString(name), + Target: SizedString(target), + UID: uid, + GID: gid, + } + + var resp SymlinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(SymlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Symlink, err +} + +// LinkAt makes the LinkAt RPC. +func (f *ClientFD) LinkAt(ctx context.Context, targetFD FDID, name string) (*Inode, error) { + req := LinkAtReq{ + DirFD: f.fd, + Target: targetFD, + Name: SizedString(name), + } + + var resp LinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(LinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Link, err +} + +// MknodAt makes the MknodAt RPC. +func (f *ClientFD) MknodAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID, minor, major uint32) (*Inode, error) { + var req MknodAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Mode = mode + req.UID = uid + req.GID = gid + req.Minor = primitive.Uint32(minor) + req.Major = primitive.Uint32(major) + + var resp MknodAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(MknodAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Child, err +} + +// SetStat makes the SetStat RPC. +func (f *ClientFD) SetStat(ctx context.Context, stat *linux.Statx) (uint32, error, error) { + req := SetStatReq{ + FD: f.fd, + Mask: stat.Mask, + Mode: uint32(stat.Mode), + UID: UID(stat.UID), + GID: GID(stat.GID), + Size: stat.Size, + Atime: linux.Timespec{ + Sec: stat.Atime.Sec, + Nsec: int64(stat.Atime.Nsec), + }, + Mtime: linux.Timespec{ + Sec: stat.Mtime.Sec, + Nsec: int64(stat.Mtime.Nsec), + }, + } + + var resp SetStatResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(SetStat, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.FailureMask, unix.Errno(resp.FailureErrNo), err +} + +// WalkMultiple makes the Walk RPC with multiple path components. +func (f *ClientFD) WalkMultiple(ctx context.Context, names []string) (WalkStatus, []Inode, error) { + req := WalkReq{ + DirFD: f.fd, + Path: StringArray(names), + } + + var resp WalkResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Status, resp.Inodes, err +} + +// Walk makes the Walk RPC with just one path component to walk. +func (f *ClientFD) Walk(ctx context.Context, name string) (*Inode, error) { + req := WalkReq{ + DirFD: f.fd, + Path: []string{name}, + } + + var inode [1]Inode + resp := WalkResp{Inodes: inode[:]} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return nil, err + } + + switch resp.Status { + case WalkComponentDoesNotExist: + return nil, unix.ENOENT + case WalkComponentSymlink: + // f is not a directory which can be walked on. + return nil, unix.ENOTDIR + } + + if n := len(resp.Inodes); n > 1 { + for i := range resp.Inodes { + f.client.CloseFDBatched(ctx, resp.Inodes[i].ControlFD) + } + log.Warningf("requested to walk one component, but got %d results", n) + return nil, unix.EIO + } else if n == 0 { + log.Warningf("walk has success status but no results returned") + return nil, unix.ENOENT + } + return &inode[0], err +} + +// WalkStat makes the WalkStat RPC with multiple path components to walk. +func (f *ClientFD) WalkStat(ctx context.Context, names []string) ([]linux.Statx, error) { + req := WalkReq{ + DirFD: f.fd, + Path: StringArray(names), + } + + var resp WalkStatResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(WalkStat, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Stats, err +} + +// StatFSTo makes the FStatFS RPC and populates statFS with the result. +func (f *ClientFD) StatFSTo(ctx context.Context, statFS *StatFS) error { + req := FStatFSReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FStatFS, uint32(req.SizeBytes()), req.MarshalUnsafe, statFS.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Allocate makes the FAllocate RPC. +func (f *ClientFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + req := FAllocateReq{ + FD: f.fd, + Mode: mode, + Offset: offset, + Length: length, + } + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FAllocate, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// ReadLinkAt makes the ReadLinkAt RPC. +func (f *ClientFD) ReadLinkAt(ctx context.Context) (string, error) { + req := ReadLinkAtReq{FD: f.fd} + var resp ReadLinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(ReadLinkAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return string(resp.Target), err +} + +// Flush makes the Flush RPC. +func (f *ClientFD) Flush(ctx context.Context) error { + if !f.client.IsSupported(Flush) { + // If Flush is not supported, it probably means that it would be a noop. + return nil + } + req := FlushReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Flush, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Connect makes the Connect RPC. +func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) { + req := ConnectReq{FD: f.fd, SockType: uint32(sockType)} + var sockFD [1]int + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, sockFD[:]) + ctx.UninterruptibleSleepFinish(false) + if err == nil && sockFD[0] < 0 { + err = unix.EBADF + } + return sockFD[0], err +} + +// UnlinkAt makes the UnlinkAt RPC. +func (f *ClientFD) UnlinkAt(ctx context.Context, name string, flags uint32) error { + req := UnlinkAtReq{ + DirFD: f.fd, + Name: SizedString(name), + Flags: primitive.Uint32(flags), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(UnlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// RenameTo makes the RenameAt RPC which renames f to newDirFD directory with +// name newName. +func (f *ClientFD) RenameTo(ctx context.Context, newDirFD FDID, newName string) error { + req := RenameAtReq{ + Renamed: f.fd, + NewDir: newDirFD, + NewName: SizedString(newName), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(RenameAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Getdents64 makes the Getdents64 RPC. +func (f *ClientFD) Getdents64(ctx context.Context, count int32) ([]Dirent64, error) { + req := Getdents64Req{ + DirFD: f.fd, + Count: count, + } + + var resp Getdents64Resp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Getdents64, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Dirents, err +} + +// ListXattr makes the FListXattr RPC. +func (f *ClientFD) ListXattr(ctx context.Context, size uint64) ([]string, error) { + req := FListXattrReq{ + FD: f.fd, + Size: size, + } + + var resp FListXattrResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FListXattr, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Xattrs, err +} + +// GetXattr makes the FGetXattr RPC. +func (f *ClientFD) GetXattr(ctx context.Context, name string, size uint64) (string, error) { + req := FGetXattrReq{ + FD: f.fd, + Name: SizedString(name), + BufSize: primitive.Uint32(size), + } + + var resp FGetXattrResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FGetXattr, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return string(resp.Value), err +} + +// SetXattr makes the FSetXattr RPC. +func (f *ClientFD) SetXattr(ctx context.Context, name string, value string, flags uint32) error { + req := FSetXattrReq{ + FD: f.fd, + Name: SizedString(name), + Value: SizedString(value), + Flags: primitive.Uint32(flags), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FSetXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// RemoveXattr makes the FRemoveXattr RPC. +func (f *ClientFD) RemoveXattr(ctx context.Context, name string) error { + req := FRemoveXattrReq{ + FD: f.fd, + Name: SizedString(name), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FRemoveXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go new file mode 100644 index 000000000..ec2035158 --- /dev/null +++ b/pkg/lisafs/communicator.go @@ -0,0 +1,80 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import "golang.org/x/sys/unix" + +// Communicator is a server side utility which represents exactly how the +// server is communicating with the client. +type Communicator interface { + // PayloadBuf returns a slice to the payload section of its internal buffer + // where the message can be marshalled. The handlers should use this to + // populate the payload buffer with the message. + // + // The payload buffer contents *should* be preserved across calls with + // different sizes. Note that this is not a guarantee, because a compromised + // owner of a "shared" payload buffer can tamper with its contents anytime, + // even when it's not its turn to do so. + PayloadBuf(size uint32) []byte + + // SndRcvMessage sends message m. The caller must have populated PayloadBuf() + // with payloadLen bytes. The caller expects to receive wantFDs FDs. + // Any received FDs must be accessible via ReleaseFDs(). It returns the + // response message along with the response payload length. + SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) + + // DonateFD makes fd non-blocking and starts tracking it. The next call to + // ReleaseFDs will include fd in the order it was added. Communicator takes + // ownership of fd. Server side should call this. + DonateFD(fd int) error + + // Track starts tracking fd. The next call to ReleaseFDs will include fd in + // the order it was added. Communicator takes ownership of fd. Client side + // should use this for accumulating received FDs. + TrackFD(fd int) + + // ReleaseFDs returns the accumulated FDs and stops tracking them. The + // ownership of the FDs is transferred to the caller. + ReleaseFDs() []int +} + +// fdTracker is a partial implementation of Communicator. It can be embedded in +// Communicator implementations to keep track of FD donations. +type fdTracker struct { + fds []int +} + +// DonateFD implements Communicator.DonateFD. +func (d *fdTracker) DonateFD(fd int) error { + // Make sure the FD is non-blocking. + if err := unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return err + } + d.TrackFD(fd) + return nil +} + +// TrackFD implements Communicator.TrackFD. +func (d *fdTracker) TrackFD(fd int) { + d.fds = append(d.fds, fd) +} + +// ReleaseFDs implements Communicator.ReleaseFDs. +func (d *fdTracker) ReleaseFDs() []int { + ret := d.fds + d.fds = d.fds[:0] + return ret +} diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go new file mode 100644 index 000000000..f6e5ecb4f --- /dev/null +++ b/pkg/lisafs/connection.go @@ -0,0 +1,320 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// Connection represents a connection between a mount point in the client and a +// mount point in the server. It is owned by the server on which it was started +// and facilitates communication with the client mount. +// +// Each connection is set up using a unix domain socket. One end is owned by +// the server and the other end is owned by the client. The connection may +// spawn additional comunicational channels for the same mount for increased +// RPC concurrency. +type Connection struct { + // server is the server on which this connection was created. It is immutably + // associated with it for its entire lifetime. + server *Server + + // mounted is a one way flag indicating whether this connection has been + // mounted correctly and the server is initialized properly. + mounted bool + + // readonly indicates if this connection is readonly. All write operations + // will fail with EROFS. + readonly bool + + // sockComm is the main socket by which this connections is established. + sockComm *sockCommunicator + + // channelsMu protects channels. + channelsMu sync.Mutex + // channels keeps track of all open channels. + channels []*channel + + // activeWg represents active channels. + activeWg sync.WaitGroup + + // reqGate counts requests that are still being handled. + reqGate sync.Gate + + // channelAlloc is used to allocate memory for channels. + channelAlloc *flipcall.PacketWindowAllocator + + fdsMu sync.RWMutex + // fds keeps tracks of open FDs on this server. It is protected by fdsMu. + fds map[FDID]genericFD + // nextFDID is the next available FDID. It is protected by fdsMu. + nextFDID FDID +} + +// CreateConnection initializes a new connection - creating a server if +// required. The connection must be started separately. +func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) { + c := &Connection{ + sockComm: newSockComm(sock), + server: s, + readonly: readonly, + channels: make([]*channel, 0, maxChannels()), + fds: make(map[FDID]genericFD), + nextFDID: InvalidFDID + 1, + } + + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return nil, err + } + c.channelAlloc = alloc + return c, nil +} + +// Server returns the associated server. +func (c *Connection) Server() *Server { + return c.server +} + +// ServerImpl returns the associated server implementation. +func (c *Connection) ServerImpl() ServerImpl { + return c.server.impl +} + +// Run defines the lifecycle of a connection. +func (c *Connection) Run() { + defer c.close() + + // Start handling requests on this connection. + for { + m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */) + if err != nil { + log.Debugf("sock read failed, closing connection: %v", err) + return + } + + respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen) + err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs) + closeFDs(respFDs) + if err != nil { + log.Debugf("sock write failed, closing connection: %v", err) + return + } + } +} + +// service starts servicing the passed channel until the channel is shutdown. +// This is a blocking method and hence must be called in a separate goroutine. +func (c *Connection) service(ch *channel) error { + rcvDataLen, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rcvDataLen > 0 { + m, payloadLen, err := ch.rcvMsg(rcvDataLen) + if err != nil { + return err + } + respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen) + numFDs := ch.sendFDs(respFDs) + closeFDs(respFDs) + + ch.marshalHdr(respM, numFDs) + rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen) + if err != nil { + return err + } + } + return nil +} + +func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) { + resp := &ErrorResp{errno: uint32(err)} + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return Error, respLen, nil +} + +func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) { + if !c.reqGate.Enter() { + // c.close() has been called; the connection is shutting down. + return c.respondError(comm, unix.ECONNRESET) + } + defer c.reqGate.Leave() + + if !c.mounted && m != Mount { + log.Warningf("connection must first be mounted") + return c.respondError(comm, unix.EINVAL) + } + + // Check if the message is supported for forward compatibility. + if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil { + log.Warningf("received request which is not supported by the server, MID = %d", m) + return c.respondError(comm, unix.EOPNOTSUPP) + } + + // Try handling the request. + respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen) + fds := comm.ReleaseFDs() + if err != nil { + closeFDs(fds) + return c.respondError(comm, p9.ExtractErrno(err)) + } + + return m, respPayloadLen, fds +} + +func (c *Connection) close() { + // Wait for completion of all inflight requests. This is mostly so that if + // a request is stuck, the sandbox supervisor has the opportunity to kill + // us with SIGABRT to get a stack dump of the offending handler. + c.reqGate.Close() + + // Shutdown and clean up channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.shutdown() + } + c.activeWg.Wait() + for _, ch := range c.channels { + ch.destroy() + } + // This is to prevent additional channels from being created. + c.channels = nil + c.channelsMu.Unlock() + + // Free the channel memory. + if c.channelAlloc != nil { + c.channelAlloc.Destroy() + } + + // Ensure the connection is closed. + c.sockComm.destroy() + + // Cleanup all FDs. + c.fdsMu.Lock() + for fdid := range c.fds { + fd := c.removeFDLocked(fdid) + fd.DecRef(nil) // Drop the ref held by c. + } + c.fdsMu.Unlock() +} + +// The caller gains a ref on the FD on success. +func (c *Connection) lookupFD(id FDID) (genericFD, error) { + c.fdsMu.RLock() + defer c.fdsMu.RUnlock() + + fd, ok := c.fds[id] + if !ok { + return nil, unix.EBADF + } + fd.IncRef() + return fd, nil +} + +// LookupControlFD retrieves the control FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + cfd, ok := fd.(*ControlFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return cfd, nil +} + +// LookupOpenFD retrieves the open FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + ofd, ok := fd.(*OpenFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return ofd, nil +} + +// insertFD inserts the passed fd into the internal datastructure to track FDs. +// The caller must hold a ref on fd which is transferred to the connection. +func (c *Connection) insertFD(fd genericFD) FDID { + c.fdsMu.Lock() + defer c.fdsMu.Unlock() + + res := c.nextFDID + c.nextFDID++ + if c.nextFDID < res { + panic("ran out of FDIDs") + } + c.fds[res] = fd + return res +} + +// RemoveFD makes c stop tracking the passed FDID and drops its ref on it. +func (c *Connection) RemoveFD(id FDID) { + c.fdsMu.Lock() + fd := c.removeFDLocked(id) + c.fdsMu.Unlock() + if fd != nil { + // Drop the ref held by c. This can take arbitrarily long. So do not hold + // c.fdsMu while calling it. + fd.DecRef(nil) + } +} + +// RemoveControlFDLocked is the same as RemoveFD with added preconditions. +// +// Preconditions: +// * server's rename mutex must at least be read locked. +// * id must be pointing to a control FD. +func (c *Connection) RemoveControlFDLocked(id FDID) { + c.fdsMu.Lock() + fd := c.removeFDLocked(id) + c.fdsMu.Unlock() + if fd != nil { + // Drop the ref held by c. This can take arbitrarily long. So do not hold + // c.fdsMu while calling it. + fd.(*ControlFD).DecRefLocked() + } +} + +// removeFDLocked makes c stop tracking the passed FDID. Note that the caller +// must drop ref on the returned fd (preferably without holding c.fdsMu). +// +// Precondition: c.fdsMu is locked. +func (c *Connection) removeFDLocked(id FDID) genericFD { + fd := c.fds[id] + if fd == nil { + log.Warningf("removeFDLocked called on non-existent FDID %d", id) + return nil + } + delete(c.fds, id) + return fd +} diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go new file mode 100644 index 000000000..28ba47112 --- /dev/null +++ b/pkg/lisafs/connection_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connection_test + +import ( + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/lisafs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + dynamicMsgID = lisafs.Channel + 1 + versionMsgID = dynamicMsgID + 1 +) + +var handlers = [...]lisafs.RPCHandler{ + lisafs.Error: lisafs.ErrorHandler, + lisafs.Mount: lisafs.MountHandler, + lisafs.Channel: lisafs.ChannelHandler, + dynamicMsgID: dynamicMsgHandler, + versionMsgID: versionHandler, +} + +// testServer implements lisafs.ServerImpl. +type testServer struct { + lisafs.Server +} + +var _ lisafs.ServerImpl = (*testServer)(nil) + +type testControlFD struct { + lisafs.ControlFD + lisafs.ControlFDImpl +} + +func (fd *testControlFD) FD() *lisafs.ControlFD { + return &fd.ControlFD +} + +// Mount implements lisafs.Mount. +func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) { + return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil +} + +// MaxMessageSize implements lisafs.MaxMessageSize. +func (s *testServer) MaxMessageSize() uint32 { + return lisafs.MaxMessageSize() +} + +// SupportedMessages implements lisafs.ServerImpl.SupportedMessages. +func (s *testServer) SupportedMessages() []lisafs.MID { + return []lisafs.MID{ + lisafs.Mount, + lisafs.Channel, + dynamicMsgID, + versionMsgID, + } +} + +func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) { + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + + ts := &testServer{} + ts.Server.InitTestOnly(ts, handlers[:]) + conn, err := ts.CreateConnection(serverSocket, false /* readonly */) + if err != nil { + t.Fatalf("starting connection failed: %v", err) + return + } + ts.StartConnection(conn) + + c, _, err := lisafs.NewClient(clientSocket, "/") + if err != nil { + t.Fatalf("client creation failed: %v", err) + } + + clientFn(c) + + c.Close() // This should trigger client and server shutdown. + ts.Wait() +} + +// TestStartUp tests that the server and client can be started up correctly. +func TestStartUp(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + if c.IsSupported(lisafs.Error) { + t.Errorf("sending error messages should not be supported") + } + }) +} + +func TestUnsupportedMessage(t *testing.T) { + unsupportedM := lisafs.MID(len(handlers)) + runServerClient(t, func(c *lisafs.Client) { + if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP { + t.Errorf("expected EOPNOTSUPP but got err: %v", err) + } + }) +} + +func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + var req lisafs.MsgDynamic + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Just echo back the message. + respPayloadLen := uint32(req.SizeBytes()) + req.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// TestStress stress tests sending many messages from various goroutines. +func TestStress(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + concurrency := 8 + numMsgPerGoroutine := 5000 + var clientWg sync.WaitGroup + for i := 0; i < concurrency; i++ { + clientWg.Add(1) + go func() { + defer clientWg.Done() + + for j := 0; j < numMsgPerGoroutine; j++ { + // Create a massive random message. + var req lisafs.MsgDynamic + req.Randomize(100) + + var resp lisafs.MsgDynamic + if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil { + t.Errorf("SndRcvMessage: received unexpected error %v", err) + return + } + if !reflect.DeepEqual(&req, &resp) { + t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp) + } + } + }() + } + + clientWg.Wait() + }) +} + +func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + // To be fair, usually handlers will create their own objects and return a + // pointer to those. Might be tempting to reuse above variables, but don't. + var rv lisafs.P9Version + rv.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Create a new response. + sv := lisafs.P9Version{ + MSize: rv.MSize, + Version: "9P2000.L.Google.11", + } + respPayloadLen := uint32(sv.SizeBytes()) + sv.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel. +func BenchmarkSendRecv(b *testing.B) { + b.ReportAllocs() + sendV := lisafs.P9Version{ + MSize: 1 << 20, + Version: "9P2000.L.Google.12", + } + + var recvV lisafs.P9Version + runServerClient(b, func(c *lisafs.Client) { + for i := 0; i < b.N; i++ { + if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil { + b.Fatalf("unexpected error occurred: %v", err) + } + } + }) +} diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go new file mode 100644 index 000000000..cc6919a1b --- /dev/null +++ b/pkg/lisafs/fd.go @@ -0,0 +1,374 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/sync" +) + +// FDID (file descriptor identifier) is used to identify FDs on a connection. +// Each connection has its own FDID namespace. +// +// +marshal slice:FDIDSlice +type FDID uint32 + +// InvalidFDID represents an invalid FDID. +const InvalidFDID FDID = 0 + +// Ok returns true if f is a valid FDID. +func (f FDID) Ok() bool { + return f != InvalidFDID +} + +// genericFD can represent a ControlFD or OpenFD. +type genericFD interface { + refsvfs2.RefCounter +} + +// A ControlFD is the gateway to the backing filesystem tree node. It is an +// unusual concept. This exists to provide a safe way to do path-based +// operations on the file. It performs operations that can modify the +// filesystem tree and synchronizes these operations. See ControlFDImpl for +// supported operations. +// +// It is not an inode, because multiple control FDs are allowed to exist on the +// same file. It is not a file descriptor because it is not tied to any access +// mode, i.e. a control FD can change its access mode based on the operation +// being performed. +// +// Reference Model: +// * When a control FD is created, the connection takes a ref on it which +// represents the client's ref on the FD. +// * The client can drop its ref via the Close RPC which will in turn make the +// connection drop its ref. +// * Each control FD holds a ref on its parent for its entire life time. +type ControlFD struct { + controlFDRefs + controlFDEntry + + // parent is the parent directory FD containing the file this FD represents. + // A ControlFD holds a ref on parent for its entire lifetime. If this FD + // represents the root, then parent is nil. parent may be a control FD from + // another connection (another mount point). parent is protected by the + // backing server's rename mutex. + parent *ControlFD + + // name is the file path's last component name. If this FD represents the + // root directory, then name is the mount path. name is protected by the + // backing server's rename mutex. + name string + + // children is a linked list of all children control FDs. As per reference + // model, all children hold a ref on this FD. + // children is protected by childrenMu and server's rename mutex. To have + // mutual exclusion, it is sufficient to: + // * Hold rename mutex for reading and lock childrenMu. OR + // * Or hold rename mutex for writing. + childrenMu sync.Mutex + children controlFDList + + // openFDs is a linked list of all FDs opened on this FD. As per reference + // model, all open FDs hold a ref on this FD. + openFDsMu sync.RWMutex + openFDs openFDList + + // All the following fields are immutable. + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // conn is the backing connection owning this FD. + conn *Connection + + // ftype is the file type of the backing inode. ftype.FileType() == ftype. + ftype linux.FileMode + + // impl is the control FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl ControlFDImpl +} + +var _ genericFD = (*ControlFD)(nil) + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *ControlFD) DecRef(context.Context) { + fd.controlFDRefs.DecRef(func() { + if fd.parent != nil { + fd.conn.server.RenameMu.RLock() + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.conn.server.RenameMu.RUnlock() + fd.parent.DecRef(nil) // Drop the ref on the parent. + } + fd.impl.Close(fd.conn) + }) +} + +// DecRefLocked is the same as DecRef except the added precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) DecRefLocked() { + fd.controlFDRefs.DecRef(func() { + fd.clearParentLocked() + fd.impl.Close(fd.conn) + }) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) clearParentLocked() { + if fd.parent == nil { + return + } + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.parent.DecRefLocked() // Drop the ref on the parent. +} + +// Init must be called before first use of fd. It inserts fd into the +// filesystem tree. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.controlFDRefs.InitRefs() + fd.conn = c + fd.id = c.insertFD(fd) + fd.name = name + fd.ftype = mode.FileType() + fd.impl = impl + fd.setParentLocked(parent) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) setParentLocked(parent *ControlFD) { + fd.parent = parent + if parent != nil { + parent.IncRef() // Hold a ref on parent. + parent.childrenMu.Lock() + parent.children.PushBack(fd) + parent.childrenMu.Unlock() + } +} + +// FileType returns the file mode only containing the file type bits. +func (fd *ControlFD) FileType() linux.FileMode { + return fd.ftype +} + +// IsDir indicates whether fd represents a directory. +func (fd *ControlFD) IsDir() bool { + return fd.ftype == unix.S_IFDIR +} + +// IsRegular indicates whether fd represents a regular file. +func (fd *ControlFD) IsRegular() bool { + return fd.ftype == unix.S_IFREG +} + +// IsSymlink indicates whether fd represents a symbolic link. +func (fd *ControlFD) IsSymlink() bool { + return fd.ftype == unix.S_IFLNK +} + +// IsSocket indicates whether fd represents a socket. +func (fd *ControlFD) IsSocket() bool { + return fd.ftype == unix.S_IFSOCK +} + +// NameLocked returns the backing file's last component name. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) NameLocked() string { + return fd.name +} + +// ParentLocked returns the parent control FD. Note that parent might be a +// control FD from another connection on this server. So its ID must not +// returned on this connection because FDIDs are local to their connection. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) ParentLocked() ControlFDImpl { + if fd.parent == nil { + return nil + } + return fd.parent.impl +} + +// ID returns fd's ID. +func (fd *ControlFD) ID() FDID { + return fd.id +} + +// FilePath returns the absolute path of the file fd was opened on. This is +// expensive and must not be called on hot paths. FilePath acquires the rename +// mutex for reading so callers should not be holding it. +func (fd *ControlFD) FilePath() string { + // Lock the rename mutex for reading to ensure that the filesystem tree is not + // changed while we traverse it upwards. + fd.conn.server.RenameMu.RLock() + defer fd.conn.server.RenameMu.RUnlock() + return fd.FilePathLocked() +} + +// FilePathLocked is the same as FilePath with the additional precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) FilePathLocked() string { + // Walk upwards and prepend name to res. + var res fspath.Builder + for fd != nil { + res.PrependComponent(fd.name) + fd = fd.parent + } + return res.String() +} + +// ForEachOpenFD executes fn on each FD opened on fd. +func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) { + fd.openFDsMu.RLock() + defer fd.openFDsMu.RUnlock() + for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() { + fn(ofd.impl) + } +} + +// OpenFD represents an open file descriptor on the protocol. It resonates +// closely with a Linux file descriptor. Its operations are limited to the +// file. Its operations are not allowed to modify or traverse the filesystem +// tree. See OpenFDImpl for the supported operations. +// +// Reference Model: +// * An OpenFD takes a reference on the control FD it was opened on. +type OpenFD struct { + openFDRefs + openFDEntry + + // All the following fields are immutable. + + // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref + // on controlFD for its entire lifetime. + controlFD *ControlFD + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // Access mode for this FD. + readable bool + writable bool + + // impl is the open FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl OpenFDImpl +} + +var _ genericFD = (*OpenFD)(nil) + +// ID returns fd's ID. +func (fd *OpenFD) ID() FDID { + return fd.id +} + +// ControlFD returns the control FD on which this FD was opened. +func (fd *OpenFD) ControlFD() ControlFDImpl { + return fd.controlFD.impl +} + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *OpenFD) DecRef(context.Context) { + fd.openFDRefs.DecRef(func() { + fd.controlFD.openFDsMu.Lock() + fd.controlFD.openFDs.Remove(fd) + fd.controlFD.openFDsMu.Unlock() + fd.controlFD.DecRef(nil) // Drop the ref on the control FD. + fd.impl.Close(fd.controlFD.conn) + }) +} + +// Init must be called before first use of fd. +func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.openFDRefs.InitRefs() + fd.controlFD = cfd + fd.id = cfd.conn.insertFD(fd) + accessMode := flags & unix.O_ACCMODE + fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR + fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR + fd.impl = impl + cfd.IncRef() // Holds a ref on cfd for its lifetime. + cfd.openFDsMu.Lock() + cfd.openFDs.PushBack(fd) + cfd.openFDsMu.Unlock() +} + +// ControlFDImpl contains implementation details for a ControlFD. +// Implementations of ControlFDImpl should contain their associated ControlFD +// by value as their first field. +// +// The operations that perform path traversal or any modification to the +// filesystem tree must synchronize those modifications with the server's +// rename mutex. +type ControlFDImpl interface { + FD() *ControlFD + Close(c *Connection) + Stat(c *Connection, comm Communicator) (uint32, error) + SetStat(c *Connection, comm Communicator, stat SetStatReq) (uint32, error) + Walk(c *Connection, comm Communicator, path StringArray) (uint32, error) + WalkStat(c *Connection, comm Communicator, path StringArray) (uint32, error) + Open(c *Connection, comm Communicator, flags uint32) (uint32, error) + OpenCreate(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, flags uint32) (uint32, error) + Mkdir(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string) (uint32, error) + Mknod(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, minor uint32, major uint32) (uint32, error) + Symlink(c *Connection, comm Communicator, name string, target string, uid UID, gid GID) (uint32, error) + Link(c *Connection, comm Communicator, dir ControlFDImpl, name string) (uint32, error) + StatFS(c *Connection, comm Communicator) (uint32, error) + Readlink(c *Connection, comm Communicator) (uint32, error) + Connect(c *Connection, comm Communicator, sockType uint32) error + Unlink(c *Connection, name string, flags uint32) error + RenameLocked(c *Connection, newDir ControlFDImpl, newName string) (func(ControlFDImpl), func(), error) + GetXattr(c *Connection, comm Communicator, name string, size uint32) (uint32, error) + SetXattr(c *Connection, name string, value string, flags uint32) error + ListXattr(c *Connection, comm Communicator, size uint64) (uint32, error) + RemoveXattr(c *Connection, comm Communicator, name string) error +} + +// OpenFDImpl contains implementation details for a OpenFD. Implementations of +// OpenFDImpl should contain their associated OpenFD by value as their first +// field. +// +// Since these operations do not perform any path traversal or any modification +// to the filesystem tree, there is no need to synchronize with rename +// operations. +type OpenFDImpl interface { + FD() *OpenFD + Close(c *Connection) + Stat(c *Connection, comm Communicator) (uint32, error) + Sync(c *Connection) error + Write(c *Connection, comm Communicator, buf []byte, off uint64) (uint32, error) + Read(c *Connection, comm Communicator, off uint64, count uint32) (uint32, error) + Allocate(c *Connection, mode, off, length uint64) error + Flush(c *Connection) error + Getdent64(c *Connection, comm Communicator, count uint32, seek0 bool) (uint32, error) +} diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go new file mode 100644 index 000000000..82807734d --- /dev/null +++ b/pkg/lisafs/handlers.go @@ -0,0 +1,768 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + "path" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +const ( + allowedOpenFlags = unix.O_ACCMODE | unix.O_TRUNC + setStatSupportedMask = unix.STATX_MODE | unix.STATX_UID | unix.STATX_GID | unix.STATX_SIZE | unix.STATX_ATIME | unix.STATX_MTIME +) + +// RPCHandler defines a handler that is invoked when the associated message is +// received. The handler is responsible for: +// +// * Unmarshalling the request from the passed payload and interpreting it. +// * Marshalling the response into the communicator's payload buffer. +// * Return the number of payload bytes written. +// * Donate any FDs (if needed) to comm which will in turn donate it to client. +type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) + +var handlers = [...]RPCHandler{ + Error: ErrorHandler, + Mount: MountHandler, + Channel: ChannelHandler, + FStat: FStatHandler, + SetStat: SetStatHandler, + Walk: WalkHandler, + WalkStat: WalkStatHandler, + OpenAt: OpenAtHandler, + OpenCreateAt: OpenCreateAtHandler, + Close: CloseHandler, + FSync: FSyncHandler, + PWrite: PWriteHandler, + PRead: PReadHandler, + MkdirAt: MkdirAtHandler, + MknodAt: MknodAtHandler, + SymlinkAt: SymlinkAtHandler, + LinkAt: LinkAtHandler, + FStatFS: FStatFSHandler, + FAllocate: FAllocateHandler, + ReadLinkAt: ReadLinkAtHandler, + Flush: FlushHandler, + Connect: ConnectHandler, + UnlinkAt: UnlinkAtHandler, + RenameAt: RenameAtHandler, + Getdents64: Getdents64Handler, + FGetXattr: FGetXattrHandler, + FSetXattr: FSetXattrHandler, + FListXattr: FListXattrHandler, + FRemoveXattr: FRemoveXattrHandler, +} + +// ErrorHandler handles Error message. +func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + // Client should never send Error. + return 0, unix.EINVAL +} + +// MountHandler handles the Mount RPC. Note that there can not be concurrent +// executions of MountHandler on a connection because the connection enforces +// that Mount is the first message on the connection. Only after the connection +// has been successfully mounted can other channels be created. +func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req MountReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + mountPath := path.Clean(string(req.MountPath)) + if !filepath.IsAbs(mountPath) { + log.Warningf("mountPath %q is not absolute", mountPath) + return 0, unix.EINVAL + } + + if c.mounted { + log.Warningf("connection has already been mounted at %q", mountPath) + return 0, unix.EBUSY + } + + rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath) + if err != nil { + return 0, err + } + + c.server.addMountPoint(rootFD.FD()) + c.mounted = true + resp := MountResp{ + Root: rootIno, + SupportedMs: c.ServerImpl().SupportedMessages(), + MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()), + } + respPayloadLen := uint32(resp.SizeBytes()) + resp.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// ChannelHandler handles the Channel RPC. +func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize()) + if err != nil { + return 0, err + } + + // Start servicing the channel in a separate goroutine. + c.activeWg.Add(1) + go func() { + if err := c.service(ch); err != nil { + // Don't log shutdown error which is expected during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err) + } + } + c.activeWg.Done() + }() + + clientDataFD, err := unix.Dup(desc.FD) + if err != nil { + unix.Close(fdSock) + ch.shutdown() + return 0, err + } + + // Respond to client with successful channel creation message. + if err := comm.DonateFD(clientDataFD); err != nil { + return 0, err + } + if err := comm.DonateFD(fdSock); err != nil { + return 0, err + } + resp := ChannelResp{ + dataOffset: desc.Offset, + dataLength: uint64(desc.Length), + } + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return respLen, nil +} + +// FStatHandler handles the FStat RPC. +func FStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req StatReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.lookupFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + switch t := fd.(type) { + case *ControlFD: + return t.impl.Stat(c, comm) + case *OpenFD: + return t.impl.Stat(c, comm) + default: + panic(fmt.Sprintf("unknown fd type %T", t)) + } +} + +// SetStatHandler handles the SetStat RPC. +func SetStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + + var req SetStatReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + if req.Mask&^setStatSupportedMask != 0 { + return 0, unix.EPERM + } + + return fd.impl.SetStat(c, comm, req) +} + +// WalkHandler handles the Walk RPC. +func WalkHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req WalkReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + for _, name := range req.Path { + if err := checkSafeName(name); err != nil { + return 0, err + } + } + + return fd.impl.Walk(c, comm, req.Path) +} + +// WalkStatHandler handles the WalkStat RPC. +func WalkStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req WalkReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + // Note that this fd is allowed to not actually be a directory when the + // only path component to walk is "" (self). + if !fd.IsDir() { + if len(req.Path) > 1 || (len(req.Path) == 1 && len(req.Path[0]) > 0) { + return 0, unix.ENOTDIR + } + } + for i, name := range req.Path { + // First component is allowed to be "". + if i == 0 && len(name) == 0 { + continue + } + if err := checkSafeName(name); err != nil { + return 0, err + } + } + + return fd.impl.WalkStat(c, comm, req.Path) +} + +// OpenAtHandler handles the OpenAt RPC. +func OpenAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req OpenAtReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + // Only keep allowed open flags. + if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags { + log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags) + req.Flags = allowedFlags + } + + accessMode := req.Flags & unix.O_ACCMODE + trunc := req.Flags&unix.O_TRUNC != 0 + if c.readonly && (accessMode != unix.O_RDONLY || trunc) { + return 0, unix.EROFS + } + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if fd.IsDir() { + // Directory is not truncatable and must be opened with O_RDONLY. + if accessMode != unix.O_RDONLY || trunc { + return 0, unix.EISDIR + } + } + + return fd.impl.Open(c, comm, req.Flags) +} + +// OpenCreateAtHandler handles the OpenCreateAt RPC. +func OpenCreateAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req OpenCreateAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Only keep allowed open flags. + if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags { + log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags) + req.Flags = allowedFlags + } + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + + return fd.impl.OpenCreate(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Flags)) +} + +// CloseHandler handles the Close RPC. +func CloseHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req CloseReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + for _, fd := range req.FDs { + c.RemoveFD(fd) + } + + // There is no response message for this. + return 0, nil +} + +// FSyncHandler handles the FSync RPC. +func FSyncHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FsyncReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Return the first error we encounter, but sync everything we can + // regardless. + var retErr error + for _, fdid := range req.FDs { + if err := c.fsyncFD(fdid); err != nil && retErr == nil { + retErr = err + } + } + + // There is no response message for this. + return 0, retErr +} + +func (c *Connection) fsyncFD(id FDID) error { + fd, err := c.LookupOpenFD(id) + if err != nil { + return err + } + return fd.impl.Sync(c) +} + +// PWriteHandler handles the PWrite RPC. +func PWriteHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req PWriteReq + // Note that it is an optimized Unmarshal operation which avoids any buffer + // allocation and copying. req.Buf just points to payload. This is safe to do + // as the handler owns payload and req's lifetime is limited to the handler. + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + if !fd.writable { + return 0, unix.EBADF + } + return fd.impl.Write(c, comm, req.Buf, uint64(req.Offset)) +} + +// PReadHandler handles the PRead RPC. +func PReadHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req PReadReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.readable { + return 0, unix.EBADF + } + return fd.impl.Read(c, comm, req.Offset, req.Count) +} + +// MkdirAtHandler handles the MkdirAt RPC. +func MkdirAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req MkdirAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Mkdir(c, comm, req.Mode, req.UID, req.GID, name) +} + +// MknodAtHandler handles the MknodAt RPC. +func MknodAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req MknodAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Mknod(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Minor), uint32(req.Major)) +} + +// SymlinkAtHandler handles the SymlinkAt RPC. +func SymlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req SymlinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Symlink(c, comm, name, string(req.Target), req.UID, req.GID) +} + +// LinkAtHandler handles the LinkAt RPC. +func LinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req LinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + + targetFD, err := c.LookupControlFD(req.Target) + if err != nil { + return 0, err + } + return targetFD.impl.Link(c, comm, fd.impl, name) +} + +// FStatFSHandler handles the FStatFS RPC. +func FStatFSHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FStatFSReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.StatFS(c, comm) +} + +// FAllocateHandler handles the FAllocate RPC. +func FAllocateHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FAllocateReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.writable { + return 0, unix.EBADF + } + return 0, fd.impl.Allocate(c, req.Mode, req.Offset, req.Length) +} + +// ReadLinkAtHandler handles the ReadLinkAt RPC. +func ReadLinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req ReadLinkAtReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsSymlink() { + return 0, unix.EINVAL + } + return fd.impl.Readlink(c, comm) +} + +// FlushHandler handles the Flush RPC. +func FlushHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FlushReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + return 0, fd.impl.Flush(c) +} + +// ConnectHandler handles the Connect RPC. +func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req ConnectReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsSocket() { + return 0, unix.ENOTSOCK + } + return 0, fd.impl.Connect(c, comm, req.SockType) +} + +// UnlinkAtHandler handles the UnlinkAt RPC. +func UnlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req UnlinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return 0, fd.impl.Unlink(c, name, uint32(req.Flags)) +} + +// RenameAtHandler handles the RenameAt RPC. +func RenameAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req RenameAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + newName := string(req.NewName) + if err := checkSafeName(newName); err != nil { + return 0, err + } + + renamed, err := c.LookupControlFD(req.Renamed) + if err != nil { + return 0, err + } + defer renamed.DecRef(nil) + + newDir, err := c.LookupControlFD(req.NewDir) + if err != nil { + return 0, err + } + defer newDir.DecRef(nil) + if !newDir.IsDir() { + return 0, unix.ENOTDIR + } + + // Hold RenameMu for writing during rename, this is important. + c.server.RenameMu.Lock() + defer c.server.RenameMu.Unlock() + + if renamed.parent == nil { + // renamed is root. + return 0, unix.EBUSY + } + + oldParentPath := renamed.parent.FilePathLocked() + oldPath := oldParentPath + "/" + renamed.name + if newName == renamed.name && oldParentPath == newDir.FilePathLocked() { + // Nothing to do. + return 0, nil + } + + updateControlFD, cleanUp, err := renamed.impl.RenameLocked(c, newDir.impl, newName) + if err != nil { + return 0, err + } + + c.server.forEachMountPoint(func(root *ControlFD) { + if !strings.HasPrefix(oldPath, root.name) { + return + } + pit := fspath.Parse(oldPath[len(root.name):]).Begin + root.renameRecursiveLocked(newDir, newName, pit, updateControlFD) + }) + + if cleanUp != nil { + cleanUp() + } + return 0, nil +} + +// Precondition: rename mutex must be locked for writing. +func (fd *ControlFD) renameRecursiveLocked(newDir *ControlFD, newName string, pit fspath.Iterator, updateControlFD func(ControlFDImpl)) { + if !pit.Ok() { + // fd should be renamed. + fd.clearParentLocked() + fd.setParentLocked(newDir) + fd.name = newName + if updateControlFD != nil { + updateControlFD(fd.impl) + } + return + } + + cur := pit.String() + next := pit.Next() + // No need to hold fd.childrenMu because RenameMu is locked for writing. + for child := fd.children.Front(); child != nil; child = child.Next() { + if child.name == cur { + child.renameRecursiveLocked(newDir, newName, next, updateControlFD) + } + } +} + +// Getdents64Handler handles the Getdents64 RPC. +func Getdents64Handler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req Getdents64Req + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.controlFD.IsDir() { + return 0, unix.ENOTDIR + } + + seek0 := false + if req.Count < 0 { + seek0 = true + req.Count = -req.Count + } + return fd.impl.Getdent64(c, comm, uint32(req.Count), seek0) +} + +// FGetXattrHandler handles the FGetXattr RPC. +func FGetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FGetXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.GetXattr(c, comm, string(req.Name), uint32(req.BufSize)) +} + +// FSetXattrHandler handles the FSetXattr RPC. +func FSetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FSetXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return 0, fd.impl.SetXattr(c, string(req.Name), string(req.Value), uint32(req.Flags)) +} + +// FListXattrHandler handles the FListXattr RPC. +func FListXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FListXattrReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.ListXattr(c, comm, req.Size) +} + +// FRemoveXattrHandler handles the FRemoveXattr RPC. +func FRemoveXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FRemoveXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return 0, fd.impl.RemoveXattr(c, comm, string(req.Name)) +} + +// checkSafeName validates the name and returns nil or returns an error. +func checkSafeName(name string) error { + if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." { + return nil + } + return unix.EINVAL +} diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go new file mode 100644 index 000000000..4d8e956ab --- /dev/null +++ b/pkg/lisafs/lisafs.go @@ -0,0 +1,18 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lisafs (LInux SAndbox FileSystem) defines the protocol for +// filesystem RPCs between an untrusted Sandbox (client) and a trusted +// filesystem server. +package lisafs diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go new file mode 100644 index 000000000..722afd0be --- /dev/null +++ b/pkg/lisafs/message.go @@ -0,0 +1,1251 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "os" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// Messages have two parts: +// * A transport header used to decipher received messages. +// * A byte array referred to as "payload" which contains the actual message. +// +// "dataLen" refers to the size of both combined. + +// MID (message ID) is used to identify messages to parse from payload. +// +// +marshal slice:MIDSlice +type MID uint16 + +// These constants are used to identify their corresponding message types. +const ( + // Error is only used in responses to pass errors to client. + Error MID = 0 + + // Mount is used to establish connection between the client and server mount + // point. lisafs requires that the client makes a successful Mount RPC before + // making other RPCs. + Mount MID = 1 + + // Channel requests to start a new communicational channel. + Channel MID = 2 + + // FStat requests the stat(2) results for a specified file. + FStat MID = 3 + + // SetStat requests to change file attributes. Note that there is no one + // corresponding Linux syscall. This is a conglomeration of fchmod(2), + // fchown(2), ftruncate(2) and futimesat(2). + SetStat MID = 4 + + // Walk requests to walk the specified path starting from the specified + // directory. Server-side path traversal is terminated preemptively on + // symlinks entries because they can cause non-linear traversal. + Walk MID = 5 + + // WalkStat is the same as Walk, except the following differences: + // * If the first path component is "", then it also returns stat results + // for the directory where the walk starts. + // * Does not return Inode, just the Stat results for each path component. + WalkStat MID = 6 + + // OpenAt is analogous to openat(2). It does not perform any walk. It merely + // duplicates the control FD with the open flags passed. + OpenAt MID = 7 + + // OpenCreateAt is analogous to openat(2) with O_CREAT|O_EXCL added to flags. + // It also returns the newly created file inode. + OpenCreateAt MID = 8 + + // Close is analogous to close(2) but can work on multiple FDs. + Close MID = 9 + + // FSync is analogous to fsync(2) but can work on multiple FDs. + FSync MID = 10 + + // PWrite is analogous to pwrite(2). + PWrite MID = 11 + + // PRead is analogous to pread(2). + PRead MID = 12 + + // MkdirAt is analogous to mkdirat(2). + MkdirAt MID = 13 + + // MknodAt is analogous to mknodat(2). + MknodAt MID = 14 + + // SymlinkAt is analogous to symlinkat(2). + SymlinkAt MID = 15 + + // LinkAt is analogous to linkat(2). + LinkAt MID = 16 + + // FStatFS is analogous to fstatfs(2). + FStatFS MID = 17 + + // FAllocate is analogous to fallocate(2). + FAllocate MID = 18 + + // ReadLinkAt is analogous to readlinkat(2). + ReadLinkAt MID = 19 + + // Flush cleans up the file state. Its behavior is implementation + // dependent and might not even be supported in server implementations. + Flush MID = 20 + + // Connect is loosely analogous to connect(2). + Connect MID = 21 + + // UnlinkAt is analogous to unlinkat(2). + UnlinkAt MID = 22 + + // RenameAt is loosely analogous to renameat(2). + RenameAt MID = 23 + + // Getdents64 is analogous to getdents64(2). + Getdents64 MID = 24 + + // FGetXattr is analogous to fgetxattr(2). + FGetXattr MID = 25 + + // FSetXattr is analogous to fsetxattr(2). + FSetXattr MID = 26 + + // FListXattr is analogous to flistxattr(2). + FListXattr MID = 27 + + // FRemoveXattr is analogous to fremovexattr(2). + FRemoveXattr MID = 28 +) + +const ( + // NoUID is a sentinel used to indicate no valid UID. + NoUID UID = math.MaxUint32 + + // NoGID is a sentinel used to indicate no valid GID. + NoGID GID = math.MaxUint32 +) + +// MaxMessageSize is the recommended max message size that can be used by +// connections. Server implementations may choose to use other values. +func MaxMessageSize() uint32 { + // Return HugePageSize - PageSize so that when flipcall packet window is + // created with MaxMessageSize() + flipcall header size + channel header + // size, HugePageSize is allocated and can be backed by a single huge page + // if supported by the underlying memfd. + return uint32(hostarch.HugePageSize - os.Getpagesize()) +} + +// TODO(gvisor.dev/issue/6450): Once this is resolved: +// * Update manual implementations and function signatures. +// * Update RPC handlers and appropriate callers to handle errors correctly. +// * Update manual implementations to get rid of buffer shifting. + +// UID represents a user ID. +// +// +marshal +type UID uint32 + +// Ok returns true if uid is not NoUID. +func (uid UID) Ok() bool { + return uid != NoUID +} + +// GID represents a group ID. +// +// +marshal +type GID uint32 + +// Ok returns true if gid is not NoGID. +func (gid GID) Ok() bool { + return gid != NoGID +} + +// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes. +func NoopMarshal([]byte) {} + +// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes. +func NoopUnmarshal([]byte) {} + +// SizedString represents a string in memory. The marshalled string bytes are +// preceded by a uint32 signifying the string length. +type SizedString string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SizedString) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + len(*s) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SizedString) MarshalBytes(dst []byte) { + strLen := primitive.Uint32(len(*s)) + strLen.MarshalUnsafe(dst) + dst = dst[strLen.SizeBytes():] + // Copy without any allocation. + copy(dst[:strLen], *s) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SizedString) UnmarshalBytes(src []byte) { + var strLen primitive.Uint32 + strLen.UnmarshalUnsafe(src) + src = src[strLen.SizeBytes():] + // Take the hit, this leads to an allocation + memcpy. No way around it. + *s = SizedString(src[:strLen]) +} + +// StringArray represents an array of SizedStrings in memory. The marshalled +// array data is preceded by a uint32 signifying the array length. +type StringArray []string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *StringArray) SizeBytes() int { + size := (*primitive.Uint32)(nil).SizeBytes() + for _, str := range *s { + sstr := SizedString(str) + size += sstr.SizeBytes() + } + return size +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *StringArray) MarshalBytes(dst []byte) { + arrLen := primitive.Uint32(len(*s)) + arrLen.MarshalUnsafe(dst) + dst = dst[arrLen.SizeBytes():] + for _, str := range *s { + sstr := SizedString(str) + sstr.MarshalBytes(dst) + dst = dst[sstr.SizeBytes():] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *StringArray) UnmarshalBytes(src []byte) { + var arrLen primitive.Uint32 + arrLen.UnmarshalUnsafe(src) + src = src[arrLen.SizeBytes():] + + if cap(*s) < int(arrLen) { + *s = make([]string, arrLen) + } else { + *s = (*s)[:arrLen] + } + + for i := primitive.Uint32(0); i < arrLen; i++ { + var sstr SizedString + sstr.UnmarshalBytes(src) + src = src[sstr.SizeBytes():] + (*s)[i] = string(sstr) + } +} + +// Inode represents an inode on the remote filesystem. +// +// +marshal slice:InodeSlice +type Inode struct { + ControlFD FDID + _ uint32 // Need to make struct packed. + Stat linux.Statx +} + +// MountReq represents a Mount request. +type MountReq struct { + MountPath SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountReq) SizeBytes() int { + return m.MountPath.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountReq) MarshalBytes(dst []byte) { + m.MountPath.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountReq) UnmarshalBytes(src []byte) { + m.MountPath.UnmarshalBytes(src) +} + +// MountResp represents a Mount response. +type MountResp struct { + Root Inode + // MaxMessageSize is the maximum size of messages communicated between the + // client and server in bytes. This includes the communication header. + MaxMessageSize primitive.Uint32 + // SupportedMs holds all the supported messages. + SupportedMs []MID +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountResp) SizeBytes() int { + return m.Root.SizeBytes() + + m.MaxMessageSize.SizeBytes() + + (*primitive.Uint16)(nil).SizeBytes() + + (len(m.SupportedMs) * (*MID)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountResp) MarshalBytes(dst []byte) { + m.Root.MarshalUnsafe(dst) + dst = dst[m.Root.SizeBytes():] + m.MaxMessageSize.MarshalUnsafe(dst) + dst = dst[m.MaxMessageSize.SizeBytes():] + numSupported := primitive.Uint16(len(m.SupportedMs)) + numSupported.MarshalBytes(dst) + dst = dst[numSupported.SizeBytes():] + MarshalUnsafeMIDSlice(m.SupportedMs, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountResp) UnmarshalBytes(src []byte) { + m.Root.UnmarshalUnsafe(src) + src = src[m.Root.SizeBytes():] + m.MaxMessageSize.UnmarshalUnsafe(src) + src = src[m.MaxMessageSize.SizeBytes():] + var numSupported primitive.Uint16 + numSupported.UnmarshalBytes(src) + src = src[numSupported.SizeBytes():] + m.SupportedMs = make([]MID, numSupported) + UnmarshalUnsafeMIDSlice(m.SupportedMs, src) +} + +// ChannelResp is the response to the create channel request. +// +// +marshal +type ChannelResp struct { + dataOffset int64 + dataLength uint64 +} + +// ErrorResp is returned to represent an error while handling a request. +// +// +marshal +type ErrorResp struct { + errno uint32 +} + +// StatReq requests the stat results for the specified FD. +// +// +marshal +type StatReq struct { + FD FDID +} + +// SetStatReq is used to set attributeds on FDs. +// +// +marshal +type SetStatReq struct { + FD FDID + _ uint32 + Mask uint32 + Mode uint32 // Only permissions part is settable. + UID UID + GID GID + Size uint64 + Atime linux.Timespec + Mtime linux.Timespec +} + +// SetStatResp is used to communicate SetStat results. It contains a mask +// representing the failed changes. It also contains the errno of the failed +// set attribute operation. If multiple operations failed then any of those +// errnos can be returned. +// +// +marshal +type SetStatResp struct { + FailureMask uint32 + FailureErrNo uint32 +} + +// WalkReq is used to request to walk multiple path components at once. This +// is used for both Walk and WalkStat. +type WalkReq struct { + DirFD FDID + Path StringArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkReq) SizeBytes() int { + return w.DirFD.SizeBytes() + w.Path.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkReq) MarshalBytes(dst []byte) { + w.DirFD.MarshalUnsafe(dst) + dst = dst[w.DirFD.SizeBytes():] + w.Path.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkReq) UnmarshalBytes(src []byte) { + w.DirFD.UnmarshalUnsafe(src) + src = src[w.DirFD.SizeBytes():] + w.Path.UnmarshalBytes(src) +} + +// WalkStatus is used to indicate the reason for partial/unsuccessful server +// side Walk operations. Please note that partial/unsuccessful walk operations +// do not necessarily fail the RPC. The RPC is successful with a failure hint +// which can be used by the client to infer server-side state. +type WalkStatus = primitive.Uint8 + +const ( + // WalkSuccess indicates that all path components were successfully walked. + WalkSuccess WalkStatus = iota + + // WalkComponentDoesNotExist indicates that the walk was prematurely + // terminated because an intermediate path component does not exist on + // server. The results of all previous existing path components is returned. + WalkComponentDoesNotExist + + // WalkComponentSymlink indicates that the walk was prematurely + // terminated because an intermediate path component was a symlink. It is not + // safe to resolve symlinks remotely (unaware of mount points). + WalkComponentSymlink +) + +// WalkResp is used to communicate the inodes walked by the server. +type WalkResp struct { + Status WalkStatus + Inodes []Inode +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkResp) SizeBytes() int { + return w.Status.SizeBytes() + + (*primitive.Uint32)(nil).SizeBytes() + (len(w.Inodes) * (*Inode)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkResp) MarshalBytes(dst []byte) { + w.Status.MarshalUnsafe(dst) + dst = dst[w.Status.SizeBytes():] + + numInodes := primitive.Uint32(len(w.Inodes)) + numInodes.MarshalUnsafe(dst) + dst = dst[numInodes.SizeBytes():] + + MarshalUnsafeInodeSlice(w.Inodes, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkResp) UnmarshalBytes(src []byte) { + w.Status.UnmarshalUnsafe(src) + src = src[w.Status.SizeBytes():] + + var numInodes primitive.Uint32 + numInodes.UnmarshalUnsafe(src) + src = src[numInodes.SizeBytes():] + + if cap(w.Inodes) < int(numInodes) { + w.Inodes = make([]Inode, numInodes) + } else { + w.Inodes = w.Inodes[:numInodes] + } + UnmarshalUnsafeInodeSlice(w.Inodes, src) +} + +// WalkStatResp is used to communicate stat results for WalkStat. +type WalkStatResp struct { + Stats []linux.Statx +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkStatResp) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (len(w.Stats) * linux.SizeOfStatx) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkStatResp) MarshalBytes(dst []byte) { + numStats := primitive.Uint32(len(w.Stats)) + numStats.MarshalUnsafe(dst) + dst = dst[numStats.SizeBytes():] + + linux.MarshalUnsafeStatxSlice(w.Stats, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkStatResp) UnmarshalBytes(src []byte) { + var numStats primitive.Uint32 + numStats.UnmarshalUnsafe(src) + src = src[numStats.SizeBytes():] + + if cap(w.Stats) < int(numStats) { + w.Stats = make([]linux.Statx, numStats) + } else { + w.Stats = w.Stats[:numStats] + } + linux.UnmarshalUnsafeStatxSlice(w.Stats, src) +} + +// OpenAtReq is used to open existing FDs with the specified flags. +// +// +marshal +type OpenAtReq struct { + FD FDID + Flags uint32 +} + +// OpenAtResp is used to communicate the newly created FD. +// +// +marshal +type OpenAtResp struct { + NewFD FDID +} + +// +marshal +type createCommon struct { + DirFD FDID + Mode linux.FileMode + _ uint16 // Need to make struct packed. + UID UID + GID GID +} + +// OpenCreateAtReq is used to make OpenCreateAt requests. +type OpenCreateAtReq struct { + createCommon + Name SizedString + Flags primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (o *OpenCreateAtReq) SizeBytes() int { + return o.createCommon.SizeBytes() + o.Name.SizeBytes() + o.Flags.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (o *OpenCreateAtReq) MarshalBytes(dst []byte) { + o.createCommon.MarshalUnsafe(dst) + dst = dst[o.createCommon.SizeBytes():] + o.Name.MarshalBytes(dst) + dst = dst[o.Name.SizeBytes():] + o.Flags.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) { + o.createCommon.UnmarshalUnsafe(src) + src = src[o.createCommon.SizeBytes():] + o.Name.UnmarshalBytes(src) + src = src[o.Name.SizeBytes():] + o.Flags.UnmarshalUnsafe(src) +} + +// OpenCreateAtResp is used to communicate successful OpenCreateAt results. +// +// +marshal +type OpenCreateAtResp struct { + Child Inode + NewFD FDID + _ uint32 // Need to make struct packed. +} + +// FdArray is a utility struct which implements a marshallable type for +// communicating an array of FDIDs. In memory, the array data is preceded by a +// uint32 denoting the array length. +type FdArray []FDID + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FdArray) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (len(*f) * (*FDID)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FdArray) MarshalBytes(dst []byte) { + arrLen := primitive.Uint32(len(*f)) + arrLen.MarshalUnsafe(dst) + dst = dst[arrLen.SizeBytes():] + MarshalUnsafeFDIDSlice(*f, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FdArray) UnmarshalBytes(src []byte) { + var arrLen primitive.Uint32 + arrLen.UnmarshalUnsafe(src) + src = src[arrLen.SizeBytes():] + if cap(*f) < int(arrLen) { + *f = make(FdArray, arrLen) + } else { + *f = (*f)[:arrLen] + } + UnmarshalUnsafeFDIDSlice(*f, src) +} + +// CloseReq is used to close(2) FDs. +type CloseReq struct { + FDs FdArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *CloseReq) SizeBytes() int { + return c.FDs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *CloseReq) MarshalBytes(dst []byte) { + c.FDs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *CloseReq) UnmarshalBytes(src []byte) { + c.FDs.UnmarshalBytes(src) +} + +// FsyncReq is used to fsync(2) FDs. +type FsyncReq struct { + FDs FdArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FsyncReq) SizeBytes() int { + return f.FDs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FsyncReq) MarshalBytes(dst []byte) { + f.FDs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FsyncReq) UnmarshalBytes(src []byte) { + f.FDs.UnmarshalBytes(src) +} + +// PReadReq is used to pread(2) on an FD. +// +// +marshal +type PReadReq struct { + Offset uint64 + FD FDID + Count uint32 +} + +// PReadResp is used to return the result of pread(2). +type PReadResp struct { + NumBytes primitive.Uint32 + Buf []byte +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *PReadResp) SizeBytes() int { + return r.NumBytes.SizeBytes() + int(r.NumBytes) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *PReadResp) MarshalBytes(dst []byte) { + r.NumBytes.MarshalUnsafe(dst) + dst = dst[r.NumBytes.SizeBytes():] + copy(dst[:r.NumBytes], r.Buf[:r.NumBytes]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *PReadResp) UnmarshalBytes(src []byte) { + r.NumBytes.UnmarshalUnsafe(src) + src = src[r.NumBytes.SizeBytes():] + + // We expect the client to have already allocated r.Buf. r.Buf probably + // (optimally) points to usermem. Directly copy into that. + copy(r.Buf[:r.NumBytes], src[:r.NumBytes]) +} + +// PWriteReq is used to pwrite(2) on an FD. +type PWriteReq struct { + Offset primitive.Uint64 + FD FDID + NumBytes primitive.Uint32 + Buf []byte +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *PWriteReq) SizeBytes() int { + return w.Offset.SizeBytes() + w.FD.SizeBytes() + w.NumBytes.SizeBytes() + int(w.NumBytes) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *PWriteReq) MarshalBytes(dst []byte) { + w.Offset.MarshalUnsafe(dst) + dst = dst[w.Offset.SizeBytes():] + w.FD.MarshalUnsafe(dst) + dst = dst[w.FD.SizeBytes():] + w.NumBytes.MarshalUnsafe(dst) + dst = dst[w.NumBytes.SizeBytes():] + copy(dst[:w.NumBytes], w.Buf[:w.NumBytes]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *PWriteReq) UnmarshalBytes(src []byte) { + w.Offset.UnmarshalUnsafe(src) + src = src[w.Offset.SizeBytes():] + w.FD.UnmarshalUnsafe(src) + src = src[w.FD.SizeBytes():] + w.NumBytes.UnmarshalUnsafe(src) + src = src[w.NumBytes.SizeBytes():] + + // This is an optimization. Assuming that the server is making this call, it + // is safe to just point to src rather than allocating and copying. + w.Buf = src[:w.NumBytes] +} + +// PWriteResp is used to return the result of pwrite(2). +// +// +marshal +type PWriteResp struct { + Count uint64 +} + +// MkdirAtReq is used to make MkdirAt requests. +type MkdirAtReq struct { + createCommon + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MkdirAtReq) SizeBytes() int { + return m.createCommon.SizeBytes() + m.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MkdirAtReq) MarshalBytes(dst []byte) { + m.createCommon.MarshalUnsafe(dst) + dst = dst[m.createCommon.SizeBytes():] + m.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MkdirAtReq) UnmarshalBytes(src []byte) { + m.createCommon.UnmarshalUnsafe(src) + src = src[m.createCommon.SizeBytes():] + m.Name.UnmarshalBytes(src) +} + +// MkdirAtResp is the response to a successful MkdirAt request. +// +// +marshal +type MkdirAtResp struct { + ChildDir Inode +} + +// MknodAtReq is used to make MknodAt requests. +type MknodAtReq struct { + createCommon + Name SizedString + Minor primitive.Uint32 + Major primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MknodAtReq) SizeBytes() int { + return m.createCommon.SizeBytes() + m.Name.SizeBytes() + m.Minor.SizeBytes() + m.Major.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MknodAtReq) MarshalBytes(dst []byte) { + m.createCommon.MarshalUnsafe(dst) + dst = dst[m.createCommon.SizeBytes():] + m.Name.MarshalBytes(dst) + dst = dst[m.Name.SizeBytes():] + m.Minor.MarshalUnsafe(dst) + dst = dst[m.Minor.SizeBytes():] + m.Major.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MknodAtReq) UnmarshalBytes(src []byte) { + m.createCommon.UnmarshalUnsafe(src) + src = src[m.createCommon.SizeBytes():] + m.Name.UnmarshalBytes(src) + src = src[m.Name.SizeBytes():] + m.Minor.UnmarshalUnsafe(src) + src = src[m.Minor.SizeBytes():] + m.Major.UnmarshalUnsafe(src) +} + +// MknodAtResp is the response to a successful MknodAt request. +// +// +marshal +type MknodAtResp struct { + Child Inode +} + +// SymlinkAtReq is used to make SymlinkAt request. +type SymlinkAtReq struct { + DirFD FDID + Name SizedString + Target SizedString + UID UID + GID GID +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SymlinkAtReq) SizeBytes() int { + return s.DirFD.SizeBytes() + s.Name.SizeBytes() + s.Target.SizeBytes() + s.UID.SizeBytes() + s.GID.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SymlinkAtReq) MarshalBytes(dst []byte) { + s.DirFD.MarshalUnsafe(dst) + dst = dst[s.DirFD.SizeBytes():] + s.Name.MarshalBytes(dst) + dst = dst[s.Name.SizeBytes():] + s.Target.MarshalBytes(dst) + dst = dst[s.Target.SizeBytes():] + s.UID.MarshalUnsafe(dst) + dst = dst[s.UID.SizeBytes():] + s.GID.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SymlinkAtReq) UnmarshalBytes(src []byte) { + s.DirFD.UnmarshalUnsafe(src) + src = src[s.DirFD.SizeBytes():] + s.Name.UnmarshalBytes(src) + src = src[s.Name.SizeBytes():] + s.Target.UnmarshalBytes(src) + src = src[s.Target.SizeBytes():] + s.UID.UnmarshalUnsafe(src) + src = src[s.UID.SizeBytes():] + s.GID.UnmarshalUnsafe(src) +} + +// SymlinkAtResp is the response to a successful SymlinkAt request. +// +// +marshal +type SymlinkAtResp struct { + Symlink Inode +} + +// LinkAtReq is used to make LinkAt requests. +type LinkAtReq struct { + DirFD FDID + Target FDID + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (l *LinkAtReq) SizeBytes() int { + return l.DirFD.SizeBytes() + l.Target.SizeBytes() + l.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (l *LinkAtReq) MarshalBytes(dst []byte) { + l.DirFD.MarshalUnsafe(dst) + dst = dst[l.DirFD.SizeBytes():] + l.Target.MarshalUnsafe(dst) + dst = dst[l.Target.SizeBytes():] + l.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (l *LinkAtReq) UnmarshalBytes(src []byte) { + l.DirFD.UnmarshalUnsafe(src) + src = src[l.DirFD.SizeBytes():] + l.Target.UnmarshalUnsafe(src) + src = src[l.Target.SizeBytes():] + l.Name.UnmarshalBytes(src) +} + +// LinkAtResp is used to respond to a successful LinkAt request. +// +// +marshal +type LinkAtResp struct { + Link Inode +} + +// FStatFSReq is used to request StatFS results for the specified FD. +// +// +marshal +type FStatFSReq struct { + FD FDID +} + +// StatFS is responded to a successful FStatFS request. +// +// +marshal +type StatFS struct { + Type uint64 + BlockSize int64 + Blocks uint64 + BlocksFree uint64 + BlocksAvailable uint64 + Files uint64 + FilesFree uint64 + NameLength uint64 +} + +// FAllocateReq is used to request to fallocate(2) an FD. This has no response. +// +// +marshal +type FAllocateReq struct { + FD FDID + _ uint32 + Mode uint64 + Offset uint64 + Length uint64 +} + +// ReadLinkAtReq is used to readlinkat(2) at the specified FD. +// +// +marshal +type ReadLinkAtReq struct { + FD FDID +} + +// ReadLinkAtResp is used to communicate ReadLinkAt results. +type ReadLinkAtResp struct { + Target SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *ReadLinkAtResp) SizeBytes() int { + return r.Target.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *ReadLinkAtResp) MarshalBytes(dst []byte) { + r.Target.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) { + r.Target.UnmarshalBytes(src) +} + +// FlushReq is used to make Flush requests. +// +// +marshal +type FlushReq struct { + FD FDID +} + +// ConnectReq is used to make a Connect request. +// +// +marshal +type ConnectReq struct { + FD FDID + // SockType is used to specify the socket type to connect to. As a special + // case, SockType = 0 means that the socket type does not matter and the + // requester will accept any socket type. + SockType uint32 +} + +// UnlinkAtReq is used to make UnlinkAt request. +type UnlinkAtReq struct { + DirFD FDID + Name SizedString + Flags primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *UnlinkAtReq) SizeBytes() int { + return u.DirFD.SizeBytes() + u.Name.SizeBytes() + u.Flags.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *UnlinkAtReq) MarshalBytes(dst []byte) { + u.DirFD.MarshalUnsafe(dst) + dst = dst[u.DirFD.SizeBytes():] + u.Name.MarshalBytes(dst) + dst = dst[u.Name.SizeBytes():] + u.Flags.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *UnlinkAtReq) UnmarshalBytes(src []byte) { + u.DirFD.UnmarshalUnsafe(src) + src = src[u.DirFD.SizeBytes():] + u.Name.UnmarshalBytes(src) + src = src[u.Name.SizeBytes():] + u.Flags.UnmarshalUnsafe(src) +} + +// RenameAtReq is used to make Rename requests. Note that the request takes in +// the to-be-renamed file's FD instead of oldDir and oldName like renameat(2). +type RenameAtReq struct { + Renamed FDID + NewDir FDID + NewName SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *RenameAtReq) SizeBytes() int { + return r.Renamed.SizeBytes() + r.NewDir.SizeBytes() + r.NewName.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *RenameAtReq) MarshalBytes(dst []byte) { + r.Renamed.MarshalUnsafe(dst) + dst = dst[r.Renamed.SizeBytes():] + r.NewDir.MarshalUnsafe(dst) + dst = dst[r.NewDir.SizeBytes():] + r.NewName.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *RenameAtReq) UnmarshalBytes(src []byte) { + r.Renamed.UnmarshalUnsafe(src) + src = src[r.Renamed.SizeBytes():] + r.NewDir.UnmarshalUnsafe(src) + src = src[r.NewDir.SizeBytes():] + r.NewName.UnmarshalBytes(src) +} + +// Getdents64Req is used to make Getdents64 requests. +// +// +marshal +type Getdents64Req struct { + DirFD FDID + // Count is the number of bytes to read. A negative value of Count is used to + // indicate that the implementation must lseek(0, SEEK_SET) before calling + // getdents64(2). Implementations must use the absolute value of Count to + // determine the number of bytes to read. + Count int32 +} + +// Dirent64 is analogous to struct linux_dirent64. +type Dirent64 struct { + Ino primitive.Uint64 + DevMinor primitive.Uint32 + DevMajor primitive.Uint32 + Off primitive.Uint64 + Type primitive.Uint8 + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (d *Dirent64) SizeBytes() int { + return d.Ino.SizeBytes() + d.DevMinor.SizeBytes() + d.DevMajor.SizeBytes() + d.Off.SizeBytes() + d.Type.SizeBytes() + d.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (d *Dirent64) MarshalBytes(dst []byte) { + d.Ino.MarshalUnsafe(dst) + dst = dst[d.Ino.SizeBytes():] + d.DevMinor.MarshalUnsafe(dst) + dst = dst[d.DevMinor.SizeBytes():] + d.DevMajor.MarshalUnsafe(dst) + dst = dst[d.DevMajor.SizeBytes():] + d.Off.MarshalUnsafe(dst) + dst = dst[d.Off.SizeBytes():] + d.Type.MarshalUnsafe(dst) + dst = dst[d.Type.SizeBytes():] + d.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (d *Dirent64) UnmarshalBytes(src []byte) { + d.Ino.UnmarshalUnsafe(src) + src = src[d.Ino.SizeBytes():] + d.DevMinor.UnmarshalUnsafe(src) + src = src[d.DevMinor.SizeBytes():] + d.DevMajor.UnmarshalUnsafe(src) + src = src[d.DevMajor.SizeBytes():] + d.Off.UnmarshalUnsafe(src) + src = src[d.Off.SizeBytes():] + d.Type.UnmarshalUnsafe(src) + src = src[d.Type.SizeBytes():] + d.Name.UnmarshalBytes(src) +} + +// Getdents64Resp is used to communicate getdents64 results. +type Getdents64Resp struct { + Dirents []Dirent64 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *Getdents64Resp) SizeBytes() int { + ret := (*primitive.Uint32)(nil).SizeBytes() + for i := range g.Dirents { + ret += g.Dirents[i].SizeBytes() + } + return ret +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *Getdents64Resp) MarshalBytes(dst []byte) { + numDirents := primitive.Uint32(len(g.Dirents)) + numDirents.MarshalUnsafe(dst) + dst = dst[numDirents.SizeBytes():] + for i := range g.Dirents { + g.Dirents[i].MarshalBytes(dst) + dst = dst[g.Dirents[i].SizeBytes():] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *Getdents64Resp) UnmarshalBytes(src []byte) { + var numDirents primitive.Uint32 + numDirents.UnmarshalUnsafe(src) + if cap(g.Dirents) < int(numDirents) { + g.Dirents = make([]Dirent64, numDirents) + } else { + g.Dirents = g.Dirents[:numDirents] + } + + src = src[numDirents.SizeBytes():] + for i := range g.Dirents { + g.Dirents[i].UnmarshalBytes(src) + src = src[g.Dirents[i].SizeBytes():] + } +} + +// FGetXattrReq is used to make FGetXattr requests. The response to this is +// just a SizedString containing the xattr value. +type FGetXattrReq struct { + FD FDID + BufSize primitive.Uint32 + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *FGetXattrReq) SizeBytes() int { + return g.FD.SizeBytes() + g.BufSize.SizeBytes() + g.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *FGetXattrReq) MarshalBytes(dst []byte) { + g.FD.MarshalUnsafe(dst) + dst = dst[g.FD.SizeBytes():] + g.BufSize.MarshalUnsafe(dst) + dst = dst[g.BufSize.SizeBytes():] + g.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *FGetXattrReq) UnmarshalBytes(src []byte) { + g.FD.UnmarshalUnsafe(src) + src = src[g.FD.SizeBytes():] + g.BufSize.UnmarshalUnsafe(src) + src = src[g.BufSize.SizeBytes():] + g.Name.UnmarshalBytes(src) +} + +// FGetXattrResp is used to respond to FGetXattr request. +type FGetXattrResp struct { + Value SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *FGetXattrResp) SizeBytes() int { + return g.Value.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *FGetXattrResp) MarshalBytes(dst []byte) { + g.Value.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *FGetXattrResp) UnmarshalBytes(src []byte) { + g.Value.UnmarshalBytes(src) +} + +// FSetXattrReq is used to make FSetXattr requests. It has no response. +type FSetXattrReq struct { + FD FDID + Flags primitive.Uint32 + Name SizedString + Value SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *FSetXattrReq) SizeBytes() int { + return s.FD.SizeBytes() + s.Flags.SizeBytes() + s.Name.SizeBytes() + s.Value.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *FSetXattrReq) MarshalBytes(dst []byte) { + s.FD.MarshalUnsafe(dst) + dst = dst[s.FD.SizeBytes():] + s.Flags.MarshalUnsafe(dst) + dst = dst[s.Flags.SizeBytes():] + s.Name.MarshalBytes(dst) + dst = dst[s.Name.SizeBytes():] + s.Value.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *FSetXattrReq) UnmarshalBytes(src []byte) { + s.FD.UnmarshalUnsafe(src) + src = src[s.FD.SizeBytes():] + s.Flags.UnmarshalUnsafe(src) + src = src[s.Flags.SizeBytes():] + s.Name.UnmarshalBytes(src) + src = src[s.Name.SizeBytes():] + s.Value.UnmarshalBytes(src) +} + +// FRemoveXattrReq is used to make FRemoveXattr requests. It has no response. +type FRemoveXattrReq struct { + FD FDID + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *FRemoveXattrReq) SizeBytes() int { + return r.FD.SizeBytes() + r.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *FRemoveXattrReq) MarshalBytes(dst []byte) { + r.FD.MarshalUnsafe(dst) + dst = dst[r.FD.SizeBytes():] + r.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) { + r.FD.UnmarshalUnsafe(src) + src = src[r.FD.SizeBytes():] + r.Name.UnmarshalBytes(src) +} + +// FListXattrReq is used to make FListXattr requests. +// +// +marshal +type FListXattrReq struct { + FD FDID + _ uint32 + Size uint64 +} + +// FListXattrResp is used to respond to FListXattr requests. +type FListXattrResp struct { + Xattrs StringArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (l *FListXattrResp) SizeBytes() int { + return l.Xattrs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (l *FListXattrResp) MarshalBytes(dst []byte) { + l.Xattrs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (l *FListXattrResp) UnmarshalBytes(src []byte) { + l.Xattrs.UnmarshalBytes(src) +} diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go new file mode 100644 index 000000000..3868dfa08 --- /dev/null +++ b/pkg/lisafs/sample_message.go @@ -0,0 +1,110 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math/rand" + + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// MsgSimple is a sample packed struct which can be used to test message passing. +// +// +marshal slice:Msg1Slice +type MsgSimple struct { + A uint16 + B uint16 + C uint32 + D uint64 +} + +// Randomize randomizes the contents of m. +func (m *MsgSimple) Randomize() { + m.A = uint16(rand.Uint32()) + m.B = uint16(rand.Uint32()) + m.C = rand.Uint32() + m.D = rand.Uint64() +} + +// MsgDynamic is a sample dynamic struct which can be used to test message passing. +// +// +marshal dynamic +type MsgDynamic struct { + N primitive.Uint32 + Arr []MsgSimple +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MsgDynamic) SizeBytes() int { + return m.N.SizeBytes() + + (int(m.N) * (*MsgSimple)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MsgDynamic) MarshalBytes(dst []byte) { + m.N.MarshalUnsafe(dst) + dst = dst[m.N.SizeBytes():] + MarshalUnsafeMsg1Slice(m.Arr, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MsgDynamic) UnmarshalBytes(src []byte) { + m.N.UnmarshalUnsafe(src) + src = src[m.N.SizeBytes():] + m.Arr = make([]MsgSimple, m.N) + UnmarshalUnsafeMsg1Slice(m.Arr, src) +} + +// Randomize randomizes the contents of m. +func (m *MsgDynamic) Randomize(arrLen int) { + m.N = primitive.Uint32(arrLen) + m.Arr = make([]MsgSimple, arrLen) + for i := 0; i < arrLen; i++ { + m.Arr[i].Randomize() + } +} + +// P9Version mimics p9.TVersion and p9.Rversion. +// +// +marshal dynamic +type P9Version struct { + MSize primitive.Uint32 + Version string +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (v *P9Version) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (v *P9Version) MarshalBytes(dst []byte) { + v.MSize.MarshalUnsafe(dst) + dst = dst[v.MSize.SizeBytes():] + versionLen := primitive.Uint16(len(v.Version)) + versionLen.MarshalUnsafe(dst) + dst = dst[versionLen.SizeBytes():] + copy(dst, v.Version) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (v *P9Version) UnmarshalBytes(src []byte) { + v.MSize.UnmarshalUnsafe(src) + src = src[v.MSize.SizeBytes():] + var versionLen primitive.Uint16 + versionLen.UnmarshalUnsafe(src) + src = src[versionLen.SizeBytes():] + v.Version = string(src[:versionLen]) +} diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go new file mode 100644 index 000000000..7515355ec --- /dev/null +++ b/pkg/lisafs/server.go @@ -0,0 +1,113 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "gvisor.dev/gvisor/pkg/sync" +) + +// Server serves a filesystem tree. Multiple connections on different mount +// points can be started on a server. The server provides utilities to safely +// modify the filesystem tree across its connections (mount points). Note that +// it does not support synchronizing filesystem tree mutations across other +// servers serving the same filesystem subtree. Server also manages the +// lifecycle of all connections. +type Server struct { + // connWg counts the number of active connections being tracked. + connWg sync.WaitGroup + + // RenameMu synchronizes rename operations within this filesystem tree. + RenameMu sync.RWMutex + + // handlers is a list of RPC handlers which can be indexed by the handler's + // corresponding MID. + handlers []RPCHandler + + // mountPoints keeps track of all the mount points this server serves. + mpMu sync.RWMutex + mountPoints []*ControlFD + + // impl is the server implementation which embeds this server. + impl ServerImpl +} + +// Init must be called before first use of server. +func (s *Server) Init(impl ServerImpl) { + s.impl = impl + s.handlers = handlers[:] +} + +// InitTestOnly is the same as Init except that it allows to swap out the +// underlying handlers with something custom. This is for test only. +func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) { + s.impl = impl + s.handlers = handlers +} + +// WithRenameReadLock invokes fn with the server's rename mutex locked for +// reading. This ensures that no rename operations occur concurrently. +func (s *Server) WithRenameReadLock(fn func() error) error { + s.RenameMu.RLock() + err := fn() + s.RenameMu.RUnlock() + return err +} + +// StartConnection starts the connection on a separate goroutine and tracks it. +func (s *Server) StartConnection(c *Connection) { + s.connWg.Add(1) + go func() { + c.Run() + s.connWg.Done() + }() +} + +// Wait waits for all connections started via StartConnection() to terminate. +func (s *Server) Wait() { + s.connWg.Wait() +} + +func (s *Server) addMountPoint(root *ControlFD) { + s.mpMu.Lock() + defer s.mpMu.Unlock() + s.mountPoints = append(s.mountPoints, root) +} + +func (s *Server) forEachMountPoint(fn func(root *ControlFD)) { + s.mpMu.RLock() + defer s.mpMu.RUnlock() + for _, mp := range s.mountPoints { + fn(mp) + } +} + +// ServerImpl contains the implementation details for a Server. +// Implementations of ServerImpl should contain their associated Server by +// value as their first field. +type ServerImpl interface { + // Mount is called when a Mount RPC is made. It mounts the connection at + // mountPath. + // + // Precondition: mountPath == path.Clean(mountPath). + Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error) + + // SupportedMessages returns a list of messages that the server + // implementation supports. + SupportedMessages() []MID + + // MaxMessageSize is the maximum payload length (in bytes) that can be sent + // to this server implementation. + MaxMessageSize() uint32 +} diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go new file mode 100644 index 000000000..88210242f --- /dev/null +++ b/pkg/lisafs/sock.go @@ -0,0 +1,208 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/unet" +) + +var ( + sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes()) +) + +// sockHeader is the header present in front of each message received on a UDS. +// +// +marshal +type sockHeader struct { + payloadLen uint32 + message MID + _ uint16 // Need to make struct packed. +} + +// sockCommunicator implements Communicator. This is not thread safe. +type sockCommunicator struct { + fdTracker + sock *unet.Socket + buf []byte +} + +var _ Communicator = (*sockCommunicator)(nil) + +func newSockComm(sock *unet.Socket) *sockCommunicator { + return &sockCommunicator{ + sock: sock, + buf: make([]byte, sockHeaderLen), + } +} + +func (s *sockCommunicator) FD() int { + return s.sock.FD() +} + +func (s *sockCommunicator) destroy() { + s.sock.Close() +} + +func (s *sockCommunicator) shutdown() { + if err := s.sock.Shutdown(); err != nil { + log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err) + } +} + +func (s *sockCommunicator) resizeBuf(size uint32) { + if cap(s.buf) < int(size) { + s.buf = s.buf[:cap(s.buf)] + s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...) + } else { + s.buf = s.buf[:size] + } +} + +// PayloadBuf implements Communicator.PayloadBuf. +func (s *sockCommunicator) PayloadBuf(size uint32) []byte { + s.resizeBuf(sockHeaderLen + size) + return s.buf[sockHeaderLen : sockHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil { + return 0, 0, err + } + + return s.rcvMsg(wantFDs) +} + +// sndPrepopulatedMsg assumes that s.buf has already been populated with +// `payloadLen` bytes of data. +func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error { + header := sockHeader{payloadLen: payloadLen, message: m} + header.MarshalUnsafe(s.buf) + dataLen := sockHeaderLen + payloadLen + return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds) +} + +// writeTo writes the passed iovec to the UDS and donates any passed FDs. +func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error { + w := sock.Writer(true) + if len(fds) > 0 { + w.PackFDs(fds...) + } + + fdsUnpacked := false + for n := 0; n < dataLen; { + cur, err := w.WriteVec(iovec) + if err != nil { + return err + } + n += cur + + // Fast common path. + if n >= dataLen { + break + } + + // Consume iovecs. + for consumed := 0; consumed < cur; { + if len(iovec[0]) <= cur-consumed { + consumed += len(iovec[0]) + iovec = iovec[1:] + } else { + iovec[0] = iovec[0][cur-consumed:] + break + } + } + + if n > 0 && !fdsUnpacked { + // Don't resend any control message. + fdsUnpacked = true + w.UnpackFDs() + } + } + return nil +} + +// rcvMsg reads the message header and payload from the UDS. It also populates +// fds with any donated FDs. +func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) { + fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs) + if err != nil { + return 0, 0, err + } + for _, fd := range fds { + s.TrackFD(fd) + } + + var header sockHeader + header.UnmarshalUnsafe(s.buf) + + // No payload? We are done. + if header.payloadLen == 0 { + return header.message, 0, nil + } + + if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil { + return 0, 0, err + } + + return header.message, header.payloadLen, nil +} + +// readFrom fills the passed buffer with data from the socket. It also returns +// any donated FDs. +func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) { + r := sock.Reader(true) + r.EnableFDs(int(wantFDs)) + + var ( + fds []int + fdInit bool + ) + n := len(buf) + for got := 0; got < n; { + cur, err := r.ReadVec([][]byte{buf[got:]}) + + // Ignore EOF if cur > 0. + if err != nil && (err != io.EOF || cur == 0) { + r.CloseFDs() + return nil, err + } + + if !fdInit && cur > 0 { + fds, err = r.ExtractFDs() + if err != nil { + return nil, err + } + + fdInit = true + r.EnableFDs(0) + } + + got += cur + } + return fds, nil +} + +func closeFDs(fds []int) { + for _, fd := range fds { + if fd >= 0 { + unix.Close(fd) + } + } +} diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go new file mode 100644 index 000000000..387f4b7a8 --- /dev/null +++ b/pkg/lisafs/sock_test.go @@ -0,0 +1,217 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) { + sock1, sock2, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer sock1.Close() + defer sock2.Close() + + var testWg sync.WaitGroup + testWg.Add(2) + + go func() { + fun1(newSockComm(sock1)) + testWg.Done() + }() + + go func() { + fun2(newSockComm(sock2)) + testWg.Done() + }() + + testWg.Wait() +} + +func TestReadWrite(t *testing.T) { + // Create random data to send. + n := 10000 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + runSocketTest(t, func(comm *sockCommunicator) { + // Scatter that data into two parts using Iovecs while sending. + mid := n / 2 + if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + if _, err := readFrom(comm.sock, gotData, 0); err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + + // Make sure we got the right data. + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + }) +} + +func TestFDDonation(t *testing.T) { + n := 10 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + // Try donating FDs to these files. + path1 := "/dev/null" + path2 := "/dev" + path3 := "/dev/random" + + runSocketTest(t, func(comm *sockCommunicator) { + devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0) + defer unix.Close(devNullFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path1, err) + } + devFD, err := unix.Open(path2, unix.O_RDONLY, 0) + defer unix.Close(devFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0) + defer unix.Close(devRandomFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + fds, err := readFrom(comm.sock, gotData, 3) + if err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + defer closeFDs(fds[:]) + + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + + if len(fds) != 3 { + t.Fatalf("wanted 3 FD, got %d", len(fds)) + } + + // Check that the FDs actually point to the correct file. + compareFDWithFile(t, fds[0], path1) + compareFDWithFile(t, fds[1], path2) + compareFDWithFile(t, fds[2], path3) + }) +} + +func compareFDWithFile(t *testing.T, fd int, path string) { + var want unix.Stat_t + if err := unix.Stat(path, &want); err != nil { + t.Fatalf("stat(%s) failed: %v", path, err) + } + + var got unix.Stat_t + if err := unix.Fstat(fd, &got); err != nil { + t.Fatalf("fstat on donated FD failed: %v", err) + } + + if got.Ino != want.Ino || got.Dev != want.Dev { + t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got) + } +} + +func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error { + var payloadLen uint32 + if msg != nil { + payloadLen = uint32(msg.SizeBytes()) + msg.MarshalUnsafe(comm.PayloadBuf(payloadLen)) + } + return comm.sndPrepopulatedMsg(m, payloadLen, nil) +} + +func TestSndRcvMessage(t *testing.T) { + req := &MsgSimple{} + req.Randomize() + reqM := MID(1) + + // Create a massive random response. + var resp MsgDynamic + resp.Randomize(100) + respM := MID(2) + + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, req); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, &resp) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, req) + if err := testSndMsg(comm, respM, &resp); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func TestSndRcvMessageNoPayload(t *testing.T) { + reqM := MID(1) + respM := MID(2) + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, nil) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, nil) + if err := testSndMsg(comm, respM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) { + gotM, payloadLen, err := comm.rcvMsg(0) + if err != nil { + t.Fatalf("readMessageFrom failed: %v", err) + } + if gotM != wantM { + t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM) + } + if wantMsg == nil { + if payloadLen != 0 { + t.Errorf("no payload expect but got %d bytes", payloadLen) + } + } else { + gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable) + gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + if !reflect.DeepEqual(wantMsg, gotMsg) { + t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg) + } + } +} diff --git a/pkg/lisafs/testsuite/BUILD b/pkg/lisafs/testsuite/BUILD new file mode 100644 index 000000000..b4a542b3a --- /dev/null +++ b/pkg/lisafs/testsuite/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_library( + name = "testsuite", + testonly = True, + srcs = ["testsuite.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/lisafs", + "//pkg/unet", + "@com_github_syndtr_gocapability//capability:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go new file mode 100644 index 000000000..476ff76a5 --- /dev/null +++ b/pkg/lisafs/testsuite/testsuite.go @@ -0,0 +1,637 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testsuite provides a integration testing suite for lisafs. +// These tests are intended for servers serving the local filesystem. +package testsuite + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "testing" + "time" + + "github.com/syndtr/gocapability/capability" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/lisafs" + "gvisor.dev/gvisor/pkg/unet" +) + +// Tester is the client code using this test suite. This interface abstracts +// away all the caller specific details. +type Tester interface { + // NewServer returns a new instance of the tester server. + NewServer(t *testing.T) *lisafs.Server + + // LinkSupported returns true if the backing server supports LinkAt. + LinkSupported() bool + + // SetUserGroupIDSupported returns true if the backing server supports + // changing UID/GID for files. + SetUserGroupIDSupported() bool +} + +// RunAllLocalFSTests runs all local FS tests as subtests. +func RunAllLocalFSTests(t *testing.T, tester Tester) { + for name, testFn := range localFSTests { + t.Run(name, func(t *testing.T) { + runServerClient(t, tester, testFn) + }) + } +} + +type testFunc func(context.Context, *testing.T, Tester, lisafs.ClientFD) + +var localFSTests map[string]testFunc = map[string]testFunc{ + "Stat": testStat, + "RegularFileIO": testRegularFileIO, + "RegularFileOpen": testRegularFileOpen, + "SetStat": testSetStat, + "Allocate": testAllocate, + "StatFS": testStatFS, + "Unlink": testUnlink, + "Symlink": testSymlink, + "HardLink": testHardLink, + "Walk": testWalk, + "Rename": testRename, + "Mknod": testMknod, + "Getdents": testGetdents, +} + +func runServerClient(t *testing.T, tester Tester, testFn testFunc) { + mountPath, err := ioutil.TempDir(os.Getenv("TEST_TMPDIR"), "") + if err != nil { + t.Fatalf("creation of temporary mountpoint failed: %v", err) + } + defer os.RemoveAll(mountPath) + + // fsgofer should run with a umask of 0, because we want to preserve file + // modes exactly for testing purposes. + unix.Umask(0) + + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + + server := tester.NewServer(t) + conn, err := server.CreateConnection(serverSocket, false /* readonly */) + if err != nil { + t.Fatalf("starting connection failed: %v", err) + return + } + server.StartConnection(conn) + + c, root, err := lisafs.NewClient(clientSocket, mountPath) + if err != nil { + t.Fatalf("client creation failed: %v", err) + } + + if !root.ControlFD.Ok() { + t.Fatalf("root control FD is not valid") + } + rootFile := c.NewFD(root.ControlFD) + + ctx := context.Background() + testFn(ctx, t, tester, rootFile) + closeFD(ctx, t, rootFile) + + c.Close() // This should trigger client and server shutdown. + server.Wait() +} + +func closeFD(ctx context.Context, t testing.TB, fdLisa lisafs.ClientFD) { + if err := fdLisa.Close(ctx); err != nil { + t.Errorf("failed to close FD: %v", err) + } +} + +func statTo(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, stat *linux.Statx) { + if err := fdLisa.StatTo(ctx, stat); err != nil { + t.Fatalf("stat failed: %v", err) + } +} + +func openCreateFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx, lisafs.ClientFD, int) { + child, childFD, childHostFD, err := fdLisa.OpenCreateAt(ctx, name, unix.O_RDWR, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("OpenCreateAt failed: %v", err) + } + if childHostFD == -1 { + t.Error("no host FD donated") + } + client := fdLisa.Client() + return client.NewFD(child.ControlFD), child.Stat, fdLisa.Client().NewFD(childFD), childHostFD +} + +func openFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, flags uint32, isReg bool) (lisafs.ClientFD, int) { + newFD, hostFD, err := fdLisa.OpenAt(ctx, flags) + if err != nil { + t.Fatalf("OpenAt failed: %v", err) + } + if hostFD == -1 && isReg { + t.Error("no host FD donated") + } + return fdLisa.Client().NewFD(newFD), hostFD +} + +func unlinkFile(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, isDir bool) { + var flags uint32 + if isDir { + flags = unix.AT_REMOVEDIR + } + if err := dir.UnlinkAt(ctx, name, flags); err != nil { + t.Errorf("unlinking file %s failed: %v", name, err) + } +} + +func symlink(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name, target string) (lisafs.ClientFD, linux.Statx) { + linkIno, err := dir.SymlinkAt(ctx, name, target, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("symlink failed: %v", err) + } + return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat +} + +func link(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, target lisafs.ClientFD) (lisafs.ClientFD, linux.Statx) { + linkIno, err := dir.LinkAt(ctx, target.ID(), name) + if err != nil { + t.Fatalf("link failed: %v", err) + } + return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat +} + +func mkdir(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) { + childIno, err := dir.MkdirAt(ctx, name, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("mkdir failed: %v", err) + } + return dir.Client().NewFD(childIno.ControlFD), childIno.Stat +} + +func mknod(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) { + nodeIno, err := dir.MknodAt(ctx, name, unix.S_IFREG|0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()), 0, 0) + if err != nil { + t.Fatalf("mknod failed: %v", err) + } + return dir.Client().NewFD(nodeIno.ControlFD), nodeIno.Stat +} + +func walk(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []lisafs.Inode { + _, inodes, err := dir.WalkMultiple(ctx, names) + if err != nil { + t.Fatalf("walk failed while trying to walk components %+v: %v", names, err) + } + return inodes +} + +func walkStat(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []linux.Statx { + stats, err := dir.WalkStat(ctx, names) + if err != nil { + t.Fatalf("walk failed while trying to walk components %+v: %v", names, err) + } + return stats +} + +func writeFD(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, buf []byte) error { + count, err := fdLisa.Write(ctx, buf, off) + if err != nil { + return err + } + if int(count) != len(buf) { + t.Errorf("partial write: buf size = %d, written = %d", len(buf), count) + } + return nil +} + +func readFDAndCmp(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, want []byte) { + buf := make([]byte, len(want)) + n, err := fdLisa.Read(ctx, buf, off) + if err != nil { + t.Errorf("read failed: %v", err) + return + } + if int(n) != len(want) { + t.Errorf("partial read: buf size = %d, read = %d", len(want), n) + return + } + if bytes.Compare(buf, want) != 0 { + t.Errorf("bytes read differ from what was expected: want = %v, got = %v", want, buf) + } +} + +func allocateAndVerify(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, length uint64) { + if err := fdLisa.Allocate(ctx, 0, off, length); err != nil { + t.Fatalf("fallocate failed: %v", err) + } + + var stat linux.Statx + statTo(ctx, t, fdLisa, &stat) + if want := off + length; stat.Size != want { + t.Errorf("incorrect file size after allocate: expected %d, got %d", off+length, stat.Size) + } +} + +func cmpStatx(t *testing.T, want, got linux.Statx) { + if got.Mask&unix.STATX_MODE != 0 && want.Mask&unix.STATX_MODE != 0 { + if got.Mode != want.Mode { + t.Errorf("mode differs: want %d, got %d", want.Mode, got.Mode) + } + } + if got.Mask&unix.STATX_INO != 0 && want.Mask&unix.STATX_INO != 0 { + if got.Ino != want.Ino { + t.Errorf("inode number differs: want %d, got %d", want.Ino, got.Ino) + } + } + if got.Mask&unix.STATX_NLINK != 0 && want.Mask&unix.STATX_NLINK != 0 { + if got.Nlink != want.Nlink { + t.Errorf("nlink differs: want %d, got %d", want.Nlink, got.Nlink) + } + } + if got.Mask&unix.STATX_UID != 0 && want.Mask&unix.STATX_UID != 0 { + if got.UID != want.UID { + t.Errorf("UID differs: want %d, got %d", want.UID, got.UID) + } + } + if got.Mask&unix.STATX_GID != 0 && want.Mask&unix.STATX_GID != 0 { + if got.GID != want.GID { + t.Errorf("GID differs: want %d, got %d", want.GID, got.GID) + } + } + if got.Mask&unix.STATX_SIZE != 0 && want.Mask&unix.STATX_SIZE != 0 { + if got.Size != want.Size { + t.Errorf("size differs: want %d, got %d", want.Size, got.Size) + } + } + if got.Mask&unix.STATX_BLOCKS != 0 && want.Mask&unix.STATX_BLOCKS != 0 { + if got.Blocks != want.Blocks { + t.Errorf("blocks differs: want %d, got %d", want.Blocks, got.Blocks) + } + } + if got.Mask&unix.STATX_ATIME != 0 && want.Mask&unix.STATX_ATIME != 0 { + if got.Atime != want.Atime { + t.Errorf("atime differs: want %d, got %d", want.Atime, got.Atime) + } + } + if got.Mask&unix.STATX_MTIME != 0 && want.Mask&unix.STATX_MTIME != 0 { + if got.Mtime != want.Mtime { + t.Errorf("mtime differs: want %d, got %d", want.Mtime, got.Mtime) + } + } + if got.Mask&unix.STATX_CTIME != 0 && want.Mask&unix.STATX_CTIME != 0 { + if got.Ctime != want.Ctime { + t.Errorf("ctime differs: want %d, got %d", want.Ctime, got.Ctime) + } + } +} + +func hasCapability(c capability.Cap) bool { + caps, err := capability.NewPid2(os.Getpid()) + if err != nil { + return false + } + if err := caps.Load(); err != nil { + return false + } + return caps.Get(capability.EFFECTIVE, c) +} + +func testStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + var rootStat linux.Statx + if err := root.StatTo(ctx, &rootStat); err != nil { + t.Errorf("stat on the root dir failed: %v", err) + } + + if ftype := rootStat.Mode & unix.S_IFMT; ftype != unix.S_IFDIR { + t.Errorf("root inode is not a directory, file type = %d", ftype) + } +} + +func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + // Test Read/Write RPCs. + data := make([]byte, 100) + rand.Read(data) + if err := writeFD(ctx, t, fd, 0, data); err != nil { + t.Fatalf("write failed: %v", err) + } + readFDAndCmp(ctx, t, fd, 0, data) + readFDAndCmp(ctx, t, fd, 50, data[50:]) + + // Make sure the host FD is configured properly. + hostReadData := make([]byte, len(data)) + if n, err := unix.Pread(hostFD, hostReadData, 0); err != nil { + t.Errorf("host read failed: %v", err) + } else if n != len(hostReadData) { + t.Errorf("partial read: buf size = %d, read = %d", len(hostReadData), n) + } else if bytes.Compare(hostReadData, data) != 0 { + t.Errorf("bytes read differ from what was expected: want = %v, got = %v", data, hostReadData) + } + + // Test syncing the writable FD. + if err := fd.Sync(ctx); err != nil { + t.Errorf("syncing the FD failed: %v", err) + } +} + +func testRegularFileOpen(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + // Open a readonly FD and try writing to it to get an EBADF. + roFile, roHostFD := openFile(ctx, t, controlFile, unix.O_RDONLY, true /* isReg */) + defer closeFD(ctx, t, roFile) + defer unix.Close(roHostFD) + if err := writeFD(ctx, t, roFile, 0, []byte{1, 2, 3}); err != unix.EBADF { + t.Errorf("writing to read only FD should generate EBADF, but got %v", err) + } +} + +func testSetStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + now := time.Now() + wantStat := linux.Statx{ + Mask: unix.STATX_MODE | unix.STATX_ATIME | unix.STATX_MTIME | unix.STATX_SIZE, + Mode: 0760, + UID: uint32(unix.Getuid()), + GID: uint32(unix.Getgid()), + Size: 50, + Atime: linux.NsecToStatxTimestamp(now.UnixNano()), + Mtime: linux.NsecToStatxTimestamp(now.UnixNano()), + } + if tester.SetUserGroupIDSupported() { + wantStat.Mask |= unix.STATX_UID | unix.STATX_GID + } + failureMask, failureErr, err := controlFile.SetStat(ctx, &wantStat) + if err != nil { + t.Fatalf("setstat failed: %v", err) + } + if failureMask != 0 { + t.Fatalf("some setstat operations failed: failureMask = %#b, failureErr = %v", failureMask, failureErr) + } + + // Verify that attributes were updated. + var gotStat linux.Statx + statTo(ctx, t, controlFile, &gotStat) + if gotStat.Mode&07777 != wantStat.Mode || + gotStat.Size != wantStat.Size || + gotStat.Atime.ToNsec() != wantStat.Atime.ToNsec() || + gotStat.Mtime.ToNsec() != wantStat.Mtime.ToNsec() || + (tester.SetUserGroupIDSupported() && (uint32(gotStat.UID) != wantStat.UID || uint32(gotStat.GID) != wantStat.GID)) { + t.Errorf("setStat did not update file correctly: setStat = %+v, stat = %+v", wantStat, gotStat) + } +} + +func testAllocate(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + allocateAndVerify(ctx, t, fd, 0, 40) + allocateAndVerify(ctx, t, fd, 20, 100) +} + +func testStatFS(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + var statFS lisafs.StatFS + if err := root.StatFSTo(ctx, &statFS); err != nil { + t.Errorf("statfs failed: %v", err) + } +} + +func testUnlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + unlinkFile(ctx, t, root, name, false /* isDir */) + if inodes := walk(ctx, t, root, []string{name}); len(inodes) > 0 { + t.Errorf("deleted file should not be generating inodes on walk: %+v", inodes) + } +} + +func testSymlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + target := "/tmp/some/path" + name := "symlinkFile" + link, linkStat := symlink(ctx, t, root, name, target) + defer closeFD(ctx, t, link) + + if linkStat.Mode&unix.S_IFMT != unix.S_IFLNK { + t.Errorf("stat return from symlink RPC indicates that the inode is not a symlink: mode = %d", linkStat.Mode) + } + + if gotTarget, err := link.ReadLinkAt(ctx); err != nil { + t.Fatalf("readlink failed: %v", err) + } else if gotTarget != target { + t.Errorf("readlink return incorrect target: expected %q, got %q", target, gotTarget) + } +} + +func testHardLink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + if !tester.LinkSupported() { + t.Skipf("server does not support LinkAt RPC") + } + if !hasCapability(capability.CAP_DAC_READ_SEARCH) { + t.Skipf("TestHardLink requires CAP_DAC_READ_SEARCH, running as %d", unix.Getuid()) + } + name := "tempFile" + controlFile, fileIno, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + link, linkStat := link(ctx, t, root, name, controlFile) + defer closeFD(ctx, t, link) + + if linkStat.Ino != fileIno.Ino { + t.Errorf("hard linked files have different inode numbers: %d %d", linkStat.Ino, fileIno.Ino) + } + if linkStat.DevMinor != fileIno.DevMinor { + t.Errorf("hard linked files have different minor device numbers: %d %d", linkStat.DevMinor, fileIno.DevMinor) + } + if linkStat.DevMajor != fileIno.DevMajor { + t.Errorf("hard linked files have different major device numbers: %d %d", linkStat.DevMajor, fileIno.DevMajor) + } +} + +func testWalk(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + // Create 10 nested directories. + n := 10 + curDir := root + + dirNames := make([]string, 0, n) + for i := 0; i < n; i++ { + name := fmt.Sprintf("tmpdir-%d", i) + childDir, _ := mkdir(ctx, t, curDir, name) + defer closeFD(ctx, t, childDir) + defer unlinkFile(ctx, t, curDir, name, true /* isDir */) + + curDir = childDir + dirNames = append(dirNames, name) + } + + // Walk all these directories. Add some junk at the end which should not be + // walked on. + dirNames = append(dirNames, []string{"a", "b", "c"}...) + inodes := walk(ctx, t, root, dirNames) + if len(inodes) != n { + t.Errorf("walk returned the incorrect number of inodes: wanted %d, got %d", n, len(inodes)) + } + + // Close all control FDs and collect stat results for all dirs including + // the root directory. + dirStats := make([]linux.Statx, 0, n+1) + var stat linux.Statx + statTo(ctx, t, root, &stat) + dirStats = append(dirStats, stat) + for _, inode := range inodes { + dirStats = append(dirStats, inode.Stat) + closeFD(ctx, t, root.Client().NewFD(inode.ControlFD)) + } + + // Test WalkStat which additonally returns Statx for root because the first + // path component is "". + dirNames = append([]string{""}, dirNames...) + gotStats := walkStat(ctx, t, root, dirNames) + if len(gotStats) != len(dirStats) { + t.Errorf("walkStat returned the incorrect number of statx: wanted %d, got %d", len(dirStats), len(gotStats)) + } else { + for i := range gotStats { + cmpStatx(t, dirStats[i], gotStats[i]) + } + } +} + +func testRename(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + tempFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, tempFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + tempDir, _ := mkdir(ctx, t, root, "tempDir") + defer closeFD(ctx, t, tempDir) + + // Move tempFile into tempDir. + if err := tempFile.RenameTo(ctx, tempDir.ID(), "movedFile"); err != nil { + t.Fatalf("rename failed: %v", err) + } + + inodes := walkStat(ctx, t, root, []string{"tempDir", "movedFile"}) + if len(inodes) != 2 { + t.Errorf("expected 2 files on walk but only found %d", len(inodes)) + } +} + +func testMknod(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "namedPipe" + pipeFile, pipeStat := mknod(ctx, t, root, name) + defer closeFD(ctx, t, pipeFile) + + var stat linux.Statx + statTo(ctx, t, pipeFile, &stat) + + if stat.Mode != pipeStat.Mode { + t.Errorf("mknod mode is incorrect: want %d, got %d", pipeStat.Mode, stat.Mode) + } + if stat.UID != pipeStat.UID { + t.Errorf("mknod UID is incorrect: want %d, got %d", pipeStat.UID, stat.UID) + } + if stat.GID != pipeStat.GID { + t.Errorf("mknod GID is incorrect: want %d, got %d", pipeStat.GID, stat.GID) + } +} + +func testGetdents(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + tempDir, _ := mkdir(ctx, t, root, "tempDir") + defer closeFD(ctx, t, tempDir) + defer unlinkFile(ctx, t, root, "tempDir", true /* isDir */) + + // Create 10 files in tempDir. + n := 10 + fileStats := make(map[string]linux.Statx) + for i := 0; i < n; i++ { + name := fmt.Sprintf("file-%d", i) + newFile, fileStat := mknod(ctx, t, tempDir, name) + defer closeFD(ctx, t, newFile) + defer unlinkFile(ctx, t, tempDir, name, false /* isDir */) + + fileStats[name] = fileStat + } + + // Use opened directory FD for getdents. + openDirFile, _ := openFile(ctx, t, tempDir, unix.O_RDONLY, false /* isReg */) + defer closeFD(ctx, t, openDirFile) + + dirents := make([]lisafs.Dirent64, 0, n) + for i := 0; i < n+2; i++ { + gotDirents, err := openDirFile.Getdents64(ctx, 40) + if err != nil { + t.Fatalf("getdents failed: %v", err) + } + if len(gotDirents) == 0 { + break + } + for _, dirent := range gotDirents { + if dirent.Name != "." && dirent.Name != ".." { + dirents = append(dirents, dirent) + } + } + } + + if len(dirents) != n { + t.Errorf("got incorrect number of dirents: wanted %d, got %d", n, len(dirents)) + } + for _, dirent := range dirents { + stat, ok := fileStats[string(dirent.Name)] + if !ok { + t.Errorf("received a dirent that was not created: %+v", dirent) + continue + } + + if dirent.Type != unix.DT_REG { + t.Errorf("dirent type of %s is incorrect: %d", dirent.Name, dirent.Type) + } + if uint64(dirent.Ino) != stat.Ino { + t.Errorf("dirent ino of %s is incorrect: want %d, got %d", dirent.Name, stat.Ino, dirent.Ino) + } + if uint32(dirent.DevMinor) != stat.DevMinor { + t.Errorf("dirent dev minor of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMinor, dirent.DevMinor) + } + if uint32(dirent.DevMajor) != stat.DevMajor { + t.Errorf("dirent dev major of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMajor, dirent.DevMajor) + } + } +} |