summaryrefslogtreecommitdiffhomepage
path: root/pkg/p9/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/p9/client.go')
-rw-r--r--pkg/p9/client.go575
1 files changed, 575 insertions, 0 deletions
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
new file mode 100644
index 000000000..71e944c30
--- /dev/null
+++ b/pkg/p9/client.go
@@ -0,0 +1,575 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "errors"
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/pool"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// ErrOutOfTags indicates no tags are available.
+var ErrOutOfTags = errors.New("out of tags -- messages lost?")
+
+// ErrOutOfFIDs indicates no more FIDs are available.
+var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?")
+
+// ErrUnexpectedTag indicates a response with an unexpected tag was received.
+var ErrUnexpectedTag = errors.New("unexpected tag in response")
+
+// ErrVersionsExhausted indicates that all versions to negotiate have been exhausted.
+var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate")
+
+// ErrBadVersionString indicates that the version string is malformed or unsupported.
+var ErrBadVersionString = errors.New("bad version string")
+
+// ErrBadResponse indicates the response didn't match the request.
+type ErrBadResponse struct {
+ Got MsgType
+ Want MsgType
+}
+
+// Error returns a highly descriptive error.
+func (e *ErrBadResponse) Error() string {
+ return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want)
+}
+
+// response is the asynchronous return from recv.
+//
+// This is used in the pending map below.
+type response struct {
+ r message
+ done chan error
+}
+
+var responsePool = sync.Pool{
+ New: func() interface{} {
+ return &response{
+ done: make(chan error, 1),
+ }
+ },
+}
+
+// Client is at least a 9P2000.L client.
+type Client struct {
+ // socket is the connected socket.
+ socket *unet.Socket
+
+ // tagPool is the collection of available tags.
+ tagPool pool.Pool
+
+ // fidPool is the collection of available fids.
+ fidPool pool.Pool
+
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
+ // pending is the set of pending messages.
+ pending map[Tag]*response
+ pendingMu sync.Mutex
+
+ // sendMu is the lock for sending a request.
+ sendMu sync.Mutex
+
+ // recvr is essentially a mutex for calling recv.
+ //
+ // Whoever writes to this channel is permitted to call recv. When
+ // finished calling recv, this channel should be emptied.
+ recvr chan bool
+}
+
+// NewClient creates a new client. It performs a Tversion exchange with
+// the server to assert that messageSize is ok to use.
+//
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
+func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
+ // Need at least one byte of payload.
+ if messageSize <= msgRegistry.largestFixedSize {
+ return nil, &ErrMessageTooLarge{
+ size: messageSize,
+ msize: msgRegistry.largestFixedSize,
+ }
+ }
+
+ // Compute a payload size and round to 512 (normal block size)
+ // if it's larger than a single block.
+ payloadSize := messageSize - msgRegistry.largestFixedSize
+ if payloadSize > 512 && payloadSize%512 != 0 {
+ payloadSize -= (payloadSize % 512)
+ }
+ c := &Client{
+ socket: socket,
+ tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)},
+ fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)},
+ pending: make(map[Tag]*response),
+ recvr: make(chan bool, 1),
+ messageSize: messageSize,
+ payloadSize: payloadSize,
+ }
+ // Agree upon a version.
+ requested, ok := parseVersion(version)
+ if !ok {
+ return nil, ErrBadVersionString
+ }
+ for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
+ rversion := Rversion{}
+ _, err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
+
+ // The server told us to try again with a lower version.
+ if err == syscall.EAGAIN {
+ if requested == lowestSupportedVersion {
+ return nil, ErrVersionsExhausted
+ }
+ requested--
+ continue
+ }
+
+ // We requested an impossible version or our other parameters were bogus.
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the version.
+ version, ok := parseVersion(rversion.Version)
+ if !ok {
+ // The server gave us a bad version. We return a generically worrisome error.
+ log.Warningf("server returned bad version string %q", rversion.Version)
+ return nil, ErrBadVersionString
+ }
+ c.version = version
+ break
+ }
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacySyscallErr
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacySyscallErr
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
+ return c, nil
+}
+
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if _, err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if _, err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
+// handleOne handles a single incoming message.
+//
+// This should only be called with the token from recvr. Note that the received
+// tag will automatically be cleared from pending.
+func (c *Client) handleOne() {
+ tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) {
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ c.pendingMu.Unlock()
+
+ // Not expecting this message?
+ if resp == nil {
+ log.Warningf("client received unexpected tag %v, ignoring", tag)
+ return nil, ErrUnexpectedTag
+ }
+
+ // Is it an error? We specifically allow this to
+ // go through, and then we deserialize below.
+ if t == MsgRlerror {
+ return &Rlerror{}, nil
+ }
+
+ // Does it match expectations?
+ if t != resp.r.Type() {
+ return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()}
+ }
+
+ // Return the response.
+ return resp.r, nil
+ })
+
+ if err != nil {
+ // No tag was extracted (probably a socket error).
+ //
+ // Likely catastrophic. Notify all waiters and clear pending.
+ c.pendingMu.Lock()
+ for _, resp := range c.pending {
+ resp.done <- err
+ }
+ c.pending = make(map[Tag]*response)
+ c.pendingMu.Unlock()
+ } else {
+ // Process the tag.
+ //
+ // We know that is is contained in the map because our lookup function
+ // above must have succeeded (found the tag) to return nil err.
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ delete(c.pending, tag)
+ c.pendingMu.Unlock()
+ resp.r = r
+ resp.done <- err
+ }
+}
+
+// waitAndRecv co-ordinates with other receivers to handle responses.
+func (c *Client) waitAndRecv(done chan error) error {
+ for {
+ select {
+ case err := <-done:
+ return err
+ case c.recvr <- true:
+ select {
+ case err := <-done:
+ // It's possible that we got the token, despite
+ // done also being available. Check for that.
+ <-c.recvr
+ return err
+ default:
+ // Handle receiving one tag.
+ c.handleOne()
+
+ // Return the token.
+ <-c.recvr
+ }
+ }
+ }
+}
+
+// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all
+// non-syscall errors to EIO.
+func (c *Client) sendRecvLegacySyscallErr(t message, r message) error {
+ received, err := c.sendRecvLegacy(t, r)
+ if !received {
+ log.Warningf("p9.Client.sendRecvChannel: %v", err)
+ return syscall.EIO
+ }
+ return err
+}
+
+// sendRecvLegacy performs a roundtrip message exchange.
+//
+// sendRecvLegacy returns true if a message was received. This allows us to
+// differentiate between failed receives and successful receives where the
+// response was an error message.
+//
+// This is called by internal functions.
+func (c *Client) sendRecvLegacy(t message, r message) (bool, error) {
+ tag, ok := c.tagPool.Get()
+ if !ok {
+ return false, ErrOutOfTags
+ }
+ defer c.tagPool.Put(tag)
+
+ // Indicate we're expecting a response.
+ //
+ // Note that the tag will be cleared from pending
+ // automatically (see handleOne for details).
+ resp := responsePool.Get().(*response)
+ defer responsePool.Put(resp)
+ resp.r = r
+ c.pendingMu.Lock()
+ c.pending[Tag(tag)] = resp
+ c.pendingMu.Unlock()
+
+ // Send the request over the wire.
+ c.sendMu.Lock()
+ err := send(c.socket, Tag(tag), t)
+ c.sendMu.Unlock()
+ if err != nil {
+ return false, err
+ }
+
+ // Co-ordinate with other receivers.
+ if err := c.waitAndRecv(resp.done); err != nil {
+ return false, err
+ }
+
+ // Is it an error message?
+ //
+ // For convenience, we transform these directly
+ // into errors. Handlers need not handle this case.
+ if rlerr, ok := resp.r.(*Rlerror); ok {
+ return true, syscall.Errno(rlerr.Error)
+ }
+
+ // At this point, we know it matches.
+ //
+ // Per recv call above, we will only allow a type
+ // match (and give our r) or an instance of Rlerror.
+ return true, nil
+}
+
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacySyscallErr(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ // Map all transport errors to EIO, but ensure that the real error
+ // is logged.
+ log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err)
+ return syscall.EIO
+ }
+ }
+
+ // Send the request and receive the server's response.
+ rsz, err := ch.send(t)
+ if err != nil {
+ // See above.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err)
+ return syscall.EIO
+ }
+
+ // Parse the server's response.
+ resp, retErr := ch.recv(r, rsz)
+ if resp == nil {
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr)
+ retErr = syscall.EIO
+ }
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return retErr
+}
+
+// Version returns the negotiated 9P2000.L.Google version number.
+func (c *Client) Version() uint32 {
+ return c.version
+}
+
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
+}