summaryrefslogtreecommitdiffhomepage
path: root/pkg/lisafs/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/lisafs/connection.go')
-rw-r--r--pkg/lisafs/connection.go304
1 files changed, 304 insertions, 0 deletions
diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go
new file mode 100644
index 000000000..8dba4805f
--- /dev/null
+++ b/pkg/lisafs/connection.go
@@ -0,0 +1,304 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Connection represents a connection between a mount point in the client and a
+// mount point in the server. It is owned by the server on which it was started
+// and facilitates communication with the client mount.
+//
+// Each connection is set up using a unix domain socket. One end is owned by
+// the server and the other end is owned by the client. The connection may
+// spawn additional comunicational channels for the same mount for increased
+// RPC concurrency.
+type Connection struct {
+ // server is the server on which this connection was created. It is immutably
+ // associated with it for its entire lifetime.
+ server *Server
+
+ // mounted is a one way flag indicating whether this connection has been
+ // mounted correctly and the server is initialized properly.
+ mounted bool
+
+ // readonly indicates if this connection is readonly. All write operations
+ // will fail with EROFS.
+ readonly bool
+
+ // sockComm is the main socket by which this connections is established.
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+ // channels keeps track of all open channels.
+ channels []*channel
+
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // reqGate counts requests that are still being handled.
+ reqGate sync.Gate
+
+ // channelAlloc is used to allocate memory for channels.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ fdsMu sync.RWMutex
+ // fds keeps tracks of open FDs on this server. It is protected by fdsMu.
+ fds map[FDID]genericFD
+ // nextFDID is the next available FDID. It is protected by fdsMu.
+ nextFDID FDID
+}
+
+// CreateConnection initializes a new connection - creating a server if
+// required. The connection must be started separately.
+func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) {
+ c := &Connection{
+ sockComm: newSockComm(sock),
+ server: s,
+ readonly: readonly,
+ channels: make([]*channel, 0, maxChannels()),
+ fds: make(map[FDID]genericFD),
+ nextFDID: InvalidFDID + 1,
+ }
+
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return nil, err
+ }
+ c.channelAlloc = alloc
+ return c, nil
+}
+
+// Server returns the associated server.
+func (c *Connection) Server() *Server {
+ return c.server
+}
+
+// ServerImpl returns the associated server implementation.
+func (c *Connection) ServerImpl() ServerImpl {
+ return c.server.impl
+}
+
+// Run defines the lifecycle of a connection.
+func (c *Connection) Run() {
+ defer c.close()
+
+ // Start handling requests on this connection.
+ for {
+ m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */)
+ if err != nil {
+ log.Debugf("sock read failed, closing connection: %v", err)
+ return
+ }
+
+ respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen)
+ err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs)
+ closeFDs(respFDs)
+ if err != nil {
+ log.Debugf("sock write failed, closing connection: %v", err)
+ return
+ }
+ }
+}
+
+// service starts servicing the passed channel until the channel is shutdown.
+// This is a blocking method and hence must be called in a separate goroutine.
+func (c *Connection) service(ch *channel) error {
+ rcvDataLen, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rcvDataLen > 0 {
+ m, payloadLen, err := ch.rcvMsg(rcvDataLen)
+ if err != nil {
+ return err
+ }
+ respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen)
+ numFDs := ch.sendFDs(respFDs)
+ closeFDs(respFDs)
+
+ ch.marshalHdr(respM, numFDs)
+ rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) {
+ resp := &ErrorResp{errno: uint32(err)}
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return Error, respLen, nil
+}
+
+func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) {
+ if !c.reqGate.Enter() {
+ // c.close() has been called; the connection is shutting down.
+ return c.respondError(comm, unix.ECONNRESET)
+ }
+ defer c.reqGate.Leave()
+
+ if !c.mounted && m != Mount {
+ log.Warningf("connection must first be mounted")
+ return c.respondError(comm, unix.EINVAL)
+ }
+
+ // Check if the message is supported for forward compatibility.
+ if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil {
+ log.Warningf("received request which is not supported by the server, MID = %d", m)
+ return c.respondError(comm, unix.EOPNOTSUPP)
+ }
+
+ // Try handling the request.
+ respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen)
+ fds := comm.ReleaseFDs()
+ if err != nil {
+ closeFDs(fds)
+ return c.respondError(comm, p9.ExtractErrno(err))
+ }
+
+ return m, respPayloadLen, fds
+}
+
+func (c *Connection) close() {
+ // Wait for completion of all inflight requests. This is mostly so that if
+ // a request is stuck, the sandbox supervisor has the opportunity to kill
+ // us with SIGABRT to get a stack dump of the offending handler.
+ c.reqGate.Close()
+
+ // Shutdown and clean up channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.shutdown()
+ }
+ c.activeWg.Wait()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ // This is to prevent additional channels from being created.
+ c.channels = nil
+ c.channelsMu.Unlock()
+
+ // Free the channel memory.
+ if c.channelAlloc != nil {
+ c.channelAlloc.Destroy()
+ }
+
+ // Ensure the connection is closed.
+ c.sockComm.destroy()
+
+ // Cleanup all FDs.
+ c.fdsMu.Lock()
+ for fdid := range c.fds {
+ fd := c.removeFDLocked(fdid)
+ fd.DecRef(nil) // Drop the ref held by c.
+ }
+ c.fdsMu.Unlock()
+}
+
+// The caller gains a ref on the FD on success.
+func (c *Connection) lookupFD(id FDID) (genericFD, error) {
+ c.fdsMu.RLock()
+ defer c.fdsMu.RUnlock()
+
+ fd, ok := c.fds[id]
+ if !ok {
+ return nil, unix.EBADF
+ }
+ fd.IncRef()
+ return fd, nil
+}
+
+// LookupControlFD retrieves the control FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ cfd, ok := fd.(*ControlFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return cfd, nil
+}
+
+// LookupOpenFD retrieves the open FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ ofd, ok := fd.(*OpenFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return ofd, nil
+}
+
+// insertFD inserts the passed fd into the internal datastructure to track FDs.
+// The caller must hold a ref on fd which is transferred to the connection.
+func (c *Connection) insertFD(fd genericFD) FDID {
+ c.fdsMu.Lock()
+ defer c.fdsMu.Unlock()
+
+ res := c.nextFDID
+ c.nextFDID++
+ if c.nextFDID < res {
+ panic("ran out of FDIDs")
+ }
+ c.fds[res] = fd
+ return res
+}
+
+// RemoveFD makes c stop tracking the passed FDID and drops its ref on it.
+func (c *Connection) RemoveFD(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.DecRef(nil)
+ }
+}
+
+// removeFDLocked makes c stop tracking the passed FDID. Note that the caller
+// must drop ref on the returned fd (preferably without holding c.fdsMu).
+//
+// Precondition: c.fdsMu is locked.
+func (c *Connection) removeFDLocked(id FDID) genericFD {
+ fd := c.fds[id]
+ if fd == nil {
+ log.Warningf("removeFDLocked called on non-existent FDID %d", id)
+ return nil
+ }
+ delete(c.fds, id)
+ return fd
+}