summaryrefslogtreecommitdiffhomepage
path: root/pkg/lisafs/client.go
diff options
context:
space:
mode:
authorAyush Ranjan <ayushranjan@google.com>2021-09-20 11:41:58 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-20 11:44:11 -0700
commitd139087b3f7be1fdc1af151e8548c4bbdc199875 (patch)
tree627c9a4aa9b1d97b4f92895e454541fb9173a8ed /pkg/lisafs/client.go
parent89a0011c100d778cd8a0cc5e1b14996461c66629 (diff)
[lisa] lisafs package POC.
This change mainly aims to define the semantics of communication for the LISAFS (LInux SAndbox Filesystem) protocol. This protocol aims to replace 9P and intends to bring some performance benefits with it. Some of the notable differences from the p9 package are: - Now the server implementations own the handlers. - As a result, there is no verbose interface like `p9.File` that all servers need to implement. Different implementations can extend their File implementations to varying degrees without imposing those extensions to other server implementations that might not have anything to do with those features. - If a server implementation adds a new RPC message, other implementations are not compelled to support it. I wrote a benchmark `BenchmarkSendRecv` in connection_test.go which competes with p9's `BenchmarkSendRecvChannel`. Running these on an AMD Milan machine shows that lisafs is **45%** faster. **With 9P** goos: linux goarch: amd64 pkg: gvisor/pkg/p9/p9 cpu: AMD EPYC 7B13 64-Core Processor BenchmarkSendRecvLegacy-256 82830 14053 ns/op 633 B/op 23 allocs/op BenchmarkSendRecvChannel-256 776971 1551 ns/op 184 B/op 6 allocs/op **With lisafs** goos: linux goarch: amd64 pkg: pkg/lisafs/connection_test cpu: AMD EPYC 7B13 64-Core Processor BenchmarkSendRecv-256 1399610 853.5 ns/op 48 B/op 2 allocs/op Fixes #5464 PiperOrigin-RevId: 397803163
Diffstat (limited to 'pkg/lisafs/client.go')
-rw-r--r--pkg/lisafs/client.go377
1 files changed, 377 insertions, 0 deletions
diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go
new file mode 100644
index 000000000..c99f8c73d
--- /dev/null
+++ b/pkg/lisafs/client.go
@@ -0,0 +1,377 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "math"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Client helps manage a connection to the lisafs server and pass messages
+// efficiently. There is a 1:1 mapping between a Connection and a Client.
+type Client struct {
+ // sockComm is the main socket by which this connections is established.
+ // Communication over the socket is synchronized by sockMu.
+ sockMu sync.Mutex
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels and availableChannels.
+ channelsMu sync.Mutex
+ // channels tracks all the channels.
+ channels []*channel
+ // availableChannels is a LIFO (stack) of channels available to be used.
+ availableChannels []*channel
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // watchdogWg only holds the watchdog goroutine.
+ watchdogWg sync.WaitGroup
+
+ // supported caches information about which messages are supported. It is
+ // indexed by MID. An MID is supported if supported[MID] is true.
+ supported []bool
+
+ // maxMessageSize is the maximum payload length (in bytes) that can be sent.
+ // It is initialized on Mount and is immutable.
+ maxMessageSize uint32
+}
+
+// NewClient creates a new client for communication with the server. It mounts
+// the server and creates channels for fast IPC. NewClient takes ownership over
+// the passed socket. On success, it returns the initialized client along with
+// the root Inode.
+func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) {
+ maxChans := maxChannels()
+ c := &Client{
+ sockComm: newSockComm(sock),
+ channels: make([]*channel, 0, maxChans),
+ availableChannels: make([]*channel, 0, maxChans),
+ maxMessageSize: 1 << 20, // 1 MB for now.
+ }
+
+ // Start a goroutine to check socket health. This goroutine is also
+ // responsible for client cleanup.
+ c.watchdogWg.Add(1)
+ go c.watchdog()
+
+ // Clean everything up if anything fails.
+ cu := cleanup.Make(func() {
+ c.Close()
+ })
+ defer cu.Clean()
+
+ // Mount the server first. Assume Mount is supported so that we can make the
+ // Mount RPC below.
+ c.supported = make([]bool, Mount+1)
+ c.supported[Mount] = true
+ mountMsg := MountReq{
+ MountPath: SizedString(mountPath),
+ }
+ var mountResp MountResp
+ if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil {
+ return nil, nil, err
+ }
+
+ // Initialize client.
+ c.maxMessageSize = uint32(mountResp.MaxMessageSize)
+ var maxSuppMID MID
+ for _, suppMID := range mountResp.SupportedMs {
+ if suppMID > maxSuppMID {
+ maxSuppMID = suppMID
+ }
+ }
+ c.supported = make([]bool, maxSuppMID+1)
+ for _, suppMID := range mountResp.SupportedMs {
+ c.supported[suppMID] = true
+ }
+
+ // Create channels parallely so that channels can be used to create more
+ // channels and costly initialization like flipcall.Endpoint.Connect can
+ // proceed parallely.
+ var channelsWg sync.WaitGroup
+ channelErrs := make([]error, maxChans)
+ for i := 0; i < maxChans; i++ {
+ channelsWg.Add(1)
+ curChanID := i
+ go func() {
+ defer channelsWg.Done()
+ ch, err := c.createChannel()
+ if err != nil {
+ log.Warningf("channel creation failed: %v", err)
+ channelErrs[curChanID] = err
+ return
+ }
+ c.channelsMu.Lock()
+ c.channels = append(c.channels, ch)
+ c.availableChannels = append(c.availableChannels, ch)
+ c.channelsMu.Unlock()
+ }()
+ }
+ channelsWg.Wait()
+
+ for _, channelErr := range channelErrs {
+ // Return the first non-nil channel creation error.
+ if channelErr != nil {
+ return nil, nil, channelErr
+ }
+ }
+ cu.Release()
+
+ return c, &mountResp.Root, nil
+}
+
+func (c *Client) watchdog() {
+ defer c.watchdogWg.Done()
+
+ events := []unix.PollFd{
+ {
+ Fd: int32(c.sockComm.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("lisafs.Client.watch(): %v", err)
+ } else if n != 1 {
+ log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Shutdown all active channels and wait for them to complete.
+ c.shutdownActiveChans()
+ c.activeWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ c.channelsMu.Unlock()
+
+ // Close main socket.
+ c.sockComm.destroy()
+}
+
+func (c *Client) shutdownActiveChans() {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ availableChans := make(map[*channel]bool)
+ for _, ch := range c.availableChannels {
+ availableChans[ch] = true
+ }
+ for _, ch := range c.channels {
+ // A channel that is not available is active.
+ if _, ok := availableChans[ch]; !ok {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.shutdown()
+ }
+ }
+
+ // Prevent channels from becoming available and serving new requests.
+ c.availableChannels = nil
+}
+
+// Close shuts down the main socket and waits for the watchdog to clean up.
+func (c *Client) Close() {
+ // This shutdown has no effect if the watchdog has already fired and closed
+ // the main socket.
+ c.sockComm.shutdown()
+ c.watchdogWg.Wait()
+}
+
+func (c *Client) createChannel() (*channel, error) {
+ var chanResp ChannelResp
+ var fds [2]int
+ if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil {
+ return nil, err
+ }
+ if fds[0] < 0 || fds[1] < 0 {
+ closeFDs(fds[:])
+ return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds)
+ }
+
+ // Lets create the channel.
+ defer closeFDs(fds[:1]) // The data FD is not needed after this.
+ desc := flipcall.PacketWindowDescriptor{
+ FD: fds[0],
+ Offset: chanResp.dataOffset,
+ Length: int(chanResp.dataLength),
+ }
+
+ ch := &channel{}
+ if err := ch.data.Init(flipcall.ClientSide, desc); err != nil {
+ closeFDs(fds[1:])
+ return nil, err
+ }
+ ch.fdChan.Init(fds[1]) // fdChan now owns this FD.
+
+ // Only a connected channel is usable.
+ if err := ch.data.Connect(); err != nil {
+ ch.destroy()
+ return nil, err
+ }
+ return ch, nil
+}
+
+// IsSupported returns true if this connection supports the passed message.
+func (c *Client) IsSupported(m MID) bool {
+ return int(m) < len(c.supported) && c.supported[m]
+}
+
+// SndRcvMessage invokes reqMarshal to marshal the request onto the payload
+// buffer, wakes up the server to process the request, waits for the response
+// and invokes respUnmarshal with the response payload. respFDs is populated
+// with the received FDs, extra fields are set to -1.
+//
+// Note that the function arguments intentionally accept marshal.Marshallable
+// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of
+// directly accepting the marshal.Marshallable interface. Even though just
+// accepting marshal.Marshallable is cleaner, it leads to a heap allocation
+// (even if that interface variable itself does not escape). In other words,
+// implicit conversion to an interface leads to an allocation.
+//
+// Precondition: reqMarshal and respUnmarshal must be non-nil.
+func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error {
+ if !c.IsSupported(m) {
+ return unix.EOPNOTSUPP
+ }
+ if payloadLen > c.maxMessageSize {
+ log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize)
+ return unix.EIO
+ }
+ wantFDs := len(respFDs)
+ if wantFDs > math.MaxUint8 {
+ log.Warningf("want too many FDs: %d", wantFDs)
+ return unix.EINVAL
+ }
+
+ // Acquire a communicator.
+ comm := c.acquireCommunicator()
+ defer c.releaseCommunicator(comm)
+
+ // Marshal the request into comm's payload buffer and make the RPC.
+ reqMarshal(comm.PayloadBuf(payloadLen))
+ respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs))
+
+ // Handle FD donation.
+ rcvFDs := comm.ReleaseFDs()
+ if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 {
+ // releasedFDs is memory owned by comm which can not be returned to caller.
+ // Copy it into the caller's buffer.
+ numFDCopied := copy(respFDs, rcvFDs)
+ if numFDCopied < numRcvFDs {
+ log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs)
+ closeFDs(rcvFDs[numFDCopied:])
+ }
+ if numFDCopied < wantFDs {
+ for i := numFDCopied; i < wantFDs; i++ {
+ respFDs[i] = -1
+ }
+ }
+ }
+
+ // Error cases.
+ if err != nil {
+ closeFDs(respFDs)
+ return err
+ }
+ if respM == Error {
+ closeFDs(respFDs)
+ var resp ErrorResp
+ resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen))
+ return unix.Errno(resp.errno)
+ }
+ if respM != m {
+ closeFDs(respFDs)
+ log.Warningf("sent %d message but got %d in response", m, respM)
+ return unix.EINVAL
+ }
+
+ // Success. The payload must be unmarshalled *before* comm is released.
+ respUnmarshal(comm.PayloadBuf(respPayloadLen))
+ return nil
+}
+
+// Postcondition: releaseCommunicator() must be called on the returned value.
+func (c *Client) acquireCommunicator() Communicator {
+ // Prefer using channel over socket because:
+ // - Channel uses a shared memory region for passing messages. IO from shared
+ // memory is faster and does not involve making a syscall.
+ // - No intermediate buffer allocation needed. With a channel, the message
+ // can be directly pasted into the shared memory region.
+ if ch := c.getChannel(); ch != nil {
+ return ch
+ }
+
+ c.sockMu.Lock()
+ return c.sockComm
+}
+
+// Precondition: comm must have been acquired via acquireCommunicator().
+func (c *Client) releaseCommunicator(comm Communicator) {
+ switch t := comm.(type) {
+ case *sockCommunicator:
+ c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator().
+ case *channel:
+ c.releaseChannel(t)
+ default:
+ panic(fmt.Sprintf("unknown communicator type %T", t))
+ }
+}
+
+// getChannel pops a channel from the available channels stack. The caller must
+// release the channel after use.
+func (c *Client) getChannel() *channel {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ if len(c.availableChannels) == 0 {
+ return nil
+ }
+
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ c.activeWg.Add(1)
+ return ch
+}
+
+// releaseChannel pushes the passed channel onto the available channel stack if
+// reinsert is true.
+func (c *Client) releaseChannel(ch *channel) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ // If availableChannels is nil, then watchdog has fired and the client is
+ // shutting down. So don't make this channel available again.
+ if !ch.dead && c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.activeWg.Done()
+}