summaryrefslogtreecommitdiffhomepage
path: root/pkg/p9
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/p9')
-rw-r--r--pkg/p9/BUILD52
-rw-r--r--pkg/p9/buffer.go263
-rw-r--r--pkg/p9/buffer_test.go31
-rw-r--r--pkg/p9/client.go575
-rw-r--r--pkg/p9/client_file.go686
-rw-r--r--pkg/p9/client_test.go109
-rw-r--r--pkg/p9/file.go288
-rw-r--r--pkg/p9/handlers.go1393
-rw-r--r--pkg/p9/messages.go2662
-rw-r--r--pkg/p9/messages_test.go483
-rw-r--r--pkg/p9/p9.go1158
-rw-r--r--pkg/p9/p9_test.go188
-rw-r--r--pkg/p9/p9test/BUILD88
-rw-r--r--pkg/p9/p9test/client_test.go2242
-rw-r--r--pkg/p9/p9test/p9test.go329
-rw-r--r--pkg/p9/path_tree.go222
-rw-r--r--pkg/p9/server.go694
-rw-r--r--pkg/p9/transport.go345
-rw-r--r--pkg/p9/transport_flipcall.go243
-rw-r--r--pkg/p9/transport_test.go231
-rw-r--r--pkg/p9/version.go175
-rw-r--r--pkg/p9/version_test.go145
22 files changed, 12602 insertions, 0 deletions
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
new file mode 100644
index 000000000..8904afad9
--- /dev/null
+++ b/pkg/p9/BUILD
@@ -0,0 +1,52 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "p9",
+ srcs = [
+ "buffer.go",
+ "client.go",
+ "client_file.go",
+ "file.go",
+ "handlers.go",
+ "messages.go",
+ "p9.go",
+ "path_tree.go",
+ "server.go",
+ "transport.go",
+ "transport_flipcall.go",
+ "version.go",
+ ],
+ deps = [
+ "//pkg/fd",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
+ "//pkg/log",
+ "//pkg/pool",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "p9_test",
+ size = "small",
+ srcs = [
+ "buffer_test.go",
+ "client_test.go",
+ "messages_test.go",
+ "p9_test.go",
+ "transport_test.go",
+ "version_test.go",
+ ],
+ library = ":p9",
+ deps = [
+ "//pkg/fd",
+ "//pkg/unet",
+ ],
+)
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
new file mode 100644
index 000000000..6a4951821
--- /dev/null
+++ b/pkg/p9/buffer.go
@@ -0,0 +1,263 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "encoding/binary"
+)
+
+// encoder is used for messages and 9P primitives.
+type encoder interface {
+ // decode decodes from the given buffer. decode may be called more than once
+ // to reuse the instance. It must clear any previous state.
+ //
+ // This may not fail, exhaustion will be recorded in the buffer.
+ decode(b *buffer)
+
+ // encode encodes to the given buffer.
+ //
+ // This may not fail.
+ encode(b *buffer)
+}
+
+// order is the byte order used for encoding.
+var order = binary.LittleEndian
+
+// buffer is a slice that is consumed.
+//
+// This is passed to the encoder methods.
+type buffer struct {
+ // data is the underlying data. This may grow during encode.
+ data []byte
+
+ // overflow indicates whether an overflow has occurred.
+ overflow bool
+}
+
+// append appends n bytes to the buffer and returns a slice pointing to the
+// newly appended bytes.
+func (b *buffer) append(n int) []byte {
+ b.data = append(b.data, make([]byte, n)...)
+ return b.data[len(b.data)-n:]
+}
+
+// consume consumes n bytes from the buffer.
+func (b *buffer) consume(n int) ([]byte, bool) {
+ if !b.has(n) {
+ b.markOverrun()
+ return nil, false
+ }
+ rval := b.data[:n]
+ b.data = b.data[n:]
+ return rval, true
+}
+
+// has returns true if n bytes are available.
+func (b *buffer) has(n int) bool {
+ return len(b.data) >= n
+}
+
+// markOverrun immediately marks this buffer as overrun.
+//
+// This is used by ReadString, since some invalid data implies the rest of the
+// buffer is no longer valid either.
+func (b *buffer) markOverrun() {
+ b.overflow = true
+}
+
+// isOverrun returns true if this buffer has run past the end.
+func (b *buffer) isOverrun() bool {
+ return b.overflow
+}
+
+// Read8 reads a byte from the buffer.
+func (b *buffer) Read8() uint8 {
+ v, ok := b.consume(1)
+ if !ok {
+ return 0
+ }
+ return uint8(v[0])
+}
+
+// Read16 reads a 16-bit value from the buffer.
+func (b *buffer) Read16() uint16 {
+ v, ok := b.consume(2)
+ if !ok {
+ return 0
+ }
+ return order.Uint16(v)
+}
+
+// Read32 reads a 32-bit value from the buffer.
+func (b *buffer) Read32() uint32 {
+ v, ok := b.consume(4)
+ if !ok {
+ return 0
+ }
+ return order.Uint32(v)
+}
+
+// Read64 reads a 64-bit value from the buffer.
+func (b *buffer) Read64() uint64 {
+ v, ok := b.consume(8)
+ if !ok {
+ return 0
+ }
+ return order.Uint64(v)
+}
+
+// ReadQIDType reads a QIDType value.
+func (b *buffer) ReadQIDType() QIDType {
+ return QIDType(b.Read8())
+}
+
+// ReadTag reads a Tag value.
+func (b *buffer) ReadTag() Tag {
+ return Tag(b.Read16())
+}
+
+// ReadFID reads a FID value.
+func (b *buffer) ReadFID() FID {
+ return FID(b.Read32())
+}
+
+// ReadUID reads a UID value.
+func (b *buffer) ReadUID() UID {
+ return UID(b.Read32())
+}
+
+// ReadGID reads a GID value.
+func (b *buffer) ReadGID() GID {
+ return GID(b.Read32())
+}
+
+// ReadPermissions reads a file mode value and applies the mask for permissions.
+func (b *buffer) ReadPermissions() FileMode {
+ return b.ReadFileMode() & permissionsMask
+}
+
+// ReadFileMode reads a file mode value.
+func (b *buffer) ReadFileMode() FileMode {
+ return FileMode(b.Read32())
+}
+
+// ReadOpenFlags reads an OpenFlags.
+func (b *buffer) ReadOpenFlags() OpenFlags {
+ return OpenFlags(b.Read32())
+}
+
+// ReadConnectFlags reads a ConnectFlags.
+func (b *buffer) ReadConnectFlags() ConnectFlags {
+ return ConnectFlags(b.Read32())
+}
+
+// ReadMsgType writes a MsgType.
+func (b *buffer) ReadMsgType() MsgType {
+ return MsgType(b.Read8())
+}
+
+// ReadString deserializes a string.
+func (b *buffer) ReadString() string {
+ l := b.Read16()
+ if !b.has(int(l)) {
+ // Mark the buffer as corrupted.
+ b.markOverrun()
+ return ""
+ }
+
+ bs := make([]byte, l)
+ for i := 0; i < int(l); i++ {
+ bs[i] = byte(b.Read8())
+ }
+ return string(bs)
+}
+
+// Write8 writes a byte to the buffer.
+func (b *buffer) Write8(v uint8) {
+ b.append(1)[0] = byte(v)
+}
+
+// Write16 writes a 16-bit value to the buffer.
+func (b *buffer) Write16(v uint16) {
+ order.PutUint16(b.append(2), v)
+}
+
+// Write32 writes a 32-bit value to the buffer.
+func (b *buffer) Write32(v uint32) {
+ order.PutUint32(b.append(4), v)
+}
+
+// Write64 writes a 64-bit value to the buffer.
+func (b *buffer) Write64(v uint64) {
+ order.PutUint64(b.append(8), v)
+}
+
+// WriteQIDType writes a QIDType value.
+func (b *buffer) WriteQIDType(qidType QIDType) {
+ b.Write8(uint8(qidType))
+}
+
+// WriteTag writes a Tag value.
+func (b *buffer) WriteTag(tag Tag) {
+ b.Write16(uint16(tag))
+}
+
+// WriteFID writes a FID value.
+func (b *buffer) WriteFID(fid FID) {
+ b.Write32(uint32(fid))
+}
+
+// WriteUID writes a UID value.
+func (b *buffer) WriteUID(uid UID) {
+ b.Write32(uint32(uid))
+}
+
+// WriteGID writes a GID value.
+func (b *buffer) WriteGID(gid GID) {
+ b.Write32(uint32(gid))
+}
+
+// WritePermissions applies a permissions mask and writes the FileMode.
+func (b *buffer) WritePermissions(perm FileMode) {
+ b.WriteFileMode(perm & permissionsMask)
+}
+
+// WriteFileMode writes a FileMode.
+func (b *buffer) WriteFileMode(mode FileMode) {
+ b.Write32(uint32(mode))
+}
+
+// WriteOpenFlags writes an OpenFlags.
+func (b *buffer) WriteOpenFlags(flags OpenFlags) {
+ b.Write32(uint32(flags))
+}
+
+// WriteConnectFlags writes a ConnectFlags.
+func (b *buffer) WriteConnectFlags(flags ConnectFlags) {
+ b.Write32(uint32(flags))
+}
+
+// WriteMsgType writes a MsgType.
+func (b *buffer) WriteMsgType(t MsgType) {
+ b.Write8(uint8(t))
+}
+
+// WriteString serializes the given string.
+func (b *buffer) WriteString(s string) {
+ b.Write16(uint16(len(s)))
+ for i := 0; i < len(s); i++ {
+ b.Write8(byte(s[i]))
+ }
+}
diff --git a/pkg/p9/buffer_test.go b/pkg/p9/buffer_test.go
new file mode 100644
index 000000000..a9c75f86b
--- /dev/null
+++ b/pkg/p9/buffer_test.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "testing"
+)
+
+func TestBufferOverrun(t *testing.T) {
+ buf := &buffer{
+ // This header indicates that a large string should follow, but
+ // it is only two bytes. Reading a string should cause an
+ // overrun.
+ data: []byte{0x0, 0x16},
+ }
+ if s := buf.ReadString(); s != "" {
+ t.Errorf("overrun read got %s, want empty", s)
+ }
+}
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
new file mode 100644
index 000000000..71e944c30
--- /dev/null
+++ b/pkg/p9/client.go
@@ -0,0 +1,575 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "errors"
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/pool"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// ErrOutOfTags indicates no tags are available.
+var ErrOutOfTags = errors.New("out of tags -- messages lost?")
+
+// ErrOutOfFIDs indicates no more FIDs are available.
+var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?")
+
+// ErrUnexpectedTag indicates a response with an unexpected tag was received.
+var ErrUnexpectedTag = errors.New("unexpected tag in response")
+
+// ErrVersionsExhausted indicates that all versions to negotiate have been exhausted.
+var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate")
+
+// ErrBadVersionString indicates that the version string is malformed or unsupported.
+var ErrBadVersionString = errors.New("bad version string")
+
+// ErrBadResponse indicates the response didn't match the request.
+type ErrBadResponse struct {
+ Got MsgType
+ Want MsgType
+}
+
+// Error returns a highly descriptive error.
+func (e *ErrBadResponse) Error() string {
+ return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want)
+}
+
+// response is the asynchronous return from recv.
+//
+// This is used in the pending map below.
+type response struct {
+ r message
+ done chan error
+}
+
+var responsePool = sync.Pool{
+ New: func() interface{} {
+ return &response{
+ done: make(chan error, 1),
+ }
+ },
+}
+
+// Client is at least a 9P2000.L client.
+type Client struct {
+ // socket is the connected socket.
+ socket *unet.Socket
+
+ // tagPool is the collection of available tags.
+ tagPool pool.Pool
+
+ // fidPool is the collection of available fids.
+ fidPool pool.Pool
+
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
+ // pending is the set of pending messages.
+ pending map[Tag]*response
+ pendingMu sync.Mutex
+
+ // sendMu is the lock for sending a request.
+ sendMu sync.Mutex
+
+ // recvr is essentially a mutex for calling recv.
+ //
+ // Whoever writes to this channel is permitted to call recv. When
+ // finished calling recv, this channel should be emptied.
+ recvr chan bool
+}
+
+// NewClient creates a new client. It performs a Tversion exchange with
+// the server to assert that messageSize is ok to use.
+//
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
+func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
+ // Need at least one byte of payload.
+ if messageSize <= msgRegistry.largestFixedSize {
+ return nil, &ErrMessageTooLarge{
+ size: messageSize,
+ msize: msgRegistry.largestFixedSize,
+ }
+ }
+
+ // Compute a payload size and round to 512 (normal block size)
+ // if it's larger than a single block.
+ payloadSize := messageSize - msgRegistry.largestFixedSize
+ if payloadSize > 512 && payloadSize%512 != 0 {
+ payloadSize -= (payloadSize % 512)
+ }
+ c := &Client{
+ socket: socket,
+ tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)},
+ fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)},
+ pending: make(map[Tag]*response),
+ recvr: make(chan bool, 1),
+ messageSize: messageSize,
+ payloadSize: payloadSize,
+ }
+ // Agree upon a version.
+ requested, ok := parseVersion(version)
+ if !ok {
+ return nil, ErrBadVersionString
+ }
+ for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
+ rversion := Rversion{}
+ _, err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
+
+ // The server told us to try again with a lower version.
+ if err == syscall.EAGAIN {
+ if requested == lowestSupportedVersion {
+ return nil, ErrVersionsExhausted
+ }
+ requested--
+ continue
+ }
+
+ // We requested an impossible version or our other parameters were bogus.
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the version.
+ version, ok := parseVersion(rversion.Version)
+ if !ok {
+ // The server gave us a bad version. We return a generically worrisome error.
+ log.Warningf("server returned bad version string %q", rversion.Version)
+ return nil, ErrBadVersionString
+ }
+ c.version = version
+ break
+ }
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacySyscallErr
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacySyscallErr
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
+ return c, nil
+}
+
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if _, err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if _, err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
+// handleOne handles a single incoming message.
+//
+// This should only be called with the token from recvr. Note that the received
+// tag will automatically be cleared from pending.
+func (c *Client) handleOne() {
+ tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) {
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ c.pendingMu.Unlock()
+
+ // Not expecting this message?
+ if resp == nil {
+ log.Warningf("client received unexpected tag %v, ignoring", tag)
+ return nil, ErrUnexpectedTag
+ }
+
+ // Is it an error? We specifically allow this to
+ // go through, and then we deserialize below.
+ if t == MsgRlerror {
+ return &Rlerror{}, nil
+ }
+
+ // Does it match expectations?
+ if t != resp.r.Type() {
+ return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()}
+ }
+
+ // Return the response.
+ return resp.r, nil
+ })
+
+ if err != nil {
+ // No tag was extracted (probably a socket error).
+ //
+ // Likely catastrophic. Notify all waiters and clear pending.
+ c.pendingMu.Lock()
+ for _, resp := range c.pending {
+ resp.done <- err
+ }
+ c.pending = make(map[Tag]*response)
+ c.pendingMu.Unlock()
+ } else {
+ // Process the tag.
+ //
+ // We know that is is contained in the map because our lookup function
+ // above must have succeeded (found the tag) to return nil err.
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ delete(c.pending, tag)
+ c.pendingMu.Unlock()
+ resp.r = r
+ resp.done <- err
+ }
+}
+
+// waitAndRecv co-ordinates with other receivers to handle responses.
+func (c *Client) waitAndRecv(done chan error) error {
+ for {
+ select {
+ case err := <-done:
+ return err
+ case c.recvr <- true:
+ select {
+ case err := <-done:
+ // It's possible that we got the token, despite
+ // done also being available. Check for that.
+ <-c.recvr
+ return err
+ default:
+ // Handle receiving one tag.
+ c.handleOne()
+
+ // Return the token.
+ <-c.recvr
+ }
+ }
+ }
+}
+
+// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all
+// non-syscall errors to EIO.
+func (c *Client) sendRecvLegacySyscallErr(t message, r message) error {
+ received, err := c.sendRecvLegacy(t, r)
+ if !received {
+ log.Warningf("p9.Client.sendRecvChannel: %v", err)
+ return syscall.EIO
+ }
+ return err
+}
+
+// sendRecvLegacy performs a roundtrip message exchange.
+//
+// sendRecvLegacy returns true if a message was received. This allows us to
+// differentiate between failed receives and successful receives where the
+// response was an error message.
+//
+// This is called by internal functions.
+func (c *Client) sendRecvLegacy(t message, r message) (bool, error) {
+ tag, ok := c.tagPool.Get()
+ if !ok {
+ return false, ErrOutOfTags
+ }
+ defer c.tagPool.Put(tag)
+
+ // Indicate we're expecting a response.
+ //
+ // Note that the tag will be cleared from pending
+ // automatically (see handleOne for details).
+ resp := responsePool.Get().(*response)
+ defer responsePool.Put(resp)
+ resp.r = r
+ c.pendingMu.Lock()
+ c.pending[Tag(tag)] = resp
+ c.pendingMu.Unlock()
+
+ // Send the request over the wire.
+ c.sendMu.Lock()
+ err := send(c.socket, Tag(tag), t)
+ c.sendMu.Unlock()
+ if err != nil {
+ return false, err
+ }
+
+ // Co-ordinate with other receivers.
+ if err := c.waitAndRecv(resp.done); err != nil {
+ return false, err
+ }
+
+ // Is it an error message?
+ //
+ // For convenience, we transform these directly
+ // into errors. Handlers need not handle this case.
+ if rlerr, ok := resp.r.(*Rlerror); ok {
+ return true, syscall.Errno(rlerr.Error)
+ }
+
+ // At this point, we know it matches.
+ //
+ // Per recv call above, we will only allow a type
+ // match (and give our r) or an instance of Rlerror.
+ return true, nil
+}
+
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacySyscallErr(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ // Map all transport errors to EIO, but ensure that the real error
+ // is logged.
+ log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err)
+ return syscall.EIO
+ }
+ }
+
+ // Send the request and receive the server's response.
+ rsz, err := ch.send(t)
+ if err != nil {
+ // See above.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err)
+ return syscall.EIO
+ }
+
+ // Parse the server's response.
+ resp, retErr := ch.recv(r, rsz)
+ if resp == nil {
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr)
+ retErr = syscall.EIO
+ }
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return retErr
+}
+
+// Version returns the negotiated 9P2000.L.Google version number.
+func (c *Client) Version() uint32 {
+ return c.version
+}
+
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
+}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
new file mode 100644
index 000000000..2ee07b664
--- /dev/null
+++ b/pkg/p9/client_file.go
@@ -0,0 +1,686 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// Attach attaches to a server.
+//
+// Note that authentication is not currently supported.
+func (c *Client) Attach(name string) (File, error) {
+ fid, ok := c.fidPool.Get()
+ if !ok {
+ return nil, ErrOutOfFIDs
+ }
+
+ rattach := Rattach{}
+ if err := c.sendRecv(&Tattach{FID: FID(fid), Auth: Tauth{AttachName: name, AuthenticationFID: NoFID, UID: NoUID}}, &rattach); err != nil {
+ c.fidPool.Put(fid)
+ return nil, err
+ }
+
+ return c.newFile(FID(fid)), nil
+}
+
+// newFile returns a new client file.
+func (c *Client) newFile(fid FID) *clientFile {
+ return &clientFile{
+ client: c,
+ fid: fid,
+ }
+}
+
+// clientFile is provided to clients.
+//
+// This proxies all of the interfaces found in file.go.
+type clientFile struct {
+ // client is the originating client.
+ client *Client
+
+ // fid is the FID for this file.
+ fid FID
+
+ // closed indicates whether this file has been closed.
+ closed uint32
+}
+
+// Walk implements File.Walk.
+func (c *clientFile) Walk(names []string) ([]QID, File, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, syscall.EBADF
+ }
+
+ fid, ok := c.client.fidPool.Get()
+ if !ok {
+ return nil, nil, ErrOutOfFIDs
+ }
+
+ rwalk := Rwalk{}
+ if err := c.client.sendRecv(&Twalk{FID: c.fid, NewFID: FID(fid), Names: names}, &rwalk); err != nil {
+ c.client.fidPool.Put(fid)
+ return nil, nil, err
+ }
+
+ // Return a new client file.
+ return rwalk.QIDs, c.client.newFile(FID(fid)), nil
+}
+
+// WalkGetAttr implements File.WalkGetAttr.
+func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, Attr, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, AttrMask{}, Attr{}, syscall.EBADF
+ }
+
+ if !versionSupportsTwalkgetattr(c.client.version) {
+ qids, file, err := c.Walk(components)
+ if err != nil {
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ _, valid, attr, err := file.GetAttr(AttrMaskAll())
+ if err != nil {
+ file.Close()
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ return qids, file, valid, attr, nil
+ }
+
+ fid, ok := c.client.fidPool.Get()
+ if !ok {
+ return nil, nil, AttrMask{}, Attr{}, ErrOutOfFIDs
+ }
+
+ rwalkgetattr := Rwalkgetattr{}
+ if err := c.client.sendRecv(&Twalkgetattr{FID: c.fid, NewFID: FID(fid), Names: components}, &rwalkgetattr); err != nil {
+ c.client.fidPool.Put(fid)
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+
+ // Return a new client file.
+ return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil
+}
+
+// StatFS implements File.StatFS.
+func (c *clientFile) StatFS() (FSStat, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return FSStat{}, syscall.EBADF
+ }
+
+ rstatfs := Rstatfs{}
+ if err := c.client.sendRecv(&Tstatfs{FID: c.fid}, &rstatfs); err != nil {
+ return FSStat{}, err
+ }
+
+ return rstatfs.FSStat, nil
+}
+
+// FSync implements File.FSync.
+func (c *clientFile) FSync() error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tfsync{FID: c.fid}, &Rfsync{})
+}
+
+// GetAttr implements File.GetAttr.
+func (c *clientFile) GetAttr(req AttrMask) (QID, AttrMask, Attr, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, AttrMask{}, Attr{}, syscall.EBADF
+ }
+
+ rgetattr := Rgetattr{}
+ if err := c.client.sendRecv(&Tgetattr{FID: c.fid, AttrMask: req}, &rgetattr); err != nil {
+ return QID{}, AttrMask{}, Attr{}, err
+ }
+
+ return rgetattr.QID, rgetattr.Valid, rgetattr.Attr, nil
+}
+
+// SetAttr implements File.SetAttr.
+func (c *clientFile) SetAttr(valid SetAttrMask, attr SetAttr) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tsetattr{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattr{})
+}
+
+// GetXattr implements File.GetXattr.
+func (c *clientFile) GetXattr(name string, size uint64) (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return "", syscall.EOPNOTSUPP
+ }
+
+ rgetxattr := Rgetxattr{}
+ if err := c.client.sendRecv(&Tgetxattr{FID: c.fid, Name: name, Size: size}, &rgetxattr); err != nil {
+ return "", err
+ }
+
+ return rgetxattr.Value, nil
+}
+
+// SetXattr implements File.SetXattr.
+func (c *clientFile) SetXattr(name, value string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tsetxattr{FID: c.fid, Name: name, Value: value, Flags: flags}, &Rsetxattr{})
+}
+
+// ListXattr implements File.ListXattr.
+func (c *clientFile) ListXattr(size uint64) (map[string]struct{}, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return nil, syscall.EOPNOTSUPP
+ }
+
+ rlistxattr := Rlistxattr{}
+ if err := c.client.sendRecv(&Tlistxattr{FID: c.fid, Size: size}, &rlistxattr); err != nil {
+ return nil, err
+ }
+
+ xattrs := make(map[string]struct{}, len(rlistxattr.Xattrs))
+ for _, x := range rlistxattr.Xattrs {
+ xattrs[x] = struct{}{}
+ }
+ return xattrs, nil
+}
+
+// RemoveXattr implements File.RemoveXattr.
+func (c *clientFile) RemoveXattr(name string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tremovexattr{FID: c.fid, Name: name}, &Rremovexattr{})
+}
+
+// Allocate implements File.Allocate.
+func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsTallocate(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tallocate{FID: c.fid, Mode: mode, Offset: offset, Length: length}, &Rallocate{})
+}
+
+// Remove implements File.Remove.
+//
+// N.B. This method is no longer part of the file interface and should be
+// considered deprecated.
+func (c *clientFile) Remove() error {
+ // Avoid double close.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return syscall.EBADF
+ }
+
+ // Send the remove message.
+ if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil {
+ return err
+ }
+
+ // "It is correct to consider remove to be a clunk with the side effect
+ // of removing the file if permissions allow."
+ // https://swtch.com/plan9port/man/man9/remove.html
+
+ // Return the FID to the pool.
+ c.client.fidPool.Put(uint64(c.fid))
+ return nil
+}
+
+// Close implements File.Close.
+func (c *clientFile) Close() error {
+ // Avoid double close.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return syscall.EBADF
+ }
+
+ // Send the close message.
+ if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil {
+ // If an error occurred, we toss away the FID. This isn't ideal,
+ // but I'm not sure what else makes sense in this context.
+ log.Warningf("Tclunk failed, losing FID %v: %v", c.fid, err)
+ return err
+ }
+
+ // Return the FID to the pool.
+ c.client.fidPool.Put(uint64(c.fid))
+ return nil
+}
+
+// Open implements File.Open.
+func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, QID{}, 0, syscall.EBADF
+ }
+
+ rlopen := Rlopen{}
+ if err := c.client.sendRecv(&Tlopen{FID: c.fid, Flags: flags}, &rlopen); err != nil {
+ return nil, QID{}, 0, err
+ }
+
+ return rlopen.File, rlopen.QID, rlopen.IoUnit, nil
+}
+
+// Connect implements File.Connect.
+func (c *clientFile) Connect(flags ConnectFlags) (*fd.FD, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+
+ if !VersionSupportsConnect(c.client.version) {
+ return nil, syscall.ECONNREFUSED
+ }
+
+ rlconnect := Rlconnect{}
+ if err := c.client.sendRecv(&Tlconnect{FID: c.fid, Flags: flags}, &rlconnect); err != nil {
+ return nil, err
+ }
+
+ return rlconnect.File, nil
+}
+
+// chunk applies fn to p in chunkSize-sized chunks until fn returns a partial result, p is
+// exhausted, or an error is encountered (which may be io.EOF).
+func chunk(chunkSize uint32, fn func([]byte, uint64) (int, error), p []byte, offset uint64) (int, error) {
+ // Some p9.Clients depend on executing fn on zero-byte buffers. Handle this
+ // as a special case (normally it is fine to short-circuit and return (0, nil)).
+ if len(p) == 0 {
+ return fn(p, offset)
+ }
+
+ // total is the cumulative bytes processed.
+ var total int
+ for {
+ var n int
+ var err error
+
+ // We're done, don't bother trying to do anything more.
+ if total == len(p) {
+ return total, nil
+ }
+
+ // Apply fn to a chunkSize-sized (or less) chunk of p.
+ if len(p) < total+int(chunkSize) {
+ n, err = fn(p[total:], offset)
+ } else {
+ n, err = fn(p[total:total+int(chunkSize)], offset)
+ }
+ total += n
+ offset += uint64(n)
+
+ // Return whatever we have processed if we encounter an error. This error
+ // could be io.EOF.
+ if err != nil {
+ return total, err
+ }
+
+ // Did we get a partial result? If so, return it immediately.
+ if n < int(chunkSize) {
+ return total, nil
+ }
+
+ // If we received more bytes than we ever requested, this is a problem.
+ if total > len(p) {
+ panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", total, len(p)))
+ }
+ }
+}
+
+// ReadAt proxies File.ReadAt.
+func (c *clientFile) ReadAt(p []byte, offset uint64) (int, error) {
+ return chunk(c.client.payloadSize, c.readAt, p, offset)
+}
+
+func (c *clientFile) readAt(p []byte, offset uint64) (int, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return 0, syscall.EBADF
+ }
+
+ rread := Rread{Data: p}
+ if err := c.client.sendRecv(&Tread{FID: c.fid, Offset: offset, Count: uint32(len(p))}, &rread); err != nil {
+ return 0, err
+ }
+
+ // The message may have been truncated, or for some reason a new buffer
+ // allocated. This isn't the common path, but we make sure that if the
+ // payload has changed we copy it. See transport.go for more information.
+ if len(p) > 0 && len(rread.Data) > 0 && &rread.Data[0] != &p[0] {
+ copy(p, rread.Data)
+ }
+
+ // io.EOF is not an error that a p9 server can return. Use POSIX semantics to
+ // return io.EOF manually: zero bytes were returned and a non-zero buffer was used.
+ if len(rread.Data) == 0 && len(p) > 0 {
+ return 0, io.EOF
+ }
+
+ return len(rread.Data), nil
+}
+
+// WriteAt proxies File.WriteAt.
+func (c *clientFile) WriteAt(p []byte, offset uint64) (int, error) {
+ return chunk(c.client.payloadSize, c.writeAt, p, offset)
+}
+
+func (c *clientFile) writeAt(p []byte, offset uint64) (int, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return 0, syscall.EBADF
+ }
+
+ rwrite := Rwrite{}
+ if err := c.client.sendRecv(&Twrite{FID: c.fid, Offset: offset, Data: p}, &rwrite); err != nil {
+ return 0, err
+ }
+
+ return int(rwrite.Count), nil
+}
+
+// ReadWriterFile wraps a File and implements io.ReadWriter, io.ReaderAt, and io.WriterAt.
+type ReadWriterFile struct {
+ File File
+ Offset uint64
+}
+
+// Read implements part of the io.ReadWriter interface.
+func (r *ReadWriterFile) Read(p []byte) (int, error) {
+ n, err := r.File.ReadAt(p, r.Offset)
+ r.Offset += uint64(n)
+ if err != nil {
+ return n, err
+ }
+ if n == 0 && len(p) > 0 {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// ReadAt implements the io.ReaderAt interface.
+func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) {
+ n, err := r.File.ReadAt(p, uint64(offset))
+ if err != nil {
+ return 0, err
+ }
+ if n == 0 && len(p) > 0 {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// Write implements part of the io.ReadWriter interface.
+func (r *ReadWriterFile) Write(p []byte) (int, error) {
+ n, err := r.File.WriteAt(p, r.Offset)
+ r.Offset += uint64(n)
+ if err != nil {
+ return n, err
+ }
+ if n < len(p) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
+// WriteAt implements the io.WriteAt interface.
+func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) {
+ n, err := r.File.WriteAt(p, uint64(offset))
+ if err != nil {
+ return n, err
+ }
+ if n < len(p) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
+// Rename implements File.Rename.
+func (c *clientFile) Rename(dir File, name string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ clientDir, ok := dir.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Trename{FID: c.fid, Directory: clientDir.fid, Name: name}, &Rrename{})
+}
+
+// Create implements File.Create.
+func (c *clientFile) Create(name string, openFlags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, QID{}, 0, syscall.EBADF
+ }
+
+ msg := Tlcreate{
+ FID: c.fid,
+ Name: name,
+ OpenFlags: openFlags,
+ Permissions: permissions,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rucreate := Rucreate{}
+ if err := c.client.sendRecv(&Tucreate{Tlcreate: msg, UID: uid}, &rucreate); err != nil {
+ return nil, nil, QID{}, 0, err
+ }
+ return rucreate.File, c, rucreate.QID, rucreate.IoUnit, nil
+ }
+
+ rlcreate := Rlcreate{}
+ if err := c.client.sendRecv(&msg, &rlcreate); err != nil {
+ return nil, nil, QID{}, 0, err
+ }
+
+ return rlcreate.File, c, rlcreate.QID, rlcreate.IoUnit, nil
+}
+
+// Mkdir implements File.Mkdir.
+func (c *clientFile) Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tmkdir{
+ Directory: c.fid,
+ Name: name,
+ Permissions: permissions,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rumkdir := Rumkdir{}
+ if err := c.client.sendRecv(&Tumkdir{Tmkdir: msg, UID: uid}, &rumkdir); err != nil {
+ return QID{}, err
+ }
+ return rumkdir.QID, nil
+ }
+
+ rmkdir := Rmkdir{}
+ if err := c.client.sendRecv(&msg, &rmkdir); err != nil {
+ return QID{}, err
+ }
+
+ return rmkdir.QID, nil
+}
+
+// Symlink implements File.Symlink.
+func (c *clientFile) Symlink(oldname string, newname string, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tsymlink{
+ Directory: c.fid,
+ Name: newname,
+ Target: oldname,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rusymlink := Rusymlink{}
+ if err := c.client.sendRecv(&Tusymlink{Tsymlink: msg, UID: uid}, &rusymlink); err != nil {
+ return QID{}, err
+ }
+ return rusymlink.QID, nil
+ }
+
+ rsymlink := Rsymlink{}
+ if err := c.client.sendRecv(&msg, &rsymlink); err != nil {
+ return QID{}, err
+ }
+
+ return rsymlink.QID, nil
+}
+
+// Link implements File.Link.
+func (c *clientFile) Link(target File, newname string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ targetFile, ok := target.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tlink{Directory: c.fid, Name: newname, Target: targetFile.fid}, &Rlink{})
+}
+
+// Mknod implements File.Mknod.
+func (c *clientFile) Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tmknod{
+ Directory: c.fid,
+ Name: name,
+ Mode: mode,
+ Major: major,
+ Minor: minor,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rumknod := Rumknod{}
+ if err := c.client.sendRecv(&Tumknod{Tmknod: msg, UID: uid}, &rumknod); err != nil {
+ return QID{}, err
+ }
+ return rumknod.QID, nil
+ }
+
+ rmknod := Rmknod{}
+ if err := c.client.sendRecv(&msg, &rmknod); err != nil {
+ return QID{}, err
+ }
+
+ return rmknod.QID, nil
+}
+
+// RenameAt implements File.RenameAt.
+func (c *clientFile) RenameAt(oldname string, newdir File, newname string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ clientNewDir, ok := newdir.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Trenameat{OldDirectory: c.fid, OldName: oldname, NewDirectory: clientNewDir.fid, NewName: newname}, &Rrenameat{})
+}
+
+// UnlinkAt implements File.UnlinkAt.
+func (c *clientFile) UnlinkAt(name string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tunlinkat{Directory: c.fid, Name: name, Flags: flags}, &Runlinkat{})
+}
+
+// Readdir implements File.Readdir.
+func (c *clientFile) Readdir(offset uint64, count uint32) ([]Dirent, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+
+ rreaddir := Rreaddir{}
+ if err := c.client.sendRecv(&Treaddir{Directory: c.fid, Offset: offset, Count: count}, &rreaddir); err != nil {
+ return nil, err
+ }
+
+ return rreaddir.Entries, nil
+}
+
+// Readlink implements File.Readlink.
+func (c *clientFile) Readlink() (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+
+ rreadlink := Rreadlink{}
+ if err := c.client.sendRecv(&Treadlink{FID: c.fid}, &rreadlink); err != nil {
+ return "", err
+ }
+
+ return rreadlink.Target, nil
+}
+
+// Flush implements File.Flush.
+func (c *clientFile) Flush() error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ if !VersionSupportsTflushf(c.client.version) {
+ return nil
+ }
+
+ return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{})
+}
+
+// Renamed implements File.Renamed.
+func (c *clientFile) Renamed(newDir File, newName string) {}
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
new file mode 100644
index 000000000..c757583e0
--- /dev/null
+++ b/pkg/p9/client_test.go
@@ -0,0 +1,109 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// TestVersion tests the version negotiation.
+func TestVersion(t *testing.T) {
+ // First, create a new server and connection.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // Create a new server and client.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // NewClient does a Tversion exchange, so this is our test for success.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ t.Fatalf("got %v, expected nil", err)
+ }
+
+ // Check a bogus version string.
+ if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
+ t.Errorf("got %v expected %v", err, syscall.EINVAL)
+ }
+
+ // Check a bogus version number.
+ if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
+ t.Errorf("got %v expected %v", err, syscall.EINVAL)
+ }
+
+ // Check a too high version number.
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN {
+ t.Errorf("got %v expected %v", err, syscall.EAGAIN)
+ }
+
+ // Check an invalid MSize.
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion), MSize: 0}, &Rversion{}); err != syscall.EINVAL {
+ t.Errorf("got %v expected %v", err, syscall.EINVAL)
+ }
+}
+
+func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ // See above.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // See above.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // See above.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ b.Fatalf("got %v, expected nil", err)
+ }
+
+ // Initialize messages.
+ sendRecv := fn(c)
+ tversion := &Tversion{
+ Version: versionString(highestSupportedVersion),
+ MSize: DefaultMessageSize,
+ }
+ rversion := new(Rversion)
+
+ // Run in a loop.
+ for i := 0; i < b.N; i++ {
+ if err := sendRecv(tversion, rversion); err != nil {
+ b.Fatalf("got unexpected err: %v", err)
+ }
+ }
+}
+
+func BenchmarkSendRecvLegacy(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error {
+ return func(t message, r message) error {
+ _, err := c.sendRecvLegacy(t, r)
+ return err
+ }
+ })
+}
+
+func BenchmarkSendRecvChannel(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel })
+}
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
new file mode 100644
index 000000000..cab35896f
--- /dev/null
+++ b/pkg/p9/file.go
@@ -0,0 +1,288 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+)
+
+// Attacher is provided by the server.
+type Attacher interface {
+ // Attach returns a new File.
+ //
+ // The client-side attach will be translate to a series of walks from
+ // the file returned by this Attach call.
+ Attach() (File, error)
+}
+
+// File is a set of operations corresponding to a single node.
+//
+// Note that on the server side, the server logic places constraints on
+// concurrent operations to make things easier. This may reduce the need for
+// complex, error-prone locking and logic in the backend. These are documented
+// for each method.
+//
+// There are three different types of guarantees provided:
+//
+// none: There is no concurrency guarantee. The method may be invoked
+// concurrently with any other method on any other file.
+//
+// read: The method is guaranteed to be exclusive of any write or global
+// operation that is mutating the state of the directory tree starting at this
+// node. For example, this means creating new files, symlinks, directories or
+// renaming a directory entry (or renaming in to this target), but the method
+// may be called concurrently with other read methods.
+//
+// write: The method is guaranteed to be exclusive of any read, write or global
+// operation that is mutating the state of the directory tree starting at this
+// node, as described in read above. There may however, be other write
+// operations executing concurrently on other components in the directory tree.
+//
+// global: The method is guaranteed to be exclusive of any read, write or
+// global operation.
+type File interface {
+ // Walk walks to the path components given in names.
+ //
+ // Walk returns QIDs in the same order that the names were passed in.
+ //
+ // An empty list of arguments should return a copy of the current file.
+ //
+ // On the server, Walk has a read concurrency guarantee.
+ Walk(names []string) ([]QID, File, error)
+
+ // WalkGetAttr walks to the next file and returns its maximal set of
+ // attributes.
+ //
+ // Server-side p9.Files may return syscall.ENOSYS to indicate that Walk
+ // and GetAttr should be used separately to satisfy this request.
+ //
+ // On the server, WalkGetAttr has a read concurrency guarantee.
+ WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error)
+
+ // StatFS returns information about the file system associated with
+ // this file.
+ //
+ // On the server, StatFS has no concurrency guarantee.
+ StatFS() (FSStat, error)
+
+ // GetAttr returns attributes of this node.
+ //
+ // On the server, GetAttr has a read concurrency guarantee.
+ GetAttr(req AttrMask) (QID, AttrMask, Attr, error)
+
+ // SetAttr sets attributes on this node.
+ //
+ // On the server, SetAttr has a write concurrency guarantee.
+ SetAttr(valid SetAttrMask, attr SetAttr) error
+
+ // GetXattr returns extended attributes of this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // On the server, GetXattr has a read concurrency guarantee.
+ GetXattr(name string, size uint64) (string, error)
+
+ // SetXattr sets extended attributes on this node.
+ //
+ // On the server, SetXattr has a write concurrency guarantee.
+ SetXattr(name, value string, flags uint32) error
+
+ // ListXattr lists the names of the extended attributes on this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute list. If the list would be larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // On the server, ListXattr has a read concurrency guarantee.
+ ListXattr(size uint64) (map[string]struct{}, error)
+
+ // RemoveXattr removes extended attributes on this node.
+ //
+ // On the server, RemoveXattr has a write concurrency guarantee.
+ RemoveXattr(name string) error
+
+ // Allocate allows the caller to directly manipulate the allocated disk space
+ // for the file. See fallocate(2) for more details.
+ Allocate(mode AllocateMode, offset, length uint64) error
+
+ // Close is called when all references are dropped on the server side,
+ // and Close should be called by the client to drop all references.
+ //
+ // For server-side implementations of Close, the error is ignored.
+ //
+ // Close must be called even when Open has not been called.
+ //
+ // On the server, Close has no concurrency guarantee.
+ Close() error
+
+ // Open must be called prior to using Read, Write or Readdir. Once Open
+ // is called, some operations, such as Walk, will no longer work.
+ //
+ // On the client, Open should be called only once. The fd return is
+ // optional, and may be nil.
+ //
+ // On the server, Open has a read concurrency guarantee. If an *fd.FD
+ // is provided, ownership now belongs to the caller. Open is guaranteed
+ // to be called only once.
+ //
+ // N.B. The server must resolve any lazy paths when open is called.
+ // After this point, read and write may be called on files with no
+ // deletion check, so resolving in the data path is not viable.
+ Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
+
+ // Read reads from this file. Open must be called first.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, ReadAt has a read concurrency guarantee. See Open for
+ // additional requirements regarding lazy path resolution.
+ ReadAt(p []byte, offset uint64) (int, error)
+
+ // Write writes to this file. Open must be called first.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, WriteAt has a read concurrency guarantee. See Open
+ // for additional requirements regarding lazy path resolution.
+ WriteAt(p []byte, offset uint64) (int, error)
+
+ // FSync syncs this node. Open must be called first.
+ //
+ // On the server, FSync has a read concurrency guarantee.
+ FSync() error
+
+ // Create creates a new regular file and opens it according to the
+ // flags given. This file is already Open.
+ //
+ // N.B. On the client, the returned file is a reference to the current
+ // file, which now represents the created file. This is not the case on
+ // the server. These semantics are very subtle and can easily lead to
+ // bugs, but are a consequence of the 9P create operation.
+ //
+ // See p9.File.Open for a description of *fd.FD.
+ //
+ // On the server, Create has a write concurrency guarantee.
+ Create(name string, flags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error)
+
+ // Mkdir creates a subdirectory.
+ //
+ // On the server, Mkdir has a write concurrency guarantee.
+ Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error)
+
+ // Symlink makes a new symbolic link.
+ //
+ // On the server, Symlink has a write concurrency guarantee.
+ Symlink(oldName string, newName string, uid UID, gid GID) (QID, error)
+
+ // Link makes a new hard link.
+ //
+ // On the server, Link has a write concurrency guarantee.
+ Link(target File, newName string) error
+
+ // Mknod makes a new device node.
+ //
+ // On the server, Mknod has a write concurrency guarantee.
+ Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error)
+
+ // Rename renames the file.
+ //
+ // Rename will never be called on the server, and RenameAt will always
+ // be used instead.
+ Rename(newDir File, newName string) error
+
+ // RenameAt renames a given file to a new name in a potentially new
+ // directory.
+ //
+ // oldName must be a name relative to this file, which must be a
+ // directory. newName is a name relative to newDir.
+ //
+ // On the server, RenameAt has a global concurrency guarantee.
+ RenameAt(oldName string, newDir File, newName string) error
+
+ // UnlinkAt the given named file.
+ //
+ // name must be a file relative to this directory.
+ //
+ // Flags are implementation-specific (e.g. O_DIRECTORY), but are
+ // generally Linux unlinkat(2) flags.
+ //
+ // On the server, UnlinkAt has a write concurrency guarantee.
+ UnlinkAt(name string, flags uint32) error
+
+ // Readdir reads directory entries.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, Readdir has a read concurrency guarantee.
+ Readdir(offset uint64, count uint32) ([]Dirent, error)
+
+ // Readlink reads the link target.
+ //
+ // On the server, Readlink has a read concurrency guarantee.
+ Readlink() (string, error)
+
+ // Flush is called prior to Close.
+ //
+ // Whereas Close drops all references to the file, Flush cleans up the
+ // file state. Behavior is implementation-specific.
+ //
+ // Flush is not related to flush(9p). Flush is an extension to 9P2000.L,
+ // see version.go.
+ //
+ // On the server, Flush has a read concurrency guarantee.
+ Flush() error
+
+ // Connect establishes a new host-socket backed connection with a
+ // socket. A File does not need to be opened before it can be connected
+ // and it can be connected to multiple times resulting in a unique
+ // *fd.FD each time. In addition, the lifetime of the *fd.FD is
+ // independent from the lifetime of the p9.File and must be managed by
+ // the caller.
+ //
+ // The returned FD must be non-blocking.
+ //
+ // Flags indicates the requested type of socket.
+ //
+ // On the server, Connect has a read concurrency guarantee.
+ Connect(flags ConnectFlags) (*fd.FD, error)
+
+ // Renamed is called when this node is renamed.
+ //
+ // This may not fail. The file will hold a reference to its parent
+ // within the p9 package, and is therefore safe to use for the lifetime
+ // of this File (until Close is called).
+ //
+ // This method should not be called by clients, who should use the
+ // relevant Rename methods. (Although the method will be a no-op.)
+ //
+ // On the server, Renamed has a global concurrency guarantee.
+ Renamed(newDir File, newName string)
+}
+
+// DefaultWalkGetAttr implements File.WalkGetAttr to return ENOSYS for server-side Files.
+type DefaultWalkGetAttr struct{}
+
+// WalkGetAttr implements File.WalkGetAttr.
+func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) {
+ return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
new file mode 100644
index 000000000..1db5797dd
--- /dev/null
+++ b/pkg/p9/handlers.go
@@ -0,0 +1,1393 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "path"
+ "strings"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// ExtractErrno extracts a syscall.Errno from a error, best effort.
+func ExtractErrno(err error) syscall.Errno {
+ switch err {
+ case os.ErrNotExist:
+ return syscall.ENOENT
+ case os.ErrExist:
+ return syscall.EEXIST
+ case os.ErrPermission:
+ return syscall.EACCES
+ case os.ErrInvalid:
+ return syscall.EINVAL
+ }
+
+ // Attempt to unwrap.
+ switch e := err.(type) {
+ case syscall.Errno:
+ return e
+ case *os.PathError:
+ return ExtractErrno(e.Err)
+ case *os.SyscallError:
+ return ExtractErrno(e.Err)
+ case *os.LinkError:
+ return ExtractErrno(e.Err)
+ }
+
+ // Default case.
+ log.Warningf("unknown error: %v", err)
+ return syscall.EIO
+}
+
+// newErr returns a new error message from an error.
+func newErr(err error) *Rlerror {
+ return &Rlerror{Error: uint32(ExtractErrno(err))}
+}
+
+// handler is implemented for server-handled messages.
+//
+// See server.go for call information.
+type handler interface {
+ // Handle handles the given message.
+ //
+ // This may modify the server state. The handle function must return a
+ // message which will be sent back to the client. It may be useful to
+ // use newErr to automatically extract an error message.
+ handle(cs *connState) message
+}
+
+// handle implements handler.handle.
+func (t *Tversion) handle(cs *connState) message {
+ if t.MSize == 0 {
+ return newErr(syscall.EINVAL)
+ }
+ if t.MSize > maximumLength {
+ return newErr(syscall.EINVAL)
+ }
+ atomic.StoreUint32(&cs.messageSize, t.MSize)
+ requested, ok := parseVersion(t.Version)
+ if !ok {
+ return newErr(syscall.EINVAL)
+ }
+ // The server cannot support newer versions that it doesn't know about. In this
+ // case we return EAGAIN to tell the client to try again with a lower version.
+ if requested > highestSupportedVersion {
+ return newErr(syscall.EAGAIN)
+ }
+ // From Tversion(9P): "The server may respond with the client’s version
+ // string, or a version string identifying an earlier defined protocol version".
+ atomic.StoreUint32(&cs.version, requested)
+ return &Rversion{
+ MSize: t.MSize,
+ Version: t.Version,
+ }
+}
+
+// handle implements handler.handle.
+func (t *Tflush) handle(cs *connState) message {
+ cs.WaitTag(t.OldTag)
+ return &Rflush{}
+}
+
+// checkSafeName validates the name and returns nil or returns an error.
+func checkSafeName(name string) error {
+ if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." {
+ return nil
+ }
+ return syscall.EINVAL
+}
+
+// handle implements handler.handle.
+func (t *Tclunk) handle(cs *connState) message {
+ if !cs.DeleteFID(t.FID) {
+ return newErr(syscall.EBADF)
+ }
+ return &Rclunk{}
+}
+
+// handle implements handler.handle.
+func (t *Tremove) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Frustratingly, because we can't be guaranteed that a rename is not
+ // occurring simultaneously with this removal, we need to acquire the
+ // global rename lock for this kind of remove operation to ensure that
+ // ref.parent does not change out from underneath us.
+ //
+ // This is why Tremove is a bad idea, and clients should generally use
+ // Tunlinkat. All p9 clients will use Tunlinkat.
+ err := ref.safelyGlobal(func() error {
+ // Is this a root? Can't remove that.
+ if ref.isRoot() {
+ return syscall.EINVAL
+ }
+
+ // N.B. this remove operation is permitted, even if the file is open.
+ // See also rename below for reasoning.
+
+ // Is this file already deleted?
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Retrieve the file's proper name.
+ name := ref.parent.pathNode.nameFor(ref)
+
+ // Attempt the removal.
+ if err := ref.parent.file.UnlinkAt(name, 0); err != nil {
+ return err
+ }
+
+ // Mark all relevant fids as deleted. We don't need to lock any
+ // individual nodes because we already hold the global lock.
+ ref.parent.markChildDeleted(name)
+ return nil
+ })
+
+ // "The remove request asks the file server both to remove the file
+ // represented by fid and to clunk the fid, even if the remove fails."
+ //
+ // "It is correct to consider remove to be a clunk with the side effect
+ // of removing the file if permissions allow."
+ // https://swtch.com/plan9port/man/man9/remove.html
+ if !cs.DeleteFID(t.FID) {
+ return newErr(syscall.EBADF)
+ }
+ if err != nil {
+ return newErr(err)
+ }
+
+ return &Rremove{}
+}
+
+// handle implements handler.handle.
+//
+// We don't support authentication, so this just returns ENOSYS.
+func (t *Tauth) handle(cs *connState) message {
+ return newErr(syscall.ENOSYS)
+}
+
+// handle implements handler.handle.
+func (t *Tattach) handle(cs *connState) message {
+ // Ensure no authentication FID is provided.
+ if t.Auth.AuthenticationFID != NoFID {
+ return newErr(syscall.EINVAL)
+ }
+
+ // Must provide an absolute path.
+ if path.IsAbs(t.Auth.AttachName) {
+ // Trim off the leading / if the path is absolute. We always
+ // treat attach paths as absolute and call attach with the root
+ // argument on the server file for clarity.
+ t.Auth.AttachName = t.Auth.AttachName[1:]
+ }
+
+ // Do the attach on the root.
+ sf, err := cs.server.attacher.Attach()
+ if err != nil {
+ return newErr(err)
+ }
+ qid, valid, attr, err := sf.GetAttr(AttrMaskAll())
+ if err != nil {
+ sf.Close() // Drop file.
+ return newErr(err)
+ }
+ if !valid.Mode {
+ sf.Close() // Drop file.
+ return newErr(syscall.EINVAL)
+ }
+
+ // Build a transient reference.
+ root := &fidRef{
+ server: cs.server,
+ parent: nil,
+ file: sf,
+ refs: 1,
+ mode: attr.Mode.FileType(),
+ pathNode: cs.server.pathTree,
+ }
+ defer root.DecRef()
+
+ // Attach the root?
+ if len(t.Auth.AttachName) == 0 {
+ cs.InsertFID(t.FID, root)
+ return &Rattach{QID: qid}
+ }
+
+ // We want the same traversal checks to apply on attach, so always
+ // attach at the root and use the regular walk paths.
+ names := strings.Split(t.Auth.AttachName, "/")
+ _, newRef, _, _, err := doWalk(cs, root, names, false)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Insert the FID.
+ cs.InsertFID(t.FID, newRef)
+ return &Rattach{QID: qid}
+}
+
+// CanOpen returns whether this file open can be opened, read and written to.
+//
+// This includes everything except symlinks and sockets.
+func CanOpen(mode FileMode) bool {
+ return mode.IsRegular() || mode.IsDir() || mode.IsNamedPipe() || mode.IsBlockDevice() || mode.IsCharacterDevice()
+}
+
+// handle implements handler.handle.
+func (t *Tlopen) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ ref.openedMu.Lock()
+ defer ref.openedMu.Unlock()
+
+ // Has it been opened already?
+ if ref.opened || !CanOpen(ref.mode) {
+ return newErr(syscall.EINVAL)
+ }
+
+ if ref.mode.IsDir() {
+ // Directory must be opened ReadOnly.
+ if t.Flags&OpenFlagsModeMask != ReadOnly {
+ return newErr(syscall.EISDIR)
+ }
+ // Directory not truncatable.
+ if t.Flags&OpenTruncate != 0 {
+ return newErr(syscall.EISDIR)
+ }
+ }
+
+ var (
+ qid QID
+ ioUnit uint32
+ osFile *fd.FD
+ )
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been deleted already?
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ // Mark file as opened and set open mode.
+ ref.opened = true
+ ref.openFlags = t.Flags
+
+ rlopen := &Rlopen{QID: qid, IoUnit: ioUnit}
+ rlopen.SetFilePayload(osFile)
+ return rlopen
+}
+
+func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var (
+ osFile *fd.FD
+ nsf File
+ qid QID
+ ioUnit uint32
+ newRef *fidRef
+ )
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow creation from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the create.
+ osFile, nsf, qid, ioUnit, err = ref.file.Create(t.Name, t.OpenFlags, t.Permissions, uid, t.GID)
+ if err != nil {
+ return err
+ }
+
+ newRef = &fidRef{
+ server: cs.server,
+ parent: ref,
+ file: nsf,
+ opened: true,
+ openFlags: t.OpenFlags,
+ mode: ModeRegular,
+ pathNode: ref.pathNode.pathNodeFor(t.Name),
+ }
+ ref.pathNode.addChild(newRef, t.Name)
+ ref.IncRef() // Acquire parent reference.
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
+ // Replace the FID reference.
+ cs.InsertFID(t.FID, newRef)
+
+ rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}}
+ rlcreate.SetFilePayload(osFile)
+ return rlcreate, nil
+}
+
+// handle implements handler.handle.
+func (t *Tlcreate) handle(cs *connState) message {
+ rlcreate, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rlcreate
+}
+
+// handle implements handler.handle.
+func (t *Tsymlink) handle(cs *connState) message {
+ rsymlink, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rsymlink
+}
+
+func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow symlinks from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the symlink.
+ qid, err = ref.file.Symlink(t.Target, t.Name, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rsymlink{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tlink) handle(cs *connState) message {
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ refTarget, ok := cs.LookupFID(t.Target)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow create links from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the link.
+ return ref.file.Link(refTarget.file, t.Name)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rlink{}
+}
+
+// handle implements handler.handle.
+func (t *Trenameat) handle(cs *connState) message {
+ if err := checkSafeName(t.OldName); err != nil {
+ return newErr(err)
+ }
+ if err := checkSafeName(t.NewName); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.OldDirectory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ refTarget, ok := cs.LookupFID(t.NewDirectory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ // Perform the rename holding the global lock.
+ if err := ref.safelyGlobal(func() (err error) {
+ // Don't allow renaming across deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() || refTarget.isDeleted() || !refTarget.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Is this the same file? If yes, short-circuit and return success.
+ if ref.pathNode == refTarget.pathNode && t.OldName == t.NewName {
+ return nil
+ }
+
+ // Attempt the actual rename.
+ if err := ref.file.RenameAt(t.OldName, refTarget.file, t.NewName); err != nil {
+ return err
+ }
+
+ // Update the path tree.
+ ref.renameChildTo(t.OldName, refTarget, t.NewName)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rrenameat{}
+}
+
+// handle implements handler.handle.
+func (t *Tunlinkat) handle(cs *connState) message {
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow deletion from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Before we do the unlink itself, we need to ensure that there
+ // are no operations in flight on associated path node. The
+ // child's path node lock must be held to ensure that the
+ // unlinkat marking the child deleted below is atomic with
+ // respect to any other read or write operations.
+ //
+ // This is one case where we have a lock ordering issue, but
+ // since we always acquire deeper in the hierarchy, we know
+ // that we are free of lock cycles.
+ childPathNode := ref.pathNode.pathNodeFor(t.Name)
+ childPathNode.opMu.Lock()
+ defer childPathNode.opMu.Unlock()
+
+ // Do the unlink.
+ err = ref.file.UnlinkAt(t.Name, t.Flags)
+ if err != nil {
+ return err
+ }
+
+ // Mark the path as deleted.
+ ref.markChildDeleted(t.Name)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Runlinkat{}
+}
+
+// handle implements handler.handle.
+func (t *Trename) handle(cs *connState) message {
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ refTarget, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ if err := ref.safelyGlobal(func() (err error) {
+ // Don't allow a root rename.
+ if ref.isRoot() {
+ return syscall.EINVAL
+ }
+
+ // Don't allow renaming deleting entries, or target non-directories.
+ if ref.isDeleted() || refTarget.isDeleted() || !refTarget.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // If the parent is deleted, but we not, something is seriously wrong.
+ // It's fail to die at this point with an assertion failure.
+ if ref.parent.isDeleted() {
+ panic(fmt.Sprintf("parent %+v deleted, child %+v is not", ref.parent, ref))
+ }
+
+ // N.B. The rename operation is allowed to proceed on open files. It
+ // does impact the state of its parent, but this is merely a sanity
+ // check in any case, and the operation is safe. There may be other
+ // files corresponding to the same path that are renamed anyways.
+
+ // Check for the exact same file and short-circuit.
+ oldName := ref.parent.pathNode.nameFor(ref)
+ if ref.parent.pathNode == refTarget.pathNode && oldName == t.Name {
+ return nil
+ }
+
+ // Call the rename method on the parent.
+ if err := ref.parent.file.RenameAt(oldName, refTarget.file, t.Name); err != nil {
+ return err
+ }
+
+ // Update the path tree.
+ ref.parent.renameChildTo(oldName, refTarget, t.Name)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rrename{}
+}
+
+// handle implements handler.handle.
+func (t *Treadlink) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var target string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow readlink on deleted files. There is no need to
+ // check if this file is opened because symlinks cannot be
+ // opened.
+ if ref.isDeleted() || !ref.mode.IsSymlink() {
+ return syscall.EINVAL
+ }
+
+ // Do the read.
+ target, err = ref.file.Readlink()
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rreadlink{target}
+}
+
+// handle implements handler.handle.
+func (t *Tread) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Constrain the size of the read buffer.
+ if int(t.Count) > int(maximumLength) {
+ return newErr(syscall.ENOBUFS)
+ }
+
+ var (
+ data = make([]byte, t.Count)
+ n int
+ )
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be read? Check permissions.
+ if openFlags&OpenFlagsModeMask == WriteOnly {
+ return syscall.EPERM
+ }
+
+ n, err = ref.file.ReadAt(data, t.Offset)
+ return err
+ }); err != nil && err != io.EOF {
+ return newErr(err)
+ }
+
+ return &Rread{Data: data[:n]}
+}
+
+// handle implements handler.handle.
+func (t *Twrite) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var n int
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be written? Check permissions.
+ if openFlags&OpenFlagsModeMask == ReadOnly {
+ return syscall.EPERM
+ }
+
+ n, err = ref.file.WriteAt(t.Data, t.Offset)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rwrite{Count: uint32(n)}
+}
+
+// handle implements handler.handle.
+func (t *Tmknod) handle(cs *connState) message {
+ rmknod, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rmknod
+}
+
+func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow mknod on deleted files.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the mknod.
+ qid, err = ref.file.Mknod(t.Name, t.Mode, t.Major, t.Minor, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rmknod{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tmkdir) handle(cs *connState) message {
+ rmkdir, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rmkdir
+}
+
+func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow mkdir on deleted files.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the mkdir.
+ qid, err = ref.file.Mkdir(t.Name, t.Permissions, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rmkdir{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tgetattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We allow getattr on deleted files. Depending on the backing
+ // implementation, it's possible that races exist that might allow
+ // fetching attributes of other files. But we need to generally allow
+ // refreshing attributes and this is a minor leak, if at all.
+
+ var (
+ qid QID
+ valid AttrMask
+ attr Attr
+ )
+ if err := ref.safelyRead(func() (err error) {
+ qid, valid, attr, err = ref.file.GetAttr(t.AttrMask)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rgetattr{QID: qid, Valid: valid, Attr: attr}
+}
+
+// handle implements handler.handle.
+func (t *Tsetattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // We don't allow setattr on files that have been deleted.
+ // This might be technically incorrect, as it's possible that
+ // there were multiple links and you can still change the
+ // corresponding inode information.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Set the attributes.
+ return ref.file.SetAttr(t.Valid, t.SetAttr)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rsetattr{}
+}
+
+// handle implements handler.handle.
+func (t *Tallocate) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be written? Check permissions.
+ if openFlags&OpenFlagsModeMask == ReadOnly {
+ return syscall.EBADF
+ }
+
+ // We don't allow allocate on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ return ref.file.Allocate(t.Mode, t.Offset, t.Length)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rallocate{}
+}
+
+// handle implements handler.handle.
+func (t *Txattrwalk) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We don't support extended attributes.
+ return newErr(syscall.ENODATA)
+}
+
+// handle implements handler.handle.
+func (t *Txattrcreate) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We don't support extended attributes.
+ return newErr(syscall.ENOSYS)
+}
+
+// handle implements handler.handle.
+func (t *Tgetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var val string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow getxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ val, err = ref.file.GetXattr(t.Name, t.Size)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rgetxattr{Value: val}
+}
+
+// handle implements handler.handle.
+func (t *Tsetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Don't allow setxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.SetXattr(t.Name, t.Value, t.Flags)
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rsetxattr{}
+}
+
+// handle implements handler.handle.
+func (t *Tlistxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var xattrs map[string]struct{}
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow listxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ xattrs, err = ref.file.ListXattr(t.Size)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ xattrList := make([]string, 0, len(xattrs))
+ for x := range xattrs {
+ xattrList = append(xattrList, x)
+ }
+ return &Rlistxattr{Xattrs: xattrList}
+}
+
+// handle implements handler.handle.
+func (t *Tremovexattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Don't allow removexattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.RemoveXattr(t.Name)
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rremovexattr{}
+}
+
+// handle implements handler.handle.
+func (t *Treaddir) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var entries []Dirent
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow reading deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); !opened {
+ return syscall.EINVAL
+ }
+
+ // Read the entries.
+ entries, err = ref.file.Readdir(t.Offset, t.Count)
+ if err != nil && err != io.EOF {
+ return err
+ }
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rreaddir{Count: t.Count, Entries: entries}
+}
+
+// handle implements handler.handle.
+func (t *Tfsync) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); !opened {
+ return syscall.EINVAL
+ }
+
+ // Perform the sync.
+ return ref.file.FSync()
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rfsync{}
+}
+
+// handle implements handler.handle.
+func (t *Tstatfs) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ st, err := ref.file.StatFS()
+ if err != nil {
+ return newErr(err)
+ }
+
+ return &Rstatfs{st}
+}
+
+// handle implements handler.handle.
+func (t *Tflushf) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyRead(ref.file.Flush); err != nil {
+ return newErr(err)
+ }
+
+ return &Rflushf{}
+}
+
+// walkOne walks zero or one path elements.
+//
+// The slice passed as qids is append and returned.
+func walkOne(qids []QID, from File, names []string, getattr bool) ([]QID, File, AttrMask, Attr, error) {
+ if len(names) > 1 {
+ // We require exactly zero or one elements.
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+ var (
+ localQIDs []QID
+ sf File
+ valid AttrMask
+ attr Attr
+ err error
+ )
+ switch {
+ case getattr:
+ localQIDs, sf, valid, attr, err = from.WalkGetAttr(names)
+ // Can't put fallthrough in the if because Go.
+ if err != syscall.ENOSYS {
+ break
+ }
+ fallthrough
+ default:
+ localQIDs, sf, err = from.Walk(names)
+ if err != nil {
+ // No way to walk this element.
+ break
+ }
+ if getattr {
+ _, valid, attr, err = sf.GetAttr(AttrMaskAll())
+ if err != nil {
+ // Don't leak the file.
+ sf.Close()
+ }
+ }
+ }
+ if err != nil {
+ // Error walking, don't return anything.
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ if len(localQIDs) != 1 {
+ // Expected a single QID.
+ sf.Close()
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+ return append(qids, localQIDs...), sf, valid, attr, nil
+}
+
+// doWalk walks from a given fidRef.
+//
+// This enforces that all intermediate nodes are walkable (directories). The
+// fidRef returned (newRef) has a reference associated with it that is now
+// owned by the caller and must be handled appropriately.
+func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QID, newRef *fidRef, valid AttrMask, attr Attr, err error) {
+ // Check the names.
+ for _, name := range names {
+ err = checkSafeName(name)
+ if err != nil {
+ return
+ }
+ }
+
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); opened {
+ err = syscall.EBUSY
+ return
+ }
+
+ // Is this an empty list? Handle specially. We don't actually need to
+ // validate anything since this is always permitted.
+ if len(names) == 0 {
+ var sf File // Temporary.
+ if err := ref.maybeParent().safelyRead(func() (err error) {
+ // Clone the single element.
+ qids, sf, valid, attr, err = walkOne(nil, ref.file, nil, getattr)
+ if err != nil {
+ return err
+ }
+
+ newRef = &fidRef{
+ server: cs.server,
+ parent: ref.parent,
+ file: sf,
+ mode: ref.mode,
+ pathNode: ref.pathNode,
+
+ // For the clone case, the cloned fid must
+ // preserve the deleted property of the
+ // original FID.
+ deleted: ref.deleted,
+ }
+ if !ref.isRoot() {
+ if !newRef.isDeleted() {
+ // Add only if a non-root node; the same node.
+ ref.parent.pathNode.addChild(newRef, ref.parent.pathNode.nameFor(ref))
+ }
+ ref.parent.IncRef() // Acquire parent reference.
+ }
+ // doWalk returns a reference.
+ newRef.IncRef()
+ return nil
+ }); err != nil {
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ // Do not return the new QID.
+ return nil, newRef, valid, attr, nil
+ }
+
+ // Do the walk, one element at a time.
+ walkRef := ref
+ walkRef.IncRef()
+ for i := 0; i < len(names); i++ {
+ // We won't allow beyond past symlinks; stop here if this isn't
+ // a proper directory and we have additional paths to walk.
+ if !walkRef.mode.IsDir() {
+ walkRef.DecRef() // Drop walk reference; no lock required.
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+
+ var sf File // Temporary.
+ if err := walkRef.safelyRead(func() (err error) {
+ // Pass getattr = true to walkOne since we need the file type for
+ // newRef.
+ qids, sf, valid, attr, err = walkOne(qids, walkRef.file, names[i:i+1], true)
+ if err != nil {
+ return err
+ }
+
+ // Note that we don't need to acquire a lock on any of
+ // these individual instances. That's because they are
+ // not actually addressable via a FID. They are
+ // anonymous. They exist in the tree for tracking
+ // purposes.
+ newRef := &fidRef{
+ server: cs.server,
+ parent: walkRef,
+ file: sf,
+ mode: attr.Mode.FileType(),
+ pathNode: walkRef.pathNode.pathNodeFor(names[i]),
+ }
+ walkRef.pathNode.addChild(newRef, names[i])
+ // We allow our walk reference to become the new parent
+ // reference here and so we don't IncRef. Instead, just
+ // set walkRef to the newRef above and acquire a new
+ // walk reference.
+ walkRef = newRef
+ walkRef.IncRef()
+ return nil
+ }); err != nil {
+ walkRef.DecRef() // Drop the old walkRef.
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ }
+
+ // Success.
+ return qids, walkRef, valid, attr, nil
+}
+
+// handle implements handler.handle.
+func (t *Twalk) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Do the walk.
+ qids, newRef, _, _, err := doWalk(cs, ref, t.Names, false)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Install the new FID.
+ cs.InsertFID(t.NewFID, newRef)
+ return &Rwalk{QIDs: qids}
+}
+
+// handle implements handler.handle.
+func (t *Twalkgetattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Do the walk.
+ qids, newRef, valid, attr, err := doWalk(cs, ref, t.Names, true)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Install the new FID.
+ cs.InsertFID(t.NewFID, newRef)
+ return &Rwalkgetattr{QIDs: qids, Valid: valid, Attr: attr}
+}
+
+// handle implements handler.handle.
+func (t *Tucreate) handle(cs *connState) message {
+ rlcreate, err := t.Tlcreate.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rucreate{*rlcreate}
+}
+
+// handle implements handler.handle.
+func (t *Tumkdir) handle(cs *connState) message {
+ rmkdir, err := t.Tmkdir.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rumkdir{*rmkdir}
+}
+
+// handle implements handler.handle.
+func (t *Tusymlink) handle(cs *connState) message {
+ rsymlink, err := t.Tsymlink.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rusymlink{*rsymlink}
+}
+
+// handle implements handler.handle.
+func (t *Tumknod) handle(cs *connState) message {
+ rmknod, err := t.Tmknod.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rumknod{*rmknod}
+}
+
+// handle implements handler.handle.
+func (t *Tlconnect) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var osFile *fd.FD
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow connecting to deleted files.
+ if ref.isDeleted() || !ref.mode.IsSocket() {
+ return syscall.EINVAL
+ }
+
+ // Do the connect.
+ osFile, err = ref.file.Connect(t.Flags)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ rlconnect := &Rlconnect{}
+ rlconnect.SetFilePayload(osFile)
+ return rlconnect
+}
+
+// handle implements handler.handle.
+func (t *Tchannel) handle(cs *connState) message {
+ // Ensure that channels are enabled.
+ if err := cs.initializeChannels(); err != nil {
+ return newErr(err)
+ }
+
+ ch := cs.lookupChannel(t.ID)
+ if ch == nil {
+ return newErr(syscall.ENOSYS)
+ }
+
+ // Return the payload. Note that we need to duplicate the file
+ // descriptor for the channel allocator, because sending is a
+ // destructive operation between sendRecvLegacy (and now the newer
+ // channel send operations). Same goes for the client FD.
+ rchannel := &Rchannel{
+ Offset: uint64(ch.desc.Offset),
+ Length: uint64(ch.desc.Length),
+ }
+ switch t.Control {
+ case 0:
+ // Open the main data channel.
+ mfd, err := syscall.Dup(int(cs.channelAlloc.FD()))
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(mfd))
+ case 1:
+ cfd, err := syscall.Dup(ch.client.FD())
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(cfd))
+ default:
+ return newErr(syscall.EINVAL)
+ }
+ return rchannel
+}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
new file mode 100644
index 000000000..57b89ad7d
--- /dev/null
+++ b/pkg/p9/messages.go
@@ -0,0 +1,2662 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/fd"
+)
+
+// ErrInvalidMsgType is returned when an unsupported message type is found.
+type ErrInvalidMsgType struct {
+ MsgType
+}
+
+// Error returns a useful string.
+func (e *ErrInvalidMsgType) Error() string {
+ return fmt.Sprintf("invalid message type: %d", e.MsgType)
+}
+
+// message is a generic 9P message.
+type message interface {
+ encoder
+ fmt.Stringer
+
+ // Type returns the message type number.
+ Type() MsgType
+}
+
+// payloader is a special message which may include an inline payload.
+type payloader interface {
+ // FixedSize returns the size of the fixed portion of this message.
+ FixedSize() uint32
+
+ // Payload returns the payload for sending.
+ Payload() []byte
+
+ // SetPayload returns the decoded message.
+ //
+ // This is going to be total message size - FixedSize. But this should
+ // be validated during decode, which will be called after SetPayload.
+ SetPayload([]byte)
+}
+
+// filer is a message capable of passing a file.
+type filer interface {
+ // FilePayload returns the file payload.
+ FilePayload() *fd.FD
+
+ // SetFilePayload sets the file payload.
+ SetFilePayload(*fd.FD)
+}
+
+// filePayload embeds a File object.
+type filePayload struct {
+ File *fd.FD
+}
+
+// FilePayload returns the file payload.
+func (f *filePayload) FilePayload() *fd.FD {
+ return f.File
+}
+
+// SetFilePayload sets the received file.
+func (f *filePayload) SetFilePayload(file *fd.FD) {
+ f.File = file
+}
+
+// Tversion is a version request.
+type Tversion struct {
+ // MSize is the message size to use.
+ MSize uint32
+
+ // Version is the version string.
+ //
+ // For this implementation, this must be 9P2000.L.
+ Version string
+}
+
+// decode implements encoder.decode.
+func (t *Tversion) decode(b *buffer) {
+ t.MSize = b.Read32()
+ t.Version = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tversion) encode(b *buffer) {
+ b.Write32(t.MSize)
+ b.WriteString(t.Version)
+}
+
+// Type implements message.Type.
+func (*Tversion) Type() MsgType {
+ return MsgTversion
+}
+
+// String implements fmt.Stringer.
+func (t *Tversion) String() string {
+ return fmt.Sprintf("Tversion{MSize: %d, Version: %s}", t.MSize, t.Version)
+}
+
+// Rversion is a version response.
+type Rversion struct {
+ // MSize is the negotiated size.
+ MSize uint32
+
+ // Version is the negotiated version.
+ Version string
+}
+
+// decode implements encoder.decode.
+func (r *Rversion) decode(b *buffer) {
+ r.MSize = b.Read32()
+ r.Version = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (r *Rversion) encode(b *buffer) {
+ b.Write32(r.MSize)
+ b.WriteString(r.Version)
+}
+
+// Type implements message.Type.
+func (*Rversion) Type() MsgType {
+ return MsgRversion
+}
+
+// String implements fmt.Stringer.
+func (r *Rversion) String() string {
+ return fmt.Sprintf("Rversion{MSize: %d, Version: %s}", r.MSize, r.Version)
+}
+
+// Tflush is a flush request.
+type Tflush struct {
+ // OldTag is the tag to wait on.
+ OldTag Tag
+}
+
+// decode implements encoder.decode.
+func (t *Tflush) decode(b *buffer) {
+ t.OldTag = b.ReadTag()
+}
+
+// encode implements encoder.encode.
+func (t *Tflush) encode(b *buffer) {
+ b.WriteTag(t.OldTag)
+}
+
+// Type implements message.Type.
+func (*Tflush) Type() MsgType {
+ return MsgTflush
+}
+
+// String implements fmt.Stringer.
+func (t *Tflush) String() string {
+ return fmt.Sprintf("Tflush{OldTag: %d}", t.OldTag)
+}
+
+// Rflush is a flush response.
+type Rflush struct {
+}
+
+// decode implements encoder.decode.
+func (*Rflush) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rflush) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rflush) Type() MsgType {
+ return MsgRflush
+}
+
+// String implements fmt.Stringer.
+func (r *Rflush) String() string {
+ return "RFlush{}"
+}
+
+// Twalk is a walk request.
+type Twalk struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // NewFID is the resulting FID.
+ NewFID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// decode implements encoder.decode.
+func (t *Twalk) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (t *Twalk) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Twalk) Type() MsgType {
+ return MsgTwalk
+}
+
+// String implements fmt.Stringer.
+func (t *Twalk) String() string {
+ return fmt.Sprintf("Twalk{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names)
+}
+
+// Rwalk is a walk response.
+type Rwalk struct {
+ // QIDs are the set of QIDs returned.
+ QIDs []QID
+}
+
+// decode implements encoder.decode.
+func (r *Rwalk) decode(b *buffer) {
+ n := b.Read16()
+ r.QIDs = r.QIDs[:0]
+ for i := 0; i < int(n); i++ {
+ var q QID
+ q.decode(b)
+ r.QIDs = append(r.QIDs, q)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rwalk) encode(b *buffer) {
+ b.Write16(uint16(len(r.QIDs)))
+ for _, q := range r.QIDs {
+ q.encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rwalk) Type() MsgType {
+ return MsgRwalk
+}
+
+// String implements fmt.Stringer.
+func (r *Rwalk) String() string {
+ return fmt.Sprintf("Rwalk{QIDs: %v}", r.QIDs)
+}
+
+// Tclunk is a close request.
+type Tclunk struct {
+ // FID is the FID to be closed.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tclunk) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tclunk) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tclunk) Type() MsgType {
+ return MsgTclunk
+}
+
+// String implements fmt.Stringer.
+func (t *Tclunk) String() string {
+ return fmt.Sprintf("Tclunk{FID: %d}", t.FID)
+}
+
+// Rclunk is a close response.
+type Rclunk struct {
+}
+
+// decode implements encoder.decode.
+func (*Rclunk) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rclunk) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rclunk) Type() MsgType {
+ return MsgRclunk
+}
+
+// String implements fmt.Stringer.
+func (r *Rclunk) String() string {
+ return "Rclunk{}"
+}
+
+// Tremove is a remove request.
+//
+// This will eventually be replaced by Tunlinkat.
+type Tremove struct {
+ // FID is the FID to be removed.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tremove) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tremove) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tremove) Type() MsgType {
+ return MsgTremove
+}
+
+// String implements fmt.Stringer.
+func (t *Tremove) String() string {
+ return fmt.Sprintf("Tremove{FID: %d}", t.FID)
+}
+
+// Rremove is a remove response.
+type Rremove struct {
+}
+
+// decode implements encoder.decode.
+func (*Rremove) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rremove) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rremove) Type() MsgType {
+ return MsgRremove
+}
+
+// String implements fmt.Stringer.
+func (r *Rremove) String() string {
+ return "Rremove{}"
+}
+
+// Rlerror is an error response.
+//
+// Note that this replaces the error code used in 9p.
+type Rlerror struct {
+ Error uint32
+}
+
+// decode implements encoder.decode.
+func (r *Rlerror) decode(b *buffer) {
+ r.Error = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (r *Rlerror) encode(b *buffer) {
+ b.Write32(r.Error)
+}
+
+// Type implements message.Type.
+func (*Rlerror) Type() MsgType {
+ return MsgRlerror
+}
+
+// String implements fmt.Stringer.
+func (r *Rlerror) String() string {
+ return fmt.Sprintf("Rlerror{Error: %d}", r.Error)
+}
+
+// Tauth is an authentication request.
+type Tauth struct {
+ // AuthenticationFID is the FID to attach the authentication result.
+ AuthenticationFID FID
+
+ // UserName is the user to attach.
+ UserName string
+
+ // AttachName is the attach name.
+ AttachName string
+
+ // UserID is the numeric identifier for UserName.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tauth) decode(b *buffer) {
+ t.AuthenticationFID = b.ReadFID()
+ t.UserName = b.ReadString()
+ t.AttachName = b.ReadString()
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tauth) encode(b *buffer) {
+ b.WriteFID(t.AuthenticationFID)
+ b.WriteString(t.UserName)
+ b.WriteString(t.AttachName)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (*Tauth) Type() MsgType {
+ return MsgTauth
+}
+
+// String implements fmt.Stringer.
+func (t *Tauth) String() string {
+ return fmt.Sprintf("Tauth{AuthFID: %d, UserName: %s, AttachName: %s, UID: %d", t.AuthenticationFID, t.UserName, t.AttachName, t.UID)
+}
+
+// Rauth is an authentication response.
+//
+// encode and decode are inherited directly from QID.
+type Rauth struct {
+ QID
+}
+
+// Type implements message.Type.
+func (*Rauth) Type() MsgType {
+ return MsgRauth
+}
+
+// String implements fmt.Stringer.
+func (r *Rauth) String() string {
+ return fmt.Sprintf("Rauth{QID: %s}", r.QID)
+}
+
+// Tattach is an attach request.
+type Tattach struct {
+ // FID is the FID to be attached.
+ FID FID
+
+ // Auth is the embedded authentication request.
+ //
+ // See client.Attach for information regarding authentication.
+ Auth Tauth
+}
+
+// decode implements encoder.decode.
+func (t *Tattach) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Auth.decode(b)
+}
+
+// encode implements encoder.encode.
+func (t *Tattach) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Auth.encode(b)
+}
+
+// Type implements message.Type.
+func (*Tattach) Type() MsgType {
+ return MsgTattach
+}
+
+// String implements fmt.Stringer.
+func (t *Tattach) String() string {
+ return fmt.Sprintf("Tattach{FID: %d, AuthFID: %d, UserName: %s, AttachName: %s, UID: %d}", t.FID, t.Auth.AuthenticationFID, t.Auth.UserName, t.Auth.AttachName, t.Auth.UID)
+}
+
+// Rattach is an attach response.
+type Rattach struct {
+ QID
+}
+
+// Type implements message.Type.
+func (*Rattach) Type() MsgType {
+ return MsgRattach
+}
+
+// String implements fmt.Stringer.
+func (r *Rattach) String() string {
+ return fmt.Sprintf("Rattach{QID: %s}", r.QID)
+}
+
+// Tlopen is an open request.
+type Tlopen struct {
+ // FID is the FID to be opened.
+ FID FID
+
+ // Flags are the open flags.
+ Flags OpenFlags
+}
+
+// decode implements encoder.decode.
+func (t *Tlopen) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Flags = b.ReadOpenFlags()
+}
+
+// encode implements encoder.encode.
+func (t *Tlopen) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteOpenFlags(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tlopen) Type() MsgType {
+ return MsgTlopen
+}
+
+// String implements fmt.Stringer.
+func (t *Tlopen) String() string {
+ return fmt.Sprintf("Tlopen{FID: %d, Flags: %v}", t.FID, t.Flags)
+}
+
+// Rlopen is a open response.
+type Rlopen struct {
+ // QID is the file's QID.
+ QID QID
+
+ // IoUnit is the recommended I/O unit.
+ IoUnit uint32
+
+ filePayload
+}
+
+// decode implements encoder.decode.
+func (r *Rlopen) decode(b *buffer) {
+ r.QID.decode(b)
+ r.IoUnit = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (r *Rlopen) encode(b *buffer) {
+ r.QID.encode(b)
+ b.Write32(r.IoUnit)
+}
+
+// Type implements message.Type.
+func (*Rlopen) Type() MsgType {
+ return MsgRlopen
+}
+
+// String implements fmt.Stringer.
+func (r *Rlopen) String() string {
+ return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
+}
+
+// Tlcreate is a create request.
+type Tlcreate struct {
+ // FID is the parent FID.
+ //
+ // This becomes the new file.
+ FID FID
+
+ // Name is the file name to create.
+ Name string
+
+ // Mode is the open mode (O_RDWR, etc.).
+ //
+ // Note that flags like O_TRUNC are ignored, as is O_EXCL. All
+ // create operations are exclusive.
+ OpenFlags OpenFlags
+
+ // Permissions is the set of permission bits.
+ Permissions FileMode
+
+ // GID is the group ID to use for creating the file.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tlcreate) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.OpenFlags = b.ReadOpenFlags()
+ t.Permissions = b.ReadPermissions()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tlcreate) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteOpenFlags(t.OpenFlags)
+ b.WritePermissions(t.Permissions)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tlcreate) Type() MsgType {
+ return MsgTlcreate
+}
+
+// String implements fmt.Stringer.
+func (t *Tlcreate) String() string {
+ return fmt.Sprintf("Tlcreate{FID: %d, Name: %s, OpenFlags: %s, Permissions: 0o%o, GID: %d}", t.FID, t.Name, t.OpenFlags, t.Permissions, t.GID)
+}
+
+// Rlcreate is a create response.
+//
+// The encode, decode, etc. methods are inherited from Rlopen.
+type Rlcreate struct {
+ Rlopen
+}
+
+// Type implements message.Type.
+func (*Rlcreate) Type() MsgType {
+ return MsgRlcreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rlcreate) String() string {
+ return fmt.Sprintf("Rlcreate{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
+}
+
+// Tsymlink is a symlink request.
+type Tsymlink struct {
+ // Directory is the directory FID.
+ Directory FID
+
+ // Name is the new in the directory.
+ Name string
+
+ // Target is the symlink target.
+ Target string
+
+ // GID is the owning group.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tsymlink) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Target = b.ReadString()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tsymlink) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WriteString(t.Target)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tsymlink) Type() MsgType {
+ return MsgTsymlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tsymlink) String() string {
+ return fmt.Sprintf("Tsymlink{DirectoryFID: %d, Name: %s, Target: %s, GID: %d}", t.Directory, t.Name, t.Target, t.GID)
+}
+
+// Rsymlink is a symlink response.
+type Rsymlink struct {
+ // QID is the new symlink's QID.
+ QID QID
+}
+
+// decode implements encoder.decode.
+func (r *Rsymlink) decode(b *buffer) {
+ r.QID.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rsymlink) encode(b *buffer) {
+ r.QID.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rsymlink) Type() MsgType {
+ return MsgRsymlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rsymlink) String() string {
+ return fmt.Sprintf("Rsymlink{QID: %s}", r.QID)
+}
+
+// Tlink is a link request.
+type Tlink struct {
+ // Directory is the directory to contain the link.
+ Directory FID
+
+ // FID is the target.
+ Target FID
+
+ // Name is the new source name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Tlink) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Target = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tlink) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteFID(t.Target)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Tlink) Type() MsgType {
+ return MsgTlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tlink) String() string {
+ return fmt.Sprintf("Tlink{DirectoryFID: %d, TargetFID: %d, Name: %s}", t.Directory, t.Target, t.Name)
+}
+
+// Rlink is a link response.
+type Rlink struct {
+}
+
+// Type implements message.Type.
+func (*Rlink) Type() MsgType {
+ return MsgRlink
+}
+
+// decode implements encoder.decode.
+func (*Rlink) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rlink) encode(*buffer) {
+}
+
+// String implements fmt.Stringer.
+func (r *Rlink) String() string {
+ return "Rlink{}"
+}
+
+// Trenameat is a rename request.
+type Trenameat struct {
+ // OldDirectory is the source directory.
+ OldDirectory FID
+
+ // OldName is the source file name.
+ OldName string
+
+ // NewDirectory is the target directory.
+ NewDirectory FID
+
+ // NewName is the new file name.
+ NewName string
+}
+
+// decode implements encoder.decode.
+func (t *Trenameat) decode(b *buffer) {
+ t.OldDirectory = b.ReadFID()
+ t.OldName = b.ReadString()
+ t.NewDirectory = b.ReadFID()
+ t.NewName = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Trenameat) encode(b *buffer) {
+ b.WriteFID(t.OldDirectory)
+ b.WriteString(t.OldName)
+ b.WriteFID(t.NewDirectory)
+ b.WriteString(t.NewName)
+}
+
+// Type implements message.Type.
+func (*Trenameat) Type() MsgType {
+ return MsgTrenameat
+}
+
+// String implements fmt.Stringer.
+func (t *Trenameat) String() string {
+ return fmt.Sprintf("TrenameAt{OldDirectoryFID: %d, OldName: %s, NewDirectoryFID: %d, NewName: %s}", t.OldDirectory, t.OldName, t.NewDirectory, t.NewName)
+}
+
+// Rrenameat is a rename response.
+type Rrenameat struct {
+}
+
+// decode implements encoder.decode.
+func (*Rrenameat) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rrenameat) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rrenameat) Type() MsgType {
+ return MsgRrenameat
+}
+
+// String implements fmt.Stringer.
+func (r *Rrenameat) String() string {
+ return "Rrenameat{}"
+}
+
+// Tunlinkat is an unlink request.
+type Tunlinkat struct {
+ // Directory is the originating directory.
+ Directory FID
+
+ // Name is the name of the entry to unlink.
+ Name string
+
+ // Flags are extra flags (e.g. O_DIRECTORY). These are not interpreted by p9.
+ Flags uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tunlinkat) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tunlinkat) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tunlinkat) Type() MsgType {
+ return MsgTunlinkat
+}
+
+// String implements fmt.Stringer.
+func (t *Tunlinkat) String() string {
+ return fmt.Sprintf("Tunlinkat{DirectoryFID: %d, Name: %s, Flags: 0x%X}", t.Directory, t.Name, t.Flags)
+}
+
+// Runlinkat is an unlink response.
+type Runlinkat struct {
+}
+
+// decode implements encoder.decode.
+func (*Runlinkat) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Runlinkat) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Runlinkat) Type() MsgType {
+ return MsgRunlinkat
+}
+
+// String implements fmt.Stringer.
+func (r *Runlinkat) String() string {
+ return "Runlinkat{}"
+}
+
+// Trename is a rename request.
+//
+// Note that this generally isn't used anymore, and ideally all rename calls
+// should Trenameat below.
+type Trename struct {
+ // FID is the FID to rename.
+ FID FID
+
+ // Directory is the target directory.
+ Directory FID
+
+ // Name is the new file name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Trename) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Trename) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Trename) Type() MsgType {
+ return MsgTrename
+}
+
+// String implements fmt.Stringer.
+func (t *Trename) String() string {
+ return fmt.Sprintf("Trename{FID: %d, DirectoryFID: %d, Name: %s}", t.FID, t.Directory, t.Name)
+}
+
+// Rrename is a rename response.
+type Rrename struct {
+}
+
+// decode implements encoder.decode.
+func (*Rrename) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rrename) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rrename) Type() MsgType {
+ return MsgRrename
+}
+
+// String implements fmt.Stringer.
+func (r *Rrename) String() string {
+ return "Rrename{}"
+}
+
+// Treadlink is a readlink request.
+type Treadlink struct {
+ // FID is the symlink.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Treadlink) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Treadlink) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Treadlink) Type() MsgType {
+ return MsgTreadlink
+}
+
+// String implements fmt.Stringer.
+func (t *Treadlink) String() string {
+ return fmt.Sprintf("Treadlink{FID: %d}", t.FID)
+}
+
+// Rreadlink is a readlink response.
+type Rreadlink struct {
+ // Target is the symlink target.
+ Target string
+}
+
+// decode implements encoder.decode.
+func (r *Rreadlink) decode(b *buffer) {
+ r.Target = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (r *Rreadlink) encode(b *buffer) {
+ b.WriteString(r.Target)
+}
+
+// Type implements message.Type.
+func (*Rreadlink) Type() MsgType {
+ return MsgRreadlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rreadlink) String() string {
+ return fmt.Sprintf("Rreadlink{Target: %s}", r.Target)
+}
+
+// Tread is a read request.
+type Tread struct {
+ // FID is the FID to read.
+ FID FID
+
+ // Offset indicates the file offset.
+ Offset uint64
+
+ // Count indicates the number of bytes to read.
+ Count uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tread) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Offset = b.Read64()
+ t.Count = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tread) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Offset)
+ b.Write32(t.Count)
+}
+
+// Type implements message.Type.
+func (*Tread) Type() MsgType {
+ return MsgTread
+}
+
+// String implements fmt.Stringer.
+func (t *Tread) String() string {
+ return fmt.Sprintf("Tread{FID: %d, Offset: %d, Count: %d}", t.FID, t.Offset, t.Count)
+}
+
+// Rread is the response for a Tread.
+type Rread struct {
+ // Data is the resulting data.
+ Data []byte
+}
+
+// decode implements encoder.decode.
+//
+// Data is automatically decoded via Payload.
+func (r *Rread) decode(b *buffer) {
+ count := b.Read32()
+ if count != uint32(len(r.Data)) {
+ b.markOverrun()
+ }
+}
+
+// encode implements encoder.encode.
+//
+// Data is automatically encoded via Payload.
+func (r *Rread) encode(b *buffer) {
+ b.Write32(uint32(len(r.Data)))
+}
+
+// Type implements message.Type.
+func (*Rread) Type() MsgType {
+ return MsgRread
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Rread) FixedSize() uint32 {
+ return 4
+}
+
+// Payload implements payloader.Payload.
+func (r *Rread) Payload() []byte {
+ return r.Data
+}
+
+// SetPayload implements payloader.SetPayload.
+func (r *Rread) SetPayload(p []byte) {
+ r.Data = p
+}
+
+// String implements fmt.Stringer.
+func (r *Rread) String() string {
+ return fmt.Sprintf("Rread{len(Data): %d}", len(r.Data))
+}
+
+// Twrite is a write request.
+type Twrite struct {
+ // FID is the FID to read.
+ FID FID
+
+ // Offset indicates the file offset.
+ Offset uint64
+
+ // Data is the data to be written.
+ Data []byte
+}
+
+// decode implements encoder.decode.
+func (t *Twrite) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Offset = b.Read64()
+ count := b.Read32()
+ if count != uint32(len(t.Data)) {
+ b.markOverrun()
+ }
+}
+
+// encode implements encoder.encode.
+//
+// This uses the buffer payload to avoid a copy.
+func (t *Twrite) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Offset)
+ b.Write32(uint32(len(t.Data)))
+}
+
+// Type implements message.Type.
+func (*Twrite) Type() MsgType {
+ return MsgTwrite
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Twrite) FixedSize() uint32 {
+ return 16
+}
+
+// Payload implements payloader.Payload.
+func (t *Twrite) Payload() []byte {
+ return t.Data
+}
+
+// SetPayload implements payloader.SetPayload.
+func (t *Twrite) SetPayload(p []byte) {
+ t.Data = p
+}
+
+// String implements fmt.Stringer.
+func (t *Twrite) String() string {
+ return fmt.Sprintf("Twrite{FID: %v, Offset %d, len(Data): %d}", t.FID, t.Offset, len(t.Data))
+}
+
+// Rwrite is the response for a Twrite.
+type Rwrite struct {
+ // Count indicates the number of bytes successfully written.
+ Count uint32
+}
+
+// decode implements encoder.decode.
+func (r *Rwrite) decode(b *buffer) {
+ r.Count = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (r *Rwrite) encode(b *buffer) {
+ b.Write32(r.Count)
+}
+
+// Type implements message.Type.
+func (*Rwrite) Type() MsgType {
+ return MsgRwrite
+}
+
+// String implements fmt.Stringer.
+func (r *Rwrite) String() string {
+ return fmt.Sprintf("Rwrite{Count: %d}", r.Count)
+}
+
+// Tmknod is a mknod request.
+type Tmknod struct {
+ // Directory is the parent directory.
+ Directory FID
+
+ // Name is the device name.
+ Name string
+
+ // Mode is the device mode and permissions.
+ Mode FileMode
+
+ // Major is the device major number.
+ Major uint32
+
+ // Minor is the device minor number.
+ Minor uint32
+
+ // GID is the device GID.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tmknod) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Mode = b.ReadFileMode()
+ t.Major = b.Read32()
+ t.Minor = b.Read32()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tmknod) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WriteFileMode(t.Mode)
+ b.Write32(t.Major)
+ b.Write32(t.Minor)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tmknod) Type() MsgType {
+ return MsgTmknod
+}
+
+// String implements fmt.Stringer.
+func (t *Tmknod) String() string {
+ return fmt.Sprintf("Tmknod{DirectoryFID: %d, Name: %s, Mode: 0o%o, Major: %d, Minor: %d, GID: %d}", t.Directory, t.Name, t.Mode, t.Major, t.Minor, t.GID)
+}
+
+// Rmknod is a mknod response.
+type Rmknod struct {
+ // QID is the resulting QID.
+ QID QID
+}
+
+// decode implements encoder.decode.
+func (r *Rmknod) decode(b *buffer) {
+ r.QID.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rmknod) encode(b *buffer) {
+ r.QID.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rmknod) Type() MsgType {
+ return MsgRmknod
+}
+
+// String implements fmt.Stringer.
+func (r *Rmknod) String() string {
+ return fmt.Sprintf("Rmknod{QID: %s}", r.QID)
+}
+
+// Tmkdir is a mkdir request.
+type Tmkdir struct {
+ // Directory is the parent directory.
+ Directory FID
+
+ // Name is the new directory name.
+ Name string
+
+ // Permissions is the set of permission bits.
+ Permissions FileMode
+
+ // GID is the owning group.
+ GID GID
+}
+
+// decode implements encoder.decode.
+func (t *Tmkdir) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Permissions = b.ReadPermissions()
+ t.GID = b.ReadGID()
+}
+
+// encode implements encoder.encode.
+func (t *Tmkdir) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WritePermissions(t.Permissions)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tmkdir) Type() MsgType {
+ return MsgTmkdir
+}
+
+// String implements fmt.Stringer.
+func (t *Tmkdir) String() string {
+ return fmt.Sprintf("Tmkdir{DirectoryFID: %d, Name: %s, Permissions: 0o%o, GID: %d}", t.Directory, t.Name, t.Permissions, t.GID)
+}
+
+// Rmkdir is a mkdir response.
+type Rmkdir struct {
+ // QID is the resulting QID.
+ QID QID
+}
+
+// decode implements encoder.decode.
+func (r *Rmkdir) decode(b *buffer) {
+ r.QID.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rmkdir) encode(b *buffer) {
+ r.QID.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rmkdir) Type() MsgType {
+ return MsgRmkdir
+}
+
+// String implements fmt.Stringer.
+func (r *Rmkdir) String() string {
+ return fmt.Sprintf("Rmkdir{QID: %s}", r.QID)
+}
+
+// Tgetattr is a getattr request.
+type Tgetattr struct {
+ // FID is the FID to get attributes for.
+ FID FID
+
+ // AttrMask is the set of attributes to get.
+ AttrMask AttrMask
+}
+
+// decode implements encoder.decode.
+func (t *Tgetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.AttrMask.decode(b)
+}
+
+// encode implements encoder.encode.
+func (t *Tgetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.AttrMask.encode(b)
+}
+
+// Type implements message.Type.
+func (*Tgetattr) Type() MsgType {
+ return MsgTgetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetattr) String() string {
+ return fmt.Sprintf("Tgetattr{FID: %d, AttrMask: %s}", t.FID, t.AttrMask)
+}
+
+// Rgetattr is a getattr response.
+type Rgetattr struct {
+ // Valid indicates which fields are valid.
+ Valid AttrMask
+
+ // QID is the QID for this file.
+ QID
+
+ // Attr is the set of attributes.
+ Attr Attr
+}
+
+// decode implements encoder.decode.
+func (r *Rgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.QID.decode(b)
+ r.Attr.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.QID.encode(b)
+ r.Attr.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rgetattr) Type() MsgType {
+ return MsgRgetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetattr) String() string {
+ return fmt.Sprintf("Rgetattr{Valid: %v, QID: %s, Attr: %s}", r.Valid, r.QID, r.Attr)
+}
+
+// Tsetattr is a setattr request.
+type Tsetattr struct {
+ // FID is the FID to change.
+ FID FID
+
+ // Valid is the set of bits which will be used.
+ Valid SetAttrMask
+
+ // SetAttr is the set request.
+ SetAttr SetAttr
+}
+
+// decode implements encoder.decode.
+func (t *Tsetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Valid.decode(b)
+ t.SetAttr.decode(b)
+}
+
+// encode implements encoder.encode.
+func (t *Tsetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Valid.encode(b)
+ t.SetAttr.encode(b)
+}
+
+// Type implements message.Type.
+func (*Tsetattr) Type() MsgType {
+ return MsgTsetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetattr) String() string {
+ return fmt.Sprintf("Tsetattr{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr)
+}
+
+// Rsetattr is a setattr response.
+type Rsetattr struct {
+}
+
+// decode implements encoder.decode.
+func (*Rsetattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rsetattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetattr) Type() MsgType {
+ return MsgRsetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetattr) String() string {
+ return "Rsetattr{}"
+}
+
+// Tallocate is an allocate request. This is an extension to 9P protocol, not
+// present in the 9P2000.L standard.
+type Tallocate struct {
+ FID FID
+ Mode AllocateMode
+ Offset uint64
+ Length uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tallocate) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Mode.decode(b)
+ t.Offset = b.Read64()
+ t.Length = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tallocate) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Mode.encode(b)
+ b.Write64(t.Offset)
+ b.Write64(t.Length)
+}
+
+// Type implements message.Type.
+func (*Tallocate) Type() MsgType {
+ return MsgTallocate
+}
+
+// String implements fmt.Stringer.
+func (t *Tallocate) String() string {
+ return fmt.Sprintf("Tallocate{FID: %d, Offset: %d, Length: %d}", t.FID, t.Offset, t.Length)
+}
+
+// Rallocate is an allocate response.
+type Rallocate struct {
+}
+
+// decode implements encoder.decode.
+func (*Rallocate) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rallocate) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rallocate) Type() MsgType {
+ return MsgRallocate
+}
+
+// String implements fmt.Stringer.
+func (r *Rallocate) String() string {
+ return "Rallocate{}"
+}
+
+// Tlistxattr is a listxattr request.
+type Tlistxattr struct {
+ // FID refers to the file on which to list xattrs.
+ FID FID
+
+ // Size is the buffer size for the xattr list.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tlistxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tlistxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tlistxattr) Type() MsgType {
+ return MsgTlistxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tlistxattr) String() string {
+ return fmt.Sprintf("Tlistxattr{FID: %d, Size: %d}", t.FID, t.Size)
+}
+
+// Rlistxattr is a listxattr response.
+type Rlistxattr struct {
+ // Xattrs is a list of extended attribute names.
+ Xattrs []string
+}
+
+// decode implements encoder.decode.
+func (r *Rlistxattr) decode(b *buffer) {
+ n := b.Read16()
+ r.Xattrs = r.Xattrs[:0]
+ for i := 0; i < int(n); i++ {
+ r.Xattrs = append(r.Xattrs, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rlistxattr) encode(b *buffer) {
+ b.Write16(uint16(len(r.Xattrs)))
+ for _, x := range r.Xattrs {
+ b.WriteString(x)
+ }
+}
+
+// Type implements message.Type.
+func (*Rlistxattr) Type() MsgType {
+ return MsgRlistxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rlistxattr) String() string {
+ return fmt.Sprintf("Rlistxattr{Xattrs: %v}", r.Xattrs)
+}
+
+// Txattrwalk walks extended attributes.
+type Txattrwalk struct {
+ // FID is the FID to check for attributes.
+ FID FID
+
+ // NewFID is the new FID associated with the attributes.
+ NewFID FID
+
+ // Name is the attribute name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Txattrwalk) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Txattrwalk) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Txattrwalk) Type() MsgType {
+ return MsgTxattrwalk
+}
+
+// String implements fmt.Stringer.
+func (t *Txattrwalk) String() string {
+ return fmt.Sprintf("Txattrwalk{FID: %d, NewFID: %d, Name: %s}", t.FID, t.NewFID, t.Name)
+}
+
+// Rxattrwalk is a xattrwalk response.
+type Rxattrwalk struct {
+ // Size is the size of the extended attribute.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (r *Rxattrwalk) decode(b *buffer) {
+ r.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (r *Rxattrwalk) encode(b *buffer) {
+ b.Write64(r.Size)
+}
+
+// Type implements message.Type.
+func (*Rxattrwalk) Type() MsgType {
+ return MsgRxattrwalk
+}
+
+// String implements fmt.Stringer.
+func (r *Rxattrwalk) String() string {
+ return fmt.Sprintf("Rxattrwalk{Size: %d}", r.Size)
+}
+
+// Txattrcreate prepare to set extended attributes.
+type Txattrcreate struct {
+ // FID is input/output parameter, it identifies the file on which
+ // extended attributes will be set but after successful Rxattrcreate
+ // it is used to write the extended attribute value.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Size of the attribute value. When the FID is clunked it has to match
+ // the number of bytes written to the FID.
+ AttrSize uint64
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// decode implements encoder.decode.
+func (t *Txattrcreate) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.AttrSize = b.Read64()
+ t.Flags = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Txattrcreate) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.AttrSize)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Txattrcreate) Type() MsgType {
+ return MsgTxattrcreate
+}
+
+// String implements fmt.Stringer.
+func (t *Txattrcreate) String() string {
+ return fmt.Sprintf("Txattrcreate{FID: %d, Name: %s, AttrSize: %d, Flags: %d}", t.FID, t.Name, t.AttrSize, t.Flags)
+}
+
+// Rxattrcreate is a xattrcreate response.
+type Rxattrcreate struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rxattrcreate) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rxattrcreate) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rxattrcreate) Type() MsgType {
+ return MsgRxattrcreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rxattrcreate) String() string {
+ return "Rxattrcreate{}"
+}
+
+// Tgetxattr is a getxattr request.
+type Tgetxattr struct {
+ // FID refers to the file for which to get xattrs.
+ FID FID
+
+ // Name is the xattr to get.
+ Name string
+
+ // Size is the buffer size for the xattr to get.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tgetxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tgetxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tgetxattr) Type() MsgType {
+ return MsgTgetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetxattr) String() string {
+ return fmt.Sprintf("Tgetxattr{FID: %d, Name: %s, Size: %d}", t.FID, t.Name, t.Size)
+}
+
+// Rgetxattr is a getxattr response.
+type Rgetxattr struct {
+ // Value is the extended attribute value.
+ Value string
+}
+
+// decode implements encoder.decode.
+func (r *Rgetxattr) decode(b *buffer) {
+ r.Value = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (r *Rgetxattr) encode(b *buffer) {
+ b.WriteString(r.Value)
+}
+
+// Type implements message.Type.
+func (*Rgetxattr) Type() MsgType {
+ return MsgRgetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetxattr) String() string {
+ return fmt.Sprintf("Rgetxattr{Value: %s}", r.Value)
+}
+
+// Tsetxattr sets extended attributes.
+type Tsetxattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Value is the attribute value.
+ Value string
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tsetxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Value = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tsetxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteString(t.Value)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tsetxattr) Type() MsgType {
+ return MsgTsetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetxattr) String() string {
+ return fmt.Sprintf("Tsetxattr{FID: %d, Name: %s, Value: %s, Flags: %d}", t.FID, t.Name, t.Value, t.Flags)
+}
+
+// Rsetxattr is a setxattr response.
+type Rsetxattr struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rsetxattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rsetxattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetxattr) Type() MsgType {
+ return MsgRsetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetxattr) String() string {
+ return "Rsetxattr{}"
+}
+
+// Tremovexattr is a removexattr request.
+type Tremovexattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Tremovexattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tremovexattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Tremovexattr) Type() MsgType {
+ return MsgTremovexattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tremovexattr) String() string {
+ return fmt.Sprintf("Tremovexattr{FID: %d, Name: %s}", t.FID, t.Name)
+}
+
+// Rremovexattr is a removexattr response.
+type Rremovexattr struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rremovexattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rremovexattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rremovexattr) Type() MsgType {
+ return MsgRremovexattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rremovexattr) String() string {
+ return "Rremovexattr{}"
+}
+
+// Treaddir is a readdir request.
+type Treaddir struct {
+ // Directory is the directory FID to read.
+ Directory FID
+
+ // Offset is the offset to read at.
+ Offset uint64
+
+ // Count is the number of bytes to read.
+ Count uint32
+}
+
+// decode implements encoder.decode.
+func (t *Treaddir) decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Offset = b.Read64()
+ t.Count = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Treaddir) encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.Write64(t.Offset)
+ b.Write32(t.Count)
+}
+
+// Type implements message.Type.
+func (*Treaddir) Type() MsgType {
+ return MsgTreaddir
+}
+
+// String implements fmt.Stringer.
+func (t *Treaddir) String() string {
+ return fmt.Sprintf("Treaddir{DirectoryFID: %d, Offset: %d, Count: %d}", t.Directory, t.Offset, t.Count)
+}
+
+// Rreaddir is a readdir response.
+type Rreaddir struct {
+ // Count is the byte limit.
+ //
+ // This should always be set from the Treaddir request.
+ Count uint32
+
+ // Entries are the resulting entries.
+ //
+ // This may be constructed in decode.
+ Entries []Dirent
+
+ // payload is the encoded payload.
+ //
+ // This is constructed by encode.
+ payload []byte
+}
+
+// decode implements encoder.decode.
+func (r *Rreaddir) decode(b *buffer) {
+ r.Count = b.Read32()
+ entriesBuf := buffer{data: r.payload}
+ r.Entries = r.Entries[:0]
+ for {
+ var d Dirent
+ d.decode(&entriesBuf)
+ if entriesBuf.isOverrun() {
+ // Couldn't decode a complete entry.
+ break
+ }
+ r.Entries = append(r.Entries, d)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rreaddir) encode(b *buffer) {
+ entriesBuf := buffer{}
+ payloadSize := 0
+ for _, d := range r.Entries {
+ d.encode(&entriesBuf)
+ if len(entriesBuf.data) > int(r.Count) {
+ break
+ }
+ payloadSize = len(entriesBuf.data)
+ }
+ r.Count = uint32(payloadSize)
+ r.payload = entriesBuf.data[:payloadSize]
+ b.Write32(r.Count)
+}
+
+// Type implements message.Type.
+func (*Rreaddir) Type() MsgType {
+ return MsgRreaddir
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Rreaddir) FixedSize() uint32 {
+ return 4
+}
+
+// Payload implements payloader.Payload.
+func (r *Rreaddir) Payload() []byte {
+ return r.payload
+}
+
+// SetPayload implements payloader.SetPayload.
+func (r *Rreaddir) SetPayload(p []byte) {
+ r.payload = p
+}
+
+// String implements fmt.Stringer.
+func (r *Rreaddir) String() string {
+ return fmt.Sprintf("Rreaddir{Count: %d, Entries: %s}", r.Count, r.Entries)
+}
+
+// Tfsync is an fsync request.
+type Tfsync struct {
+ // FID is the fid to sync.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tfsync) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tfsync) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tfsync) Type() MsgType {
+ return MsgTfsync
+}
+
+// String implements fmt.Stringer.
+func (t *Tfsync) String() string {
+ return fmt.Sprintf("Tfsync{FID: %d}", t.FID)
+}
+
+// Rfsync is an fsync response.
+type Rfsync struct {
+}
+
+// decode implements encoder.decode.
+func (*Rfsync) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rfsync) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rfsync) Type() MsgType {
+ return MsgRfsync
+}
+
+// String implements fmt.Stringer.
+func (r *Rfsync) String() string {
+ return "Rfsync{}"
+}
+
+// Tstatfs is a stat request.
+type Tstatfs struct {
+ // FID is the root.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tstatfs) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tstatfs) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tstatfs) Type() MsgType {
+ return MsgTstatfs
+}
+
+// String implements fmt.Stringer.
+func (t *Tstatfs) String() string {
+ return fmt.Sprintf("Tstatfs{FID: %d}", t.FID)
+}
+
+// Rstatfs is the response for a Tstatfs.
+type Rstatfs struct {
+ // FSStat is the stat result.
+ FSStat FSStat
+}
+
+// decode implements encoder.decode.
+func (r *Rstatfs) decode(b *buffer) {
+ r.FSStat.decode(b)
+}
+
+// encode implements encoder.encode.
+func (r *Rstatfs) encode(b *buffer) {
+ r.FSStat.encode(b)
+}
+
+// Type implements message.Type.
+func (*Rstatfs) Type() MsgType {
+ return MsgRstatfs
+}
+
+// String implements fmt.Stringer.
+func (r *Rstatfs) String() string {
+ return fmt.Sprintf("Rstatfs{FSStat: %v}", r.FSStat)
+}
+
+// Tflushf is a flush file request, not to be confused with Tflush.
+type Tflushf struct {
+ // FID is the FID to be flushed.
+ FID FID
+}
+
+// decode implements encoder.decode.
+func (t *Tflushf) decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// encode implements encoder.encode.
+func (t *Tflushf) encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tflushf) Type() MsgType {
+ return MsgTflushf
+}
+
+// String implements fmt.Stringer.
+func (t *Tflushf) String() string {
+ return fmt.Sprintf("Tflushf{FID: %d}", t.FID)
+}
+
+// Rflushf is a flush file response.
+type Rflushf struct {
+}
+
+// decode implements encoder.decode.
+func (*Rflushf) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rflushf) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rflushf) Type() MsgType {
+ return MsgRflushf
+}
+
+// String implements fmt.Stringer.
+func (*Rflushf) String() string {
+ return "Rflushf{}"
+}
+
+// Twalkgetattr is a walk request.
+type Twalkgetattr struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // NewFID is the resulting FID.
+ NewFID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// decode implements encoder.decode.
+func (t *Twalkgetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (t *Twalkgetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Twalkgetattr) Type() MsgType {
+ return MsgTwalkgetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Twalkgetattr) String() string {
+ return fmt.Sprintf("Twalkgetattr{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names)
+}
+
+// Rwalkgetattr is a walk response.
+type Rwalkgetattr struct {
+ // Valid indicates which fields are valid in the Attr below.
+ Valid AttrMask
+
+ // Attr is the set of attributes for the last QID (the file walked to).
+ Attr Attr
+
+ // QIDs are the set of QIDs returned.
+ QIDs []QID
+}
+
+// decode implements encoder.decode.
+func (r *Rwalkgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.Attr.decode(b)
+ n := b.Read16()
+ r.QIDs = r.QIDs[:0]
+ for i := 0; i < int(n); i++ {
+ var q QID
+ q.decode(b)
+ r.QIDs = append(r.QIDs, q)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rwalkgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.Attr.encode(b)
+ b.Write16(uint16(len(r.QIDs)))
+ for _, q := range r.QIDs {
+ q.encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rwalkgetattr) Type() MsgType {
+ return MsgRwalkgetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rwalkgetattr) String() string {
+ return fmt.Sprintf("Rwalkgetattr{Valid: %s, Attr: %s, QIDs: %v}", r.Valid, r.Attr, r.QIDs)
+}
+
+// Tucreate is a Tlcreate message that includes a UID.
+type Tucreate struct {
+ Tlcreate
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tucreate) decode(b *buffer) {
+ t.Tlcreate.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tucreate) encode(b *buffer) {
+ t.Tlcreate.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tucreate) Type() MsgType {
+ return MsgTucreate
+}
+
+// String implements fmt.Stringer.
+func (t *Tucreate) String() string {
+ return fmt.Sprintf("Tucreate{Tlcreate: %v, UID: %d}", &t.Tlcreate, t.UID)
+}
+
+// Rucreate is a file creation response.
+type Rucreate struct {
+ Rlcreate
+}
+
+// Type implements message.Type.
+func (*Rucreate) Type() MsgType {
+ return MsgRucreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rucreate) String() string {
+ return fmt.Sprintf("Rucreate{%v}", &r.Rlcreate)
+}
+
+// Tumkdir is a Tmkdir message that includes a UID.
+type Tumkdir struct {
+ Tmkdir
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tumkdir) decode(b *buffer) {
+ t.Tmkdir.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tumkdir) encode(b *buffer) {
+ t.Tmkdir.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tumkdir) Type() MsgType {
+ return MsgTumkdir
+}
+
+// String implements fmt.Stringer.
+func (t *Tumkdir) String() string {
+ return fmt.Sprintf("Tumkdir{Tmkdir: %v, UID: %d}", &t.Tmkdir, t.UID)
+}
+
+// Rumkdir is a umkdir response.
+type Rumkdir struct {
+ Rmkdir
+}
+
+// Type implements message.Type.
+func (*Rumkdir) Type() MsgType {
+ return MsgRumkdir
+}
+
+// String implements fmt.Stringer.
+func (r *Rumkdir) String() string {
+ return fmt.Sprintf("Rumkdir{%v}", &r.Rmkdir)
+}
+
+// Tumknod is a Tmknod message that includes a UID.
+type Tumknod struct {
+ Tmknod
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tumknod) decode(b *buffer) {
+ t.Tmknod.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tumknod) encode(b *buffer) {
+ t.Tmknod.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tumknod) Type() MsgType {
+ return MsgTumknod
+}
+
+// String implements fmt.Stringer.
+func (t *Tumknod) String() string {
+ return fmt.Sprintf("Tumknod{Tmknod: %v, UID: %d}", &t.Tmknod, t.UID)
+}
+
+// Rumknod is a umknod response.
+type Rumknod struct {
+ Rmknod
+}
+
+// Type implements message.Type.
+func (*Rumknod) Type() MsgType {
+ return MsgRumknod
+}
+
+// String implements fmt.Stringer.
+func (r *Rumknod) String() string {
+ return fmt.Sprintf("Rumknod{%v}", &r.Rmknod)
+}
+
+// Tusymlink is a Tsymlink message that includes a UID.
+type Tusymlink struct {
+ Tsymlink
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// decode implements encoder.decode.
+func (t *Tusymlink) decode(b *buffer) {
+ t.Tsymlink.decode(b)
+ t.UID = b.ReadUID()
+}
+
+// encode implements encoder.encode.
+func (t *Tusymlink) encode(b *buffer) {
+ t.Tsymlink.encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tusymlink) Type() MsgType {
+ return MsgTusymlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tusymlink) String() string {
+ return fmt.Sprintf("Tusymlink{Tsymlink: %v, UID: %d}", &t.Tsymlink, t.UID)
+}
+
+// Rusymlink is a usymlink response.
+type Rusymlink struct {
+ Rsymlink
+}
+
+// Type implements message.Type.
+func (*Rusymlink) Type() MsgType {
+ return MsgRusymlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rusymlink) String() string {
+ return fmt.Sprintf("Rusymlink{%v}", &r.Rsymlink)
+}
+
+// Tlconnect is a connect request.
+type Tlconnect struct {
+ // FID is the FID to be connected.
+ FID FID
+
+ // Flags are the connect flags.
+ Flags ConnectFlags
+}
+
+// decode implements encoder.decode.
+func (t *Tlconnect) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Flags = b.ReadConnectFlags()
+}
+
+// encode implements encoder.encode.
+func (t *Tlconnect) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteConnectFlags(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tlconnect) Type() MsgType {
+ return MsgTlconnect
+}
+
+// String implements fmt.Stringer.
+func (t *Tlconnect) String() string {
+ return fmt.Sprintf("Tlconnect{FID: %d, Flags: %v}", t.FID, t.Flags)
+}
+
+// Rlconnect is a connect response.
+type Rlconnect struct {
+ filePayload
+}
+
+// decode implements encoder.decode.
+func (r *Rlconnect) decode(*buffer) {}
+
+// encode implements encoder.encode.
+func (r *Rlconnect) encode(*buffer) {}
+
+// Type implements message.Type.
+func (*Rlconnect) Type() MsgType {
+ return MsgRlconnect
+}
+
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+}
+
+// Tchannel creates a new channel.
+type Tchannel struct {
+ // ID is the channel ID.
+ ID uint32
+
+ // Control is 0 if the Rchannel response should provide the flipcall
+ // component of the channel, and 1 if the Rchannel response should
+ // provide the fdchannel component of the channel.
+ Control uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tchannel) decode(b *buffer) {
+ t.ID = b.Read32()
+ t.Control = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tchannel) encode(b *buffer) {
+ b.Write32(t.ID)
+ b.Write32(t.Control)
+}
+
+// Type implements message.Type.
+func (*Tchannel) Type() MsgType {
+ return MsgTchannel
+}
+
+// String implements fmt.Stringer.
+func (t *Tchannel) String() string {
+ return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control)
+}
+
+// Rchannel is the channel response.
+type Rchannel struct {
+ Offset uint64
+ Length uint64
+ filePayload
+}
+
+// decode implements encoder.decode.
+func (r *Rchannel) decode(b *buffer) {
+ r.Offset = b.Read64()
+ r.Length = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (r *Rchannel) encode(b *buffer) {
+ b.Write64(r.Offset)
+ b.Write64(r.Length)
+}
+
+// Type implements message.Type.
+func (*Rchannel) Type() MsgType {
+ return MsgRchannel
+}
+
+// String implements fmt.Stringer.
+func (r *Rchannel) String() string {
+ return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
+}
+
+const maxCacheSize = 3
+
+// msgFactory is used to reduce allocations by caching messages for reuse.
+type msgFactory struct {
+ create func() message
+ cache chan message
+}
+
+// msgRegistry indexes all message factories by type.
+var msgRegistry registry
+
+type registry struct {
+ factories [math.MaxUint8]msgFactory
+
+ // largestFixedSize is computed so that given some message size M, you can
+ // compute the maximum payload size (e.g. for Twrite, Rread) with
+ // M-largestFixedSize. You could do this individual on a per-message basis,
+ // but it's easier to compute a single maximum safe payload.
+ largestFixedSize uint32
+}
+
+// get returns a new message by type.
+//
+// An error is returned in the case of an unknown message.
+//
+// This takes, and ignores, a message tag so that it may be used directly as a
+// lookupTagAndType function for recv (by design).
+func (r *registry) get(_ Tag, t MsgType) (message, error) {
+ entry := &r.factories[t]
+ if entry.create == nil {
+ return nil, &ErrInvalidMsgType{t}
+ }
+
+ select {
+ case msg := <-entry.cache:
+ return msg, nil
+ default:
+ return entry.create(), nil
+ }
+}
+
+func (r *registry) put(msg message) {
+ if p, ok := msg.(payloader); ok {
+ p.SetPayload(nil)
+ }
+ if f, ok := msg.(filer); ok {
+ f.SetFilePayload(nil)
+ }
+
+ entry := &r.factories[msg.Type()]
+ select {
+ case entry.cache <- msg:
+ default:
+ }
+}
+
+// register registers the given message type.
+//
+// This may cause panic on failure and should only be used from init.
+func (r *registry) register(t MsgType, fn func() message) {
+ if int(t) >= len(r.factories) {
+ panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(r.factories)))
+ }
+ if r.factories[t].create != nil {
+ panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, r.factories[t].create(), fn()))
+ }
+ r.factories[t] = msgFactory{
+ create: fn,
+ cache: make(chan message, maxCacheSize),
+ }
+
+ if size := calculateSize(fn()); size > r.largestFixedSize {
+ r.largestFixedSize = size
+ }
+}
+
+func calculateSize(m message) uint32 {
+ if p, ok := m.(payloader); ok {
+ return p.FixedSize()
+ }
+ var dataBuf buffer
+ m.encode(&dataBuf)
+ return uint32(len(dataBuf.data))
+}
+
+func init() {
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+ msgRegistry.register(MsgTstatfs, func() message { return &Tstatfs{} })
+ msgRegistry.register(MsgRstatfs, func() message { return &Rstatfs{} })
+ msgRegistry.register(MsgTlopen, func() message { return &Tlopen{} })
+ msgRegistry.register(MsgRlopen, func() message { return &Rlopen{} })
+ msgRegistry.register(MsgTlcreate, func() message { return &Tlcreate{} })
+ msgRegistry.register(MsgRlcreate, func() message { return &Rlcreate{} })
+ msgRegistry.register(MsgTsymlink, func() message { return &Tsymlink{} })
+ msgRegistry.register(MsgRsymlink, func() message { return &Rsymlink{} })
+ msgRegistry.register(MsgTmknod, func() message { return &Tmknod{} })
+ msgRegistry.register(MsgRmknod, func() message { return &Rmknod{} })
+ msgRegistry.register(MsgTrename, func() message { return &Trename{} })
+ msgRegistry.register(MsgRrename, func() message { return &Rrename{} })
+ msgRegistry.register(MsgTreadlink, func() message { return &Treadlink{} })
+ msgRegistry.register(MsgRreadlink, func() message { return &Rreadlink{} })
+ msgRegistry.register(MsgTgetattr, func() message { return &Tgetattr{} })
+ msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} })
+ msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} })
+ msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} })
+ msgRegistry.register(MsgTlistxattr, func() message { return &Tlistxattr{} })
+ msgRegistry.register(MsgRlistxattr, func() message { return &Rlistxattr{} })
+ msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
+ msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
+ msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
+ msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTgetxattr, func() message { return &Tgetxattr{} })
+ msgRegistry.register(MsgRgetxattr, func() message { return &Rgetxattr{} })
+ msgRegistry.register(MsgTsetxattr, func() message { return &Tsetxattr{} })
+ msgRegistry.register(MsgRsetxattr, func() message { return &Rsetxattr{} })
+ msgRegistry.register(MsgTremovexattr, func() message { return &Tremovexattr{} })
+ msgRegistry.register(MsgRremovexattr, func() message { return &Rremovexattr{} })
+ msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
+ msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
+ msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
+ msgRegistry.register(MsgRfsync, func() message { return &Rfsync{} })
+ msgRegistry.register(MsgTlink, func() message { return &Tlink{} })
+ msgRegistry.register(MsgRlink, func() message { return &Rlink{} })
+ msgRegistry.register(MsgTmkdir, func() message { return &Tmkdir{} })
+ msgRegistry.register(MsgRmkdir, func() message { return &Rmkdir{} })
+ msgRegistry.register(MsgTrenameat, func() message { return &Trenameat{} })
+ msgRegistry.register(MsgRrenameat, func() message { return &Rrenameat{} })
+ msgRegistry.register(MsgTunlinkat, func() message { return &Tunlinkat{} })
+ msgRegistry.register(MsgRunlinkat, func() message { return &Runlinkat{} })
+ msgRegistry.register(MsgTversion, func() message { return &Tversion{} })
+ msgRegistry.register(MsgRversion, func() message { return &Rversion{} })
+ msgRegistry.register(MsgTauth, func() message { return &Tauth{} })
+ msgRegistry.register(MsgRauth, func() message { return &Rauth{} })
+ msgRegistry.register(MsgTattach, func() message { return &Tattach{} })
+ msgRegistry.register(MsgRattach, func() message { return &Rattach{} })
+ msgRegistry.register(MsgTflush, func() message { return &Tflush{} })
+ msgRegistry.register(MsgRflush, func() message { return &Rflush{} })
+ msgRegistry.register(MsgTwalk, func() message { return &Twalk{} })
+ msgRegistry.register(MsgRwalk, func() message { return &Rwalk{} })
+ msgRegistry.register(MsgTread, func() message { return &Tread{} })
+ msgRegistry.register(MsgRread, func() message { return &Rread{} })
+ msgRegistry.register(MsgTwrite, func() message { return &Twrite{} })
+ msgRegistry.register(MsgRwrite, func() message { return &Rwrite{} })
+ msgRegistry.register(MsgTclunk, func() message { return &Tclunk{} })
+ msgRegistry.register(MsgRclunk, func() message { return &Rclunk{} })
+ msgRegistry.register(MsgTremove, func() message { return &Tremove{} })
+ msgRegistry.register(MsgRremove, func() message { return &Rremove{} })
+ msgRegistry.register(MsgTflushf, func() message { return &Tflushf{} })
+ msgRegistry.register(MsgRflushf, func() message { return &Rflushf{} })
+ msgRegistry.register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} })
+ msgRegistry.register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} })
+ msgRegistry.register(MsgTucreate, func() message { return &Tucreate{} })
+ msgRegistry.register(MsgRucreate, func() message { return &Rucreate{} })
+ msgRegistry.register(MsgTumkdir, func() message { return &Tumkdir{} })
+ msgRegistry.register(MsgRumkdir, func() message { return &Rumkdir{} })
+ msgRegistry.register(MsgTumknod, func() message { return &Tumknod{} })
+ msgRegistry.register(MsgRumknod, func() message { return &Rumknod{} })
+ msgRegistry.register(MsgTusymlink, func() message { return &Tusymlink{} })
+ msgRegistry.register(MsgRusymlink, func() message { return &Rusymlink{} })
+ msgRegistry.register(MsgTlconnect, func() message { return &Tlconnect{} })
+ msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
+ msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
+ msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
+ msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
+}
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
new file mode 100644
index 000000000..7facc9f5e
--- /dev/null
+++ b/pkg/p9/messages_test.go
@@ -0,0 +1,483 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+func TestEncodeDecode(t *testing.T) {
+ objs := []encoder{
+ &QID{
+ Type: 1,
+ Version: 2,
+ Path: 3,
+ },
+ &FSStat{
+ Type: 1,
+ BlockSize: 2,
+ Blocks: 3,
+ BlocksFree: 4,
+ BlocksAvailable: 5,
+ Files: 6,
+ FilesFree: 7,
+ FSID: 8,
+ NameLength: 9,
+ },
+ &AttrMask{
+ Mode: true,
+ NLink: true,
+ UID: true,
+ GID: true,
+ RDev: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ INo: true,
+ Size: true,
+ Blocks: true,
+ BTime: true,
+ Gen: true,
+ DataVersion: true,
+ },
+ &Attr{
+ Mode: Exec,
+ UID: 2,
+ GID: 3,
+ NLink: 4,
+ RDev: 5,
+ Size: 6,
+ BlockSize: 7,
+ Blocks: 8,
+ ATimeSeconds: 9,
+ ATimeNanoSeconds: 10,
+ MTimeSeconds: 11,
+ MTimeNanoSeconds: 12,
+ CTimeSeconds: 13,
+ CTimeNanoSeconds: 14,
+ BTimeSeconds: 15,
+ BTimeNanoSeconds: 16,
+ Gen: 17,
+ DataVersion: 18,
+ },
+ &SetAttrMask{
+ Permissions: true,
+ UID: true,
+ GID: true,
+ Size: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ ATimeNotSystemTime: true,
+ MTimeNotSystemTime: true,
+ },
+ &SetAttr{
+ Permissions: 1,
+ UID: 2,
+ GID: 3,
+ Size: 4,
+ ATimeSeconds: 5,
+ ATimeNanoSeconds: 6,
+ MTimeSeconds: 7,
+ MTimeNanoSeconds: 8,
+ },
+ &Dirent{
+ QID: QID{Type: 1},
+ Offset: 2,
+ Type: 3,
+ Name: "a",
+ },
+ &Rlerror{
+ Error: 1,
+ },
+ &Tstatfs{
+ FID: 1,
+ },
+ &Rstatfs{
+ FSStat: FSStat{Type: 1},
+ },
+ &Tlopen{
+ FID: 1,
+ Flags: WriteOnly,
+ },
+ &Rlopen{
+ QID: QID{Type: 1},
+ IoUnit: 2,
+ },
+ &Tlconnect{
+ FID: 1,
+ },
+ &Rlconnect{},
+ &Tlcreate{
+ FID: 1,
+ Name: "a",
+ OpenFlags: 2,
+ Permissions: 3,
+ GID: 4,
+ },
+ &Rlcreate{
+ Rlopen{QID: QID{Type: 1}},
+ },
+ &Tsymlink{
+ Directory: 1,
+ Name: "a",
+ Target: "b",
+ GID: 2,
+ },
+ &Rsymlink{
+ QID: QID{Type: 1},
+ },
+ &Tmknod{
+ Directory: 1,
+ Name: "a",
+ Mode: 2,
+ Major: 3,
+ Minor: 4,
+ GID: 5,
+ },
+ &Rmknod{
+ QID: QID{Type: 1},
+ },
+ &Trename{
+ FID: 1,
+ Directory: 2,
+ Name: "a",
+ },
+ &Rrename{},
+ &Treadlink{
+ FID: 1,
+ },
+ &Rreadlink{
+ Target: "a",
+ },
+ &Tgetattr{
+ FID: 1,
+ AttrMask: AttrMask{Mode: true},
+ },
+ &Rgetattr{
+ Valid: AttrMask{Mode: true},
+ QID: QID{Type: 1},
+ Attr: Attr{Mode: Write},
+ },
+ &Tsetattr{
+ FID: 1,
+ Valid: SetAttrMask{Permissions: true},
+ SetAttr: SetAttr{Permissions: Write},
+ },
+ &Rsetattr{},
+ &Txattrwalk{
+ FID: 1,
+ NewFID: 2,
+ Name: "a",
+ },
+ &Rxattrwalk{
+ Size: 1,
+ },
+ &Txattrcreate{
+ FID: 1,
+ Name: "a",
+ AttrSize: 2,
+ Flags: 3,
+ },
+ &Rxattrcreate{},
+ &Tgetxattr{
+ FID: 1,
+ Name: "abc",
+ Size: 2,
+ },
+ &Rgetxattr{
+ Value: "xyz",
+ },
+ &Tsetxattr{
+ FID: 1,
+ Name: "abc",
+ Value: "xyz",
+ Flags: 2,
+ },
+ &Rsetxattr{},
+ &Treaddir{
+ Directory: 1,
+ Offset: 2,
+ Count: 3,
+ },
+ &Rreaddir{
+ // Count must be sufficient to encode a dirent.
+ Count: 0x1a,
+ Entries: []Dirent{{QID: QID{Type: 2}}},
+ },
+ &Tfsync{
+ FID: 1,
+ },
+ &Rfsync{},
+ &Tlink{
+ Directory: 1,
+ Target: 2,
+ Name: "a",
+ },
+ &Rlink{},
+ &Tmkdir{
+ Directory: 1,
+ Name: "a",
+ Permissions: 2,
+ GID: 3,
+ },
+ &Rmkdir{
+ QID: QID{Type: 1},
+ },
+ &Trenameat{
+ OldDirectory: 1,
+ OldName: "a",
+ NewDirectory: 2,
+ NewName: "b",
+ },
+ &Rrenameat{},
+ &Tunlinkat{
+ Directory: 1,
+ Name: "a",
+ Flags: 2,
+ },
+ &Runlinkat{},
+ &Tversion{
+ MSize: 1,
+ Version: "a",
+ },
+ &Rversion{
+ MSize: 1,
+ Version: "a",
+ },
+ &Tauth{
+ AuthenticationFID: 1,
+ UserName: "a",
+ AttachName: "b",
+ UID: 2,
+ },
+ &Rauth{
+ QID: QID{Type: 1},
+ },
+ &Tattach{
+ FID: 1,
+ Auth: Tauth{AuthenticationFID: 2},
+ },
+ &Rattach{
+ QID: QID{Type: 1},
+ },
+ &Tflush{
+ OldTag: 1,
+ },
+ &Rflush{},
+ &Twalk{
+ FID: 1,
+ NewFID: 2,
+ Names: []string{"a"},
+ },
+ &Rwalk{
+ QIDs: []QID{{Type: 1}},
+ },
+ &Tread{
+ FID: 1,
+ Offset: 2,
+ Count: 3,
+ },
+ &Rread{
+ Data: []byte{'a'},
+ },
+ &Twrite{
+ FID: 1,
+ Offset: 2,
+ Data: []byte{'a'},
+ },
+ &Rwrite{
+ Count: 1,
+ },
+ &Tclunk{
+ FID: 1,
+ },
+ &Rclunk{},
+ &Tremove{
+ FID: 1,
+ },
+ &Rremove{},
+ &Tflushf{
+ FID: 1,
+ },
+ &Rflushf{},
+ &Twalkgetattr{
+ FID: 1,
+ NewFID: 2,
+ Names: []string{"a"},
+ },
+ &Rwalkgetattr{
+ QIDs: []QID{{Type: 1}},
+ Valid: AttrMask{Mode: true},
+ Attr: Attr{Mode: Write},
+ },
+ &Tucreate{
+ Tlcreate: Tlcreate{
+ FID: 1,
+ Name: "a",
+ OpenFlags: 2,
+ Permissions: 3,
+ GID: 4,
+ },
+ UID: 5,
+ },
+ &Rucreate{
+ Rlcreate{Rlopen{QID: QID{Type: 1}}},
+ },
+ &Tumkdir{
+ Tmkdir: Tmkdir{
+ Directory: 1,
+ Name: "a",
+ Permissions: 2,
+ GID: 3,
+ },
+ UID: 4,
+ },
+ &Rumkdir{
+ Rmkdir{QID: QID{Type: 1}},
+ },
+ &Tusymlink{
+ Tsymlink: Tsymlink{
+ Directory: 1,
+ Name: "a",
+ Target: "b",
+ GID: 2,
+ },
+ UID: 3,
+ },
+ &Rusymlink{
+ Rsymlink{QID: QID{Type: 1}},
+ },
+ &Tumknod{
+ Tmknod: Tmknod{
+ Directory: 1,
+ Name: "a",
+ Mode: 2,
+ Major: 3,
+ Minor: 4,
+ GID: 5,
+ },
+ UID: 6,
+ },
+ &Rumknod{
+ Rmknod{QID: QID{Type: 1}},
+ },
+ }
+
+ for _, enc := range objs {
+ // Encode the original.
+ data := make([]byte, initialBufferLength)
+ buf := buffer{data: data[:0]}
+ enc.encode(&buf)
+
+ // Create a new object, same as the first.
+ enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder)
+ buf2 := buffer{data: buf.data}
+
+ // To be fair, we need to add any payloads (directly).
+ if pl, ok := enc.(payloader); ok {
+ enc2.(payloader).SetPayload(pl.Payload())
+ }
+
+ // And any file payloads (directly).
+ if fl, ok := enc.(filer); ok {
+ enc2.(filer).SetFilePayload(fl.FilePayload())
+ }
+
+ // Mark sure it was okay.
+ enc2.decode(&buf2)
+ if buf2.isOverrun() {
+ t.Errorf("object %#v->%#v got overrun on decode", enc, enc2)
+ continue
+ }
+
+ // Check that they are equal.
+ if !reflect.DeepEqual(enc, enc2) {
+ t.Errorf("object %#v and %#v differ", enc, enc2)
+ continue
+ }
+ }
+}
+
+func TestMessageStrings(t *testing.T) {
+ for typ := range msgRegistry.factories {
+ entry := &msgRegistry.factories[typ]
+ if entry.create != nil {
+ name := fmt.Sprintf("%+v", typ)
+ t.Run(name, func(t *testing.T) {
+ defer func() { // Ensure no panic.
+ if r := recover(); r != nil {
+ t.Errorf("printing %s failed: %v", name, r)
+ }
+ }()
+ m := entry.create()
+ _ = fmt.Sprintf("%v", m)
+ err := ErrInvalidMsgType{MsgType(typ)}
+ _ = err.Error()
+ })
+ }
+ }
+}
+
+func TestRegisterDuplicate(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ // We expect a panic.
+ t.FailNow()
+ }
+ }()
+
+ // Register a duplicate.
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+}
+
+func TestMsgCache(t *testing.T) {
+ // Cache starts empty.
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Message can be created with an empty cache.
+ msg, err := msgRegistry.get(0, MsgRlerror)
+ if err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that message is added to the cache when returned.
+ msgRegistry.put(msg)
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that returned message is reused.
+ if got, err := msgRegistry.get(0, MsgRlerror); err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ } else if msg != got {
+ t.Errorf("Message not reused, got: %d, want: %d", got, msg)
+ }
+
+ // Check that cache doesn't grow beyond max size.
+ for i := 0; i < maxCacheSize+1; i++ {
+ msgRegistry.put(&Rlerror{})
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
new file mode 100644
index 000000000..28d851ff5
--- /dev/null
+++ b/pkg/p9/p9.go
@@ -0,0 +1,1158 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package p9 is a 9P2000.L implementation.
+package p9
+
+import (
+ "fmt"
+ "math"
+ "os"
+ "strings"
+ "sync/atomic"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+// OpenFlags is the mode passed to Open and Create operations.
+//
+// These correspond to bits sent over the wire.
+type OpenFlags uint32
+
+const (
+ // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode.
+ ReadOnly OpenFlags = 0
+
+ // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode.
+ WriteOnly OpenFlags = 1
+
+ // ReadWrite is a Tlopen flag indicates read-write mode.
+ ReadWrite OpenFlags = 2
+
+ // OpenFlagsModeMask is a mask of valid OpenFlags mode bits.
+ OpenFlagsModeMask OpenFlags = 3
+
+ // OpenTruncate is a Tlopen flag indicating that the opened file should be
+ // truncated.
+ OpenTruncate OpenFlags = 01000
+)
+
+// ConnectFlags is the mode passed to Connect operations.
+//
+// These correspond to bits sent over the wire.
+type ConnectFlags uint32
+
+const (
+ // StreamSocket is a Tlconnect flag indicating SOCK_STREAM mode.
+ StreamSocket ConnectFlags = 0
+
+ // DgramSocket is a Tlconnect flag indicating SOCK_DGRAM mode.
+ DgramSocket ConnectFlags = 1
+
+ // SeqpacketSocket is a Tlconnect flag indicating SOCK_SEQPACKET mode.
+ SeqpacketSocket ConnectFlags = 2
+
+ // AnonymousSocket is a Tlconnect flag indicating that the mode does not
+ // matter and that the requester will accept any socket type.
+ AnonymousSocket ConnectFlags = 3
+)
+
+// OSFlags converts a p9.OpenFlags to an int compatible with open(2).
+func (o OpenFlags) OSFlags() int {
+ // "flags contains Linux open(2) flags bits" - 9P2000.L
+ return int(o)
+}
+
+// String implements fmt.Stringer.
+func (o OpenFlags) String() string {
+ var buf strings.Builder
+ switch mode := o & OpenFlagsModeMask; mode {
+ case ReadOnly:
+ buf.WriteString("ReadOnly")
+ case WriteOnly:
+ buf.WriteString("WriteOnly")
+ case ReadWrite:
+ buf.WriteString("ReadWrite")
+ default:
+ fmt.Fprintf(&buf, "%#o", mode)
+ }
+ otherFlags := o &^ OpenFlagsModeMask
+ if otherFlags&OpenTruncate != 0 {
+ buf.WriteString("|OpenTruncate")
+ otherFlags &^= OpenTruncate
+ }
+ if otherFlags != 0 {
+ fmt.Fprintf(&buf, "|%#o", otherFlags)
+ }
+ return buf.String()
+}
+
+// Tag is a message tag.
+type Tag uint16
+
+// FID is a file identifier.
+type FID uint64
+
+// FileMode are flags corresponding to file modes.
+//
+// These correspond to bits sent over the wire.
+// These also correspond to mode_t bits.
+type FileMode uint32
+
+const (
+ // FileModeMask is a mask of all the file mode bits of FileMode.
+ FileModeMask FileMode = 0170000
+
+ // ModeSocket is an (unused) mode bit for a socket.
+ ModeSocket FileMode = 0140000
+
+ // ModeSymlink is a mode bit for a symlink.
+ ModeSymlink FileMode = 0120000
+
+ // ModeRegular is a mode bit for regular files.
+ ModeRegular FileMode = 0100000
+
+ // ModeBlockDevice is a mode bit for block devices.
+ ModeBlockDevice FileMode = 060000
+
+ // ModeDirectory is a mode bit for directories.
+ ModeDirectory FileMode = 040000
+
+ // ModeCharacterDevice is a mode bit for a character device.
+ ModeCharacterDevice FileMode = 020000
+
+ // ModeNamedPipe is a mode bit for a named pipe.
+ ModeNamedPipe FileMode = 010000
+
+ // Read is a mode bit indicating read permission.
+ Read FileMode = 04
+
+ // Write is a mode bit indicating write permission.
+ Write FileMode = 02
+
+ // Exec is a mode bit indicating exec permission.
+ Exec FileMode = 01
+
+ // AllPermissions is a mask with rwx bits set for user, group and others.
+ AllPermissions FileMode = 0777
+
+ // Sticky is a mode bit indicating sticky directories.
+ Sticky FileMode = 01000
+
+ // permissionsMask is the mask to apply to FileModes for permissions. It
+ // includes rwx bits for user, group and others, and sticky bit.
+ permissionsMask FileMode = 01777
+)
+
+// QIDType is the most significant byte of the FileMode word, to be used as the
+// Type field of p9.QID.
+func (m FileMode) QIDType() QIDType {
+ switch {
+ case m.IsDir():
+ return TypeDir
+ case m.IsSocket(), m.IsNamedPipe(), m.IsCharacterDevice():
+ // Best approximation.
+ return TypeAppendOnly
+ case m.IsSymlink():
+ return TypeSymlink
+ default:
+ return TypeRegular
+ }
+}
+
+// FileType returns the file mode without the permission bits.
+func (m FileMode) FileType() FileMode {
+ return m & FileModeMask
+}
+
+// Permissions returns just the permission bits of the mode.
+func (m FileMode) Permissions() FileMode {
+ return m & permissionsMask
+}
+
+// Writable returns the mode with write bits added.
+func (m FileMode) Writable() FileMode {
+ return m | 0222
+}
+
+// IsReadable returns true if m represents a file that can be read.
+func (m FileMode) IsReadable() bool {
+ return m&0444 != 0
+}
+
+// IsWritable returns true if m represents a file that can be written to.
+func (m FileMode) IsWritable() bool {
+ return m&0222 != 0
+}
+
+// IsExecutable returns true if m represents a file that can be executed.
+func (m FileMode) IsExecutable() bool {
+ return m&0111 != 0
+}
+
+// IsRegular returns true if m is a regular file.
+func (m FileMode) IsRegular() bool {
+ return m&FileModeMask == ModeRegular
+}
+
+// IsDir returns true if m represents a directory.
+func (m FileMode) IsDir() bool {
+ return m&FileModeMask == ModeDirectory
+}
+
+// IsNamedPipe returns true if m represents a named pipe.
+func (m FileMode) IsNamedPipe() bool {
+ return m&FileModeMask == ModeNamedPipe
+}
+
+// IsCharacterDevice returns true if m represents a character device.
+func (m FileMode) IsCharacterDevice() bool {
+ return m&FileModeMask == ModeCharacterDevice
+}
+
+// IsBlockDevice returns true if m represents a character device.
+func (m FileMode) IsBlockDevice() bool {
+ return m&FileModeMask == ModeBlockDevice
+}
+
+// IsSocket returns true if m represents a socket.
+func (m FileMode) IsSocket() bool {
+ return m&FileModeMask == ModeSocket
+}
+
+// IsSymlink returns true if m represents a symlink.
+func (m FileMode) IsSymlink() bool {
+ return m&FileModeMask == ModeSymlink
+}
+
+// ModeFromOS returns a FileMode from an os.FileMode.
+func ModeFromOS(mode os.FileMode) FileMode {
+ m := FileMode(mode.Perm())
+ switch {
+ case mode.IsDir():
+ m |= ModeDirectory
+ case mode&os.ModeSymlink != 0:
+ m |= ModeSymlink
+ case mode&os.ModeSocket != 0:
+ m |= ModeSocket
+ case mode&os.ModeNamedPipe != 0:
+ m |= ModeNamedPipe
+ case mode&os.ModeCharDevice != 0:
+ m |= ModeCharacterDevice
+ case mode&os.ModeDevice != 0:
+ m |= ModeBlockDevice
+ default:
+ m |= ModeRegular
+ }
+ return m
+}
+
+// OSMode converts a p9.FileMode to an os.FileMode.
+func (m FileMode) OSMode() os.FileMode {
+ var osMode os.FileMode
+ osMode |= os.FileMode(m.Permissions())
+ switch {
+ case m.IsDir():
+ osMode |= os.ModeDir
+ case m.IsSymlink():
+ osMode |= os.ModeSymlink
+ case m.IsSocket():
+ osMode |= os.ModeSocket
+ case m.IsNamedPipe():
+ osMode |= os.ModeNamedPipe
+ case m.IsCharacterDevice():
+ osMode |= os.ModeCharDevice | os.ModeDevice
+ case m.IsBlockDevice():
+ osMode |= os.ModeDevice
+ }
+ return osMode
+}
+
+// UID represents a user ID.
+type UID uint32
+
+// Ok returns true if uid is not NoUID.
+func (uid UID) Ok() bool {
+ return uid != NoUID
+}
+
+// GID represents a group ID.
+type GID uint32
+
+// Ok returns true if gid is not NoGID.
+func (gid GID) Ok() bool {
+ return gid != NoGID
+}
+
+const (
+ // NoTag is a sentinel used to indicate no valid tag.
+ NoTag Tag = math.MaxUint16
+
+ // NoFID is a sentinel used to indicate no valid FID.
+ NoFID FID = math.MaxUint32
+
+ // NoUID is a sentinel used to indicate no valid UID.
+ NoUID UID = math.MaxUint32
+
+ // NoGID is a sentinel used to indicate no valid GID.
+ NoGID GID = math.MaxUint32
+)
+
+// MsgType is a type identifier.
+type MsgType uint8
+
+// MsgType declarations.
+const (
+ MsgTlerror MsgType = 6
+ MsgRlerror = 7
+ MsgTstatfs = 8
+ MsgRstatfs = 9
+ MsgTlopen = 12
+ MsgRlopen = 13
+ MsgTlcreate = 14
+ MsgRlcreate = 15
+ MsgTsymlink = 16
+ MsgRsymlink = 17
+ MsgTmknod = 18
+ MsgRmknod = 19
+ MsgTrename = 20
+ MsgRrename = 21
+ MsgTreadlink = 22
+ MsgRreadlink = 23
+ MsgTgetattr = 24
+ MsgRgetattr = 25
+ MsgTsetattr = 26
+ MsgRsetattr = 27
+ MsgTlistxattr = 28
+ MsgRlistxattr = 29
+ MsgTxattrwalk = 30
+ MsgRxattrwalk = 31
+ MsgTxattrcreate = 32
+ MsgRxattrcreate = 33
+ MsgTgetxattr = 34
+ MsgRgetxattr = 35
+ MsgTsetxattr = 36
+ MsgRsetxattr = 37
+ MsgTremovexattr = 38
+ MsgRremovexattr = 39
+ MsgTreaddir = 40
+ MsgRreaddir = 41
+ MsgTfsync = 50
+ MsgRfsync = 51
+ MsgTlink = 70
+ MsgRlink = 71
+ MsgTmkdir = 72
+ MsgRmkdir = 73
+ MsgTrenameat = 74
+ MsgRrenameat = 75
+ MsgTunlinkat = 76
+ MsgRunlinkat = 77
+ MsgTversion = 100
+ MsgRversion = 101
+ MsgTauth = 102
+ MsgRauth = 103
+ MsgTattach = 104
+ MsgRattach = 105
+ MsgTflush = 108
+ MsgRflush = 109
+ MsgTwalk = 110
+ MsgRwalk = 111
+ MsgTread = 116
+ MsgRread = 117
+ MsgTwrite = 118
+ MsgRwrite = 119
+ MsgTclunk = 120
+ MsgRclunk = 121
+ MsgTremove = 122
+ MsgRremove = 123
+ MsgTflushf = 124
+ MsgRflushf = 125
+ MsgTwalkgetattr = 126
+ MsgRwalkgetattr = 127
+ MsgTucreate = 128
+ MsgRucreate = 129
+ MsgTumkdir = 130
+ MsgRumkdir = 131
+ MsgTumknod = 132
+ MsgRumknod = 133
+ MsgTusymlink = 134
+ MsgRusymlink = 135
+ MsgTlconnect = 136
+ MsgRlconnect = 137
+ MsgTallocate = 138
+ MsgRallocate = 139
+ MsgTchannel = 250
+ MsgRchannel = 251
+)
+
+// QIDType represents the file type for QIDs.
+//
+// QIDType corresponds to the high 8 bits of a Plan 9 file mode.
+type QIDType uint8
+
+const (
+ // TypeDir represents a directory type.
+ TypeDir QIDType = 0x80
+
+ // TypeAppendOnly represents an append only file.
+ TypeAppendOnly QIDType = 0x40
+
+ // TypeExclusive represents an exclusive-use file.
+ TypeExclusive QIDType = 0x20
+
+ // TypeMount represents a mounted channel.
+ TypeMount QIDType = 0x10
+
+ // TypeAuth represents an authentication file.
+ TypeAuth QIDType = 0x08
+
+ // TypeTemporary represents a temporary file.
+ TypeTemporary QIDType = 0x04
+
+ // TypeSymlink represents a symlink.
+ TypeSymlink QIDType = 0x02
+
+ // TypeLink represents a hard link.
+ TypeLink QIDType = 0x01
+
+ // TypeRegular represents a regular file.
+ TypeRegular QIDType = 0x00
+)
+
+// QID is a unique file identifier.
+//
+// This may be embedded in other requests and responses.
+type QID struct {
+ // Type is the highest order byte of the file mode.
+ Type QIDType
+
+ // Version is an arbitrary server version number.
+ Version uint32
+
+ // Path is a unique server identifier for this path (e.g. inode).
+ Path uint64
+}
+
+// String implements fmt.Stringer.
+func (q QID) String() string {
+ return fmt.Sprintf("QID{Type: %d, Version: %d, Path: %d}", q.Type, q.Version, q.Path)
+}
+
+// decode implements encoder.decode.
+func (q *QID) decode(b *buffer) {
+ q.Type = b.ReadQIDType()
+ q.Version = b.Read32()
+ q.Path = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (q *QID) encode(b *buffer) {
+ b.WriteQIDType(q.Type)
+ b.Write32(q.Version)
+ b.Write64(q.Path)
+}
+
+// QIDGenerator is a simple generator for QIDs that atomically increments Path
+// values.
+type QIDGenerator struct {
+ // uids is an ever increasing value that can be atomically incremented
+ // to provide unique Path values for QIDs.
+ uids uint64
+}
+
+// Get returns a new 9P unique ID with a unique Path given a QID type.
+//
+// While the 9P spec allows Version to be incremented every time the file is
+// modified, we currently do not use the Version member for anything. Hence,
+// it is set to 0.
+func (q *QIDGenerator) Get(t QIDType) QID {
+ return QID{
+ Type: t,
+ Version: 0,
+ Path: atomic.AddUint64(&q.uids, 1),
+ }
+}
+
+// FSStat is used by statfs.
+type FSStat struct {
+ // Type is the filesystem type.
+ Type uint32
+
+ // BlockSize is the blocksize.
+ BlockSize uint32
+
+ // Blocks is the number of blocks.
+ Blocks uint64
+
+ // BlocksFree is the number of free blocks.
+ BlocksFree uint64
+
+ // BlocksAvailable is the number of blocks *available*.
+ BlocksAvailable uint64
+
+ // Files is the number of files available.
+ Files uint64
+
+ // FilesFree is the number of free file nodes.
+ FilesFree uint64
+
+ // FSID is the filesystem ID.
+ FSID uint64
+
+ // NameLength is the maximum name length.
+ NameLength uint32
+}
+
+// decode implements encoder.decode.
+func (f *FSStat) decode(b *buffer) {
+ f.Type = b.Read32()
+ f.BlockSize = b.Read32()
+ f.Blocks = b.Read64()
+ f.BlocksFree = b.Read64()
+ f.BlocksAvailable = b.Read64()
+ f.Files = b.Read64()
+ f.FilesFree = b.Read64()
+ f.FSID = b.Read64()
+ f.NameLength = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (f *FSStat) encode(b *buffer) {
+ b.Write32(f.Type)
+ b.Write32(f.BlockSize)
+ b.Write64(f.Blocks)
+ b.Write64(f.BlocksFree)
+ b.Write64(f.BlocksAvailable)
+ b.Write64(f.Files)
+ b.Write64(f.FilesFree)
+ b.Write64(f.FSID)
+ b.Write32(f.NameLength)
+}
+
+// AttrMask is a mask of attributes for getattr.
+type AttrMask struct {
+ Mode bool
+ NLink bool
+ UID bool
+ GID bool
+ RDev bool
+ ATime bool
+ MTime bool
+ CTime bool
+ INo bool
+ Size bool
+ Blocks bool
+ BTime bool
+ Gen bool
+ DataVersion bool
+}
+
+// Contains returns true if a contains all of the attributes masked as b.
+func (a AttrMask) Contains(b AttrMask) bool {
+ if b.Mode && !a.Mode {
+ return false
+ }
+ if b.NLink && !a.NLink {
+ return false
+ }
+ if b.UID && !a.UID {
+ return false
+ }
+ if b.GID && !a.GID {
+ return false
+ }
+ if b.RDev && !a.RDev {
+ return false
+ }
+ if b.ATime && !a.ATime {
+ return false
+ }
+ if b.MTime && !a.MTime {
+ return false
+ }
+ if b.CTime && !a.CTime {
+ return false
+ }
+ if b.INo && !a.INo {
+ return false
+ }
+ if b.Size && !a.Size {
+ return false
+ }
+ if b.Blocks && !a.Blocks {
+ return false
+ }
+ if b.BTime && !a.BTime {
+ return false
+ }
+ if b.Gen && !a.Gen {
+ return false
+ }
+ if b.DataVersion && !a.DataVersion {
+ return false
+ }
+ return true
+}
+
+// Empty returns true if no fields are masked.
+func (a AttrMask) Empty() bool {
+ return !a.Mode && !a.NLink && !a.UID && !a.GID && !a.RDev && !a.ATime && !a.MTime && !a.CTime && !a.INo && !a.Size && !a.Blocks && !a.BTime && !a.Gen && !a.DataVersion
+}
+
+// AttrMaskAll returns an AttrMask with all fields masked.
+func AttrMaskAll() AttrMask {
+ return AttrMask{
+ Mode: true,
+ NLink: true,
+ UID: true,
+ GID: true,
+ RDev: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ INo: true,
+ Size: true,
+ Blocks: true,
+ BTime: true,
+ Gen: true,
+ DataVersion: true,
+ }
+}
+
+// String implements fmt.Stringer.
+func (a AttrMask) String() string {
+ var masks []string
+ if a.Mode {
+ masks = append(masks, "Mode")
+ }
+ if a.NLink {
+ masks = append(masks, "NLink")
+ }
+ if a.UID {
+ masks = append(masks, "UID")
+ }
+ if a.GID {
+ masks = append(masks, "GID")
+ }
+ if a.RDev {
+ masks = append(masks, "RDev")
+ }
+ if a.ATime {
+ masks = append(masks, "ATime")
+ }
+ if a.MTime {
+ masks = append(masks, "MTime")
+ }
+ if a.CTime {
+ masks = append(masks, "CTime")
+ }
+ if a.INo {
+ masks = append(masks, "INo")
+ }
+ if a.Size {
+ masks = append(masks, "Size")
+ }
+ if a.Blocks {
+ masks = append(masks, "Blocks")
+ }
+ if a.BTime {
+ masks = append(masks, "BTime")
+ }
+ if a.Gen {
+ masks = append(masks, "Gen")
+ }
+ if a.DataVersion {
+ masks = append(masks, "DataVersion")
+ }
+ return fmt.Sprintf("AttrMask{with: %s}", strings.Join(masks, " "))
+}
+
+// decode implements encoder.decode.
+func (a *AttrMask) decode(b *buffer) {
+ mask := b.Read64()
+ a.Mode = mask&0x00000001 != 0
+ a.NLink = mask&0x00000002 != 0
+ a.UID = mask&0x00000004 != 0
+ a.GID = mask&0x00000008 != 0
+ a.RDev = mask&0x00000010 != 0
+ a.ATime = mask&0x00000020 != 0
+ a.MTime = mask&0x00000040 != 0
+ a.CTime = mask&0x00000080 != 0
+ a.INo = mask&0x00000100 != 0
+ a.Size = mask&0x00000200 != 0
+ a.Blocks = mask&0x00000400 != 0
+ a.BTime = mask&0x00000800 != 0
+ a.Gen = mask&0x00001000 != 0
+ a.DataVersion = mask&0x00002000 != 0
+}
+
+// encode implements encoder.encode.
+func (a *AttrMask) encode(b *buffer) {
+ var mask uint64
+ if a.Mode {
+ mask |= 0x00000001
+ }
+ if a.NLink {
+ mask |= 0x00000002
+ }
+ if a.UID {
+ mask |= 0x00000004
+ }
+ if a.GID {
+ mask |= 0x00000008
+ }
+ if a.RDev {
+ mask |= 0x00000010
+ }
+ if a.ATime {
+ mask |= 0x00000020
+ }
+ if a.MTime {
+ mask |= 0x00000040
+ }
+ if a.CTime {
+ mask |= 0x00000080
+ }
+ if a.INo {
+ mask |= 0x00000100
+ }
+ if a.Size {
+ mask |= 0x00000200
+ }
+ if a.Blocks {
+ mask |= 0x00000400
+ }
+ if a.BTime {
+ mask |= 0x00000800
+ }
+ if a.Gen {
+ mask |= 0x00001000
+ }
+ if a.DataVersion {
+ mask |= 0x00002000
+ }
+ b.Write64(mask)
+}
+
+// Attr is a set of attributes for getattr.
+type Attr struct {
+ Mode FileMode
+ UID UID
+ GID GID
+ NLink uint64
+ RDev uint64
+ Size uint64
+ BlockSize uint64
+ Blocks uint64
+ ATimeSeconds uint64
+ ATimeNanoSeconds uint64
+ MTimeSeconds uint64
+ MTimeNanoSeconds uint64
+ CTimeSeconds uint64
+ CTimeNanoSeconds uint64
+ BTimeSeconds uint64
+ BTimeNanoSeconds uint64
+ Gen uint64
+ DataVersion uint64
+}
+
+// String implements fmt.Stringer.
+func (a Attr) String() string {
+ return fmt.Sprintf("Attr{Mode: 0o%o, UID: %d, GID: %d, NLink: %d, RDev: %d, Size: %d, BlockSize: %d, Blocks: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}, CTime: {Sec: %d, NanoSec: %d}, BTime: {Sec: %d, NanoSec: %d}, Gen: %d, DataVersion: %d}",
+ a.Mode, a.UID, a.GID, a.NLink, a.RDev, a.Size, a.BlockSize, a.Blocks, a.ATimeSeconds, a.ATimeNanoSeconds, a.MTimeSeconds, a.MTimeNanoSeconds, a.CTimeSeconds, a.CTimeNanoSeconds, a.BTimeSeconds, a.BTimeNanoSeconds, a.Gen, a.DataVersion)
+}
+
+// encode implements encoder.encode.
+func (a *Attr) encode(b *buffer) {
+ b.WriteFileMode(a.Mode)
+ b.WriteUID(a.UID)
+ b.WriteGID(a.GID)
+ b.Write64(a.NLink)
+ b.Write64(a.RDev)
+ b.Write64(a.Size)
+ b.Write64(a.BlockSize)
+ b.Write64(a.Blocks)
+ b.Write64(a.ATimeSeconds)
+ b.Write64(a.ATimeNanoSeconds)
+ b.Write64(a.MTimeSeconds)
+ b.Write64(a.MTimeNanoSeconds)
+ b.Write64(a.CTimeSeconds)
+ b.Write64(a.CTimeNanoSeconds)
+ b.Write64(a.BTimeSeconds)
+ b.Write64(a.BTimeNanoSeconds)
+ b.Write64(a.Gen)
+ b.Write64(a.DataVersion)
+}
+
+// decode implements encoder.decode.
+func (a *Attr) decode(b *buffer) {
+ a.Mode = b.ReadFileMode()
+ a.UID = b.ReadUID()
+ a.GID = b.ReadGID()
+ a.NLink = b.Read64()
+ a.RDev = b.Read64()
+ a.Size = b.Read64()
+ a.BlockSize = b.Read64()
+ a.Blocks = b.Read64()
+ a.ATimeSeconds = b.Read64()
+ a.ATimeNanoSeconds = b.Read64()
+ a.MTimeSeconds = b.Read64()
+ a.MTimeNanoSeconds = b.Read64()
+ a.CTimeSeconds = b.Read64()
+ a.CTimeNanoSeconds = b.Read64()
+ a.BTimeSeconds = b.Read64()
+ a.BTimeNanoSeconds = b.Read64()
+ a.Gen = b.Read64()
+ a.DataVersion = b.Read64()
+}
+
+// StatToAttr converts a Linux syscall stat structure to an Attr.
+func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) {
+ attr := Attr{
+ UID: NoUID,
+ GID: NoGID,
+ }
+ if req.Mode {
+ // p9.FileMode corresponds to Linux mode_t.
+ attr.Mode = FileMode(s.Mode)
+ }
+ if req.NLink {
+ attr.NLink = uint64(s.Nlink)
+ }
+ if req.UID {
+ attr.UID = UID(s.Uid)
+ }
+ if req.GID {
+ attr.GID = GID(s.Gid)
+ }
+ if req.RDev {
+ attr.RDev = s.Dev
+ }
+ if req.ATime {
+ attr.ATimeSeconds = uint64(s.Atim.Sec)
+ attr.ATimeNanoSeconds = uint64(s.Atim.Nsec)
+ }
+ if req.MTime {
+ attr.MTimeSeconds = uint64(s.Mtim.Sec)
+ attr.MTimeNanoSeconds = uint64(s.Mtim.Nsec)
+ }
+ if req.CTime {
+ attr.CTimeSeconds = uint64(s.Ctim.Sec)
+ attr.CTimeNanoSeconds = uint64(s.Ctim.Nsec)
+ }
+ if req.Size {
+ attr.Size = uint64(s.Size)
+ }
+ if req.Blocks {
+ attr.BlockSize = uint64(s.Blksize)
+ attr.Blocks = uint64(s.Blocks)
+ }
+
+ // Use the req field because we already have it.
+ req.BTime = false
+ req.Gen = false
+ req.DataVersion = false
+
+ return attr, req
+}
+
+// SetAttrMask specifies a valid mask for setattr.
+type SetAttrMask struct {
+ Permissions bool
+ UID bool
+ GID bool
+ Size bool
+ ATime bool
+ MTime bool
+ CTime bool
+ ATimeNotSystemTime bool
+ MTimeNotSystemTime bool
+}
+
+// IsSubsetOf returns whether s is a subset of m.
+func (s SetAttrMask) IsSubsetOf(m SetAttrMask) bool {
+ sb := s.bitmask()
+ sm := m.bitmask()
+ return sm|sb == sm
+}
+
+// String implements fmt.Stringer.
+func (s SetAttrMask) String() string {
+ var masks []string
+ if s.Permissions {
+ masks = append(masks, "Permissions")
+ }
+ if s.UID {
+ masks = append(masks, "UID")
+ }
+ if s.GID {
+ masks = append(masks, "GID")
+ }
+ if s.Size {
+ masks = append(masks, "Size")
+ }
+ if s.ATime {
+ masks = append(masks, "ATime")
+ }
+ if s.MTime {
+ masks = append(masks, "MTime")
+ }
+ if s.CTime {
+ masks = append(masks, "CTime")
+ }
+ if s.ATimeNotSystemTime {
+ masks = append(masks, "ATimeNotSystemTime")
+ }
+ if s.MTimeNotSystemTime {
+ masks = append(masks, "MTimeNotSystemTime")
+ }
+ return fmt.Sprintf("SetAttrMask{with: %s}", strings.Join(masks, " "))
+}
+
+// Empty returns true if no fields are masked.
+func (s SetAttrMask) Empty() bool {
+ return !s.Permissions && !s.UID && !s.GID && !s.Size && !s.ATime && !s.MTime && !s.CTime && !s.ATimeNotSystemTime && !s.MTimeNotSystemTime
+}
+
+// decode implements encoder.decode.
+func (s *SetAttrMask) decode(b *buffer) {
+ mask := b.Read32()
+ s.Permissions = mask&0x00000001 != 0
+ s.UID = mask&0x00000002 != 0
+ s.GID = mask&0x00000004 != 0
+ s.Size = mask&0x00000008 != 0
+ s.ATime = mask&0x00000010 != 0
+ s.MTime = mask&0x00000020 != 0
+ s.CTime = mask&0x00000040 != 0
+ s.ATimeNotSystemTime = mask&0x00000080 != 0
+ s.MTimeNotSystemTime = mask&0x00000100 != 0
+}
+
+func (s SetAttrMask) bitmask() uint32 {
+ var mask uint32
+ if s.Permissions {
+ mask |= 0x00000001
+ }
+ if s.UID {
+ mask |= 0x00000002
+ }
+ if s.GID {
+ mask |= 0x00000004
+ }
+ if s.Size {
+ mask |= 0x00000008
+ }
+ if s.ATime {
+ mask |= 0x00000010
+ }
+ if s.MTime {
+ mask |= 0x00000020
+ }
+ if s.CTime {
+ mask |= 0x00000040
+ }
+ if s.ATimeNotSystemTime {
+ mask |= 0x00000080
+ }
+ if s.MTimeNotSystemTime {
+ mask |= 0x00000100
+ }
+ return mask
+}
+
+// encode implements encoder.encode.
+func (s *SetAttrMask) encode(b *buffer) {
+ b.Write32(s.bitmask())
+}
+
+// SetAttr specifies a set of attributes for a setattr.
+type SetAttr struct {
+ Permissions FileMode
+ UID UID
+ GID GID
+ Size uint64
+ ATimeSeconds uint64
+ ATimeNanoSeconds uint64
+ MTimeSeconds uint64
+ MTimeNanoSeconds uint64
+}
+
+// String implements fmt.Stringer.
+func (s SetAttr) String() string {
+ return fmt.Sprintf("SetAttr{Permissions: 0o%o, UID: %d, GID: %d, Size: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}}", s.Permissions, s.UID, s.GID, s.Size, s.ATimeSeconds, s.ATimeNanoSeconds, s.MTimeSeconds, s.MTimeNanoSeconds)
+}
+
+// decode implements encoder.decode.
+func (s *SetAttr) decode(b *buffer) {
+ s.Permissions = b.ReadPermissions()
+ s.UID = b.ReadUID()
+ s.GID = b.ReadGID()
+ s.Size = b.Read64()
+ s.ATimeSeconds = b.Read64()
+ s.ATimeNanoSeconds = b.Read64()
+ s.MTimeSeconds = b.Read64()
+ s.MTimeNanoSeconds = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (s *SetAttr) encode(b *buffer) {
+ b.WritePermissions(s.Permissions)
+ b.WriteUID(s.UID)
+ b.WriteGID(s.GID)
+ b.Write64(s.Size)
+ b.Write64(s.ATimeSeconds)
+ b.Write64(s.ATimeNanoSeconds)
+ b.Write64(s.MTimeSeconds)
+ b.Write64(s.MTimeNanoSeconds)
+}
+
+// Apply applies this to the given Attr.
+func (a *Attr) Apply(mask SetAttrMask, attr SetAttr) {
+ if mask.Permissions {
+ a.Mode = a.Mode&^permissionsMask | (attr.Permissions & permissionsMask)
+ }
+ if mask.UID {
+ a.UID = attr.UID
+ }
+ if mask.GID {
+ a.GID = attr.GID
+ }
+ if mask.Size {
+ a.Size = attr.Size
+ }
+ if mask.ATime {
+ a.ATimeSeconds = attr.ATimeSeconds
+ a.ATimeNanoSeconds = attr.ATimeNanoSeconds
+ }
+ if mask.MTime {
+ a.MTimeSeconds = attr.MTimeSeconds
+ a.MTimeNanoSeconds = attr.MTimeNanoSeconds
+ }
+}
+
+// Dirent is used for readdir.
+type Dirent struct {
+ // QID is the entry QID.
+ QID QID
+
+ // Offset is the offset in the directory.
+ //
+ // This will be communicated back the original caller.
+ Offset uint64
+
+ // Type is the 9P type.
+ Type QIDType
+
+ // Name is the name of the entry (i.e. basename).
+ Name string
+}
+
+// String implements fmt.Stringer.
+func (d Dirent) String() string {
+ return fmt.Sprintf("Dirent{QID: %d, Offset: %d, Type: 0x%X, Name: %s}", d.QID, d.Offset, d.Type, d.Name)
+}
+
+// decode implements encoder.decode.
+func (d *Dirent) decode(b *buffer) {
+ d.QID.decode(b)
+ d.Offset = b.Read64()
+ d.Type = b.ReadQIDType()
+ d.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (d *Dirent) encode(b *buffer) {
+ d.QID.encode(b)
+ b.Write64(d.Offset)
+ b.WriteQIDType(d.Type)
+ b.WriteString(d.Name)
+}
+
+// AllocateMode are possible modes to p9.File.Allocate().
+type AllocateMode struct {
+ KeepSize bool
+ PunchHole bool
+ NoHideStale bool
+ CollapseRange bool
+ ZeroRange bool
+ InsertRange bool
+ Unshare bool
+}
+
+// ToLinux converts to a value compatible with fallocate(2)'s mode.
+func (a *AllocateMode) ToLinux() uint32 {
+ rv := uint32(0)
+ if a.KeepSize {
+ rv |= unix.FALLOC_FL_KEEP_SIZE
+ }
+ if a.PunchHole {
+ rv |= unix.FALLOC_FL_PUNCH_HOLE
+ }
+ if a.NoHideStale {
+ rv |= unix.FALLOC_FL_NO_HIDE_STALE
+ }
+ if a.CollapseRange {
+ rv |= unix.FALLOC_FL_COLLAPSE_RANGE
+ }
+ if a.ZeroRange {
+ rv |= unix.FALLOC_FL_ZERO_RANGE
+ }
+ if a.InsertRange {
+ rv |= unix.FALLOC_FL_INSERT_RANGE
+ }
+ if a.Unshare {
+ rv |= unix.FALLOC_FL_UNSHARE_RANGE
+ }
+ return rv
+}
+
+// decode implements encoder.decode.
+func (a *AllocateMode) decode(b *buffer) {
+ mask := b.Read32()
+ a.KeepSize = mask&0x01 != 0
+ a.PunchHole = mask&0x02 != 0
+ a.NoHideStale = mask&0x04 != 0
+ a.CollapseRange = mask&0x08 != 0
+ a.ZeroRange = mask&0x10 != 0
+ a.InsertRange = mask&0x20 != 0
+ a.Unshare = mask&0x40 != 0
+}
+
+// encode implements encoder.encode.
+func (a *AllocateMode) encode(b *buffer) {
+ mask := uint32(0)
+ if a.KeepSize {
+ mask |= 0x01
+ }
+ if a.PunchHole {
+ mask |= 0x02
+ }
+ if a.NoHideStale {
+ mask |= 0x04
+ }
+ if a.CollapseRange {
+ mask |= 0x08
+ }
+ if a.ZeroRange {
+ mask |= 0x10
+ }
+ if a.InsertRange {
+ mask |= 0x20
+ }
+ if a.Unshare {
+ mask |= 0x40
+ }
+ b.Write32(mask)
+}
diff --git a/pkg/p9/p9_test.go b/pkg/p9/p9_test.go
new file mode 100644
index 000000000..8dda6cc64
--- /dev/null
+++ b/pkg/p9/p9_test.go
@@ -0,0 +1,188 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "os"
+ "testing"
+)
+
+func TestFileModeHelpers(t *testing.T) {
+ fns := map[FileMode]struct {
+ // name identifies the file mode.
+ name string
+
+ // function is the function that should return true given the
+ // right FileMode.
+ function func(m FileMode) bool
+ }{
+ ModeRegular: {
+ name: "regular",
+ function: FileMode.IsRegular,
+ },
+ ModeDirectory: {
+ name: "directory",
+ function: FileMode.IsDir,
+ },
+ ModeNamedPipe: {
+ name: "named pipe",
+ function: FileMode.IsNamedPipe,
+ },
+ ModeCharacterDevice: {
+ name: "character device",
+ function: FileMode.IsCharacterDevice,
+ },
+ ModeBlockDevice: {
+ name: "block device",
+ function: FileMode.IsBlockDevice,
+ },
+ ModeSymlink: {
+ name: "symlink",
+ function: FileMode.IsSymlink,
+ },
+ ModeSocket: {
+ name: "socket",
+ function: FileMode.IsSocket,
+ },
+ }
+ for mode, info := range fns {
+ // Make sure the mode doesn't identify as anything but itself.
+ for testMode, testfns := range fns {
+ if mode != testMode && testfns.function(mode) {
+ t.Errorf("Mode %s returned true when asked if it was mode %s", info.name, testfns.name)
+ }
+ }
+
+ // Make sure mode identifies as itself.
+ if !info.function(mode) {
+ t.Errorf("Mode %s returned false when asked if it was itself", info.name)
+ }
+ }
+}
+
+func TestFileModeToQID(t *testing.T) {
+ for _, test := range []struct {
+ // name identifies the test.
+ name string
+
+ // mode is the FileMode we start out with.
+ mode FileMode
+
+ // want is the corresponding QIDType we expect.
+ want QIDType
+ }{
+ {
+ name: "Directories are of type directory",
+ mode: ModeDirectory,
+ want: TypeDir,
+ },
+ {
+ name: "Sockets are append-only files",
+ mode: ModeSocket,
+ want: TypeAppendOnly,
+ },
+ {
+ name: "Named pipes are append-only files",
+ mode: ModeNamedPipe,
+ want: TypeAppendOnly,
+ },
+ {
+ name: "Character devices are append-only files",
+ mode: ModeCharacterDevice,
+ want: TypeAppendOnly,
+ },
+ {
+ name: "Symlinks are of type symlink",
+ mode: ModeSymlink,
+ want: TypeSymlink,
+ },
+ {
+ name: "Regular files are of type regular",
+ mode: ModeRegular,
+ want: TypeRegular,
+ },
+ {
+ name: "Block devices are regular files",
+ mode: ModeBlockDevice,
+ want: TypeRegular,
+ },
+ } {
+ if qidType := test.mode.QIDType(); qidType != test.want {
+ t.Errorf("ModeToQID test %s failed: got %o, wanted %o", test.name, qidType, test.want)
+ }
+ }
+}
+
+func TestP9ModeConverters(t *testing.T) {
+ for _, m := range []FileMode{
+ ModeRegular,
+ ModeDirectory,
+ ModeCharacterDevice,
+ ModeBlockDevice,
+ ModeSocket,
+ ModeSymlink,
+ ModeNamedPipe,
+ } {
+ if mb := ModeFromOS(m.OSMode()); mb != m {
+ t.Errorf("Converting %o to OS.FileMode gives %o and is converted back as %o", m, m.OSMode(), mb)
+ }
+ }
+}
+
+func TestOSModeConverters(t *testing.T) {
+ // Modes that can be converted back and forth.
+ for _, m := range []os.FileMode{
+ 0, // Regular file.
+ os.ModeDir,
+ os.ModeCharDevice | os.ModeDevice,
+ os.ModeDevice,
+ os.ModeSocket,
+ os.ModeSymlink,
+ os.ModeNamedPipe,
+ } {
+ if mb := ModeFromOS(m).OSMode(); mb != m {
+ t.Errorf("Converting %o to p9.FileMode gives %o and is converted back as %o", m, ModeFromOS(m), mb)
+ }
+ }
+
+ // Modes that will be converted to a regular file since p9 cannot
+ // express these.
+ for _, m := range []os.FileMode{
+ os.ModeAppend,
+ os.ModeExclusive,
+ os.ModeTemporary,
+ } {
+ if p9Mode := ModeFromOS(m); p9Mode != ModeRegular {
+ t.Errorf("Converting %o to p9.FileMode should have given ModeRegular, but yielded %o", m, p9Mode)
+ }
+ }
+}
+
+func TestAttrMaskContains(t *testing.T) {
+ req := AttrMask{Mode: true, Size: true}
+ have := AttrMask{}
+ if have.Contains(req) {
+ t.Fatalf("AttrMask %v should not be a superset of %v", have, req)
+ }
+ have.Mode = true
+ if have.Contains(req) {
+ t.Fatalf("AttrMask %v should not be a superset of %v", have, req)
+ }
+ have.Size = true
+ have.MTime = true
+ if !have.Contains(req) {
+ t.Fatalf("AttrMask %v should be a superset of %v", have, req)
+ }
+}
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
new file mode 100644
index 000000000..7ca67cb19
--- /dev/null
+++ b/pkg/p9/p9test/BUILD
@@ -0,0 +1,88 @@
+load("//tools:defs.bzl", "go_binary", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+alias(
+ name = "mockgen",
+ actual = "@com_github_golang_mock//mockgen:mockgen",
+)
+
+MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9"
+
+# mockgen_reflect is a source file that contains mock generation code that
+# imports the p9 package and generates a specification via reflection. The
+# usual generation path must be split into two distinct parts because the full
+# source tree is not available to all build targets. Only declared depencies
+# are available (and even then, not the Go source files).
+genrule(
+ name = "mockgen_reflect",
+ testonly = 1,
+ outs = ["mockgen_reflect.go"],
+ cmd = (
+ "$(location :mockgen) " +
+ "-package p9test " +
+ "-prog_only " + MOCK_SRC_PACKAGE + " " +
+ "Attacher,File > $@"
+ ),
+ tools = [":mockgen"],
+)
+
+# mockgen_exec is the binary that includes the above reflection generator.
+# Running this binary will emit an encoded version of the p9 Attacher and File
+# structures. This is consumed by the mocks genrule, below.
+go_binary(
+ name = "mockgen_exec",
+ testonly = 1,
+ srcs = ["mockgen_reflect.go"],
+ deps = [
+ "//pkg/p9",
+ "@com_github_golang_mock//mockgen/model:go_default_library",
+ ],
+)
+
+# mocks consumes the encoded output above, and generates the full source for a
+# set of mocks. These are included directly in the p9test library.
+genrule(
+ name = "mocks",
+ testonly = 1,
+ outs = ["mocks.go"],
+ cmd = (
+ "$(location :mockgen) " +
+ "-package p9test " +
+ "-exec_only $(location :mockgen_exec) " + MOCK_SRC_PACKAGE + " File > $@"
+ ),
+ tools = [
+ ":mockgen",
+ ":mockgen_exec",
+ ],
+)
+
+go_library(
+ name = "p9test",
+ srcs = [
+ "mocks.go",
+ "p9test.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@com_github_golang_mock//gomock:go_default_library",
+ ],
+)
+
+go_test(
+ name = "client_test",
+ size = "medium",
+ srcs = ["client_test.go"],
+ library = ":p9test",
+ deps = [
+ "//pkg/fd",
+ "//pkg/p9",
+ "//pkg/sync",
+ "@com_github_golang_mock//gomock:go_default_library",
+ ],
+)
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
new file mode 100644
index 000000000..6e7bb3db2
--- /dev/null
+++ b/pkg/p9/p9test/client_test.go
@@ -0,0 +1,2242 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "math/rand"
+ "os"
+ "reflect"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/golang/mock/gomock"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestPanic(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Create a new root.
+ d := h.NewDirectory(nil)(nil)
+ defer d.Close() // Needed manually.
+ h.Attacher.EXPECT().Attach().Return(d, nil).Do(func() {
+ // Panic here, and ensure that we get back EFAULT.
+ panic("handler")
+ })
+
+ // Attach to the client.
+ if _, err := c.Attach("/"); err != syscall.EFAULT {
+ t.Fatalf("got attach err %v, want EFAULT", err)
+ }
+}
+
+func TestAttachNoLeak(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Create a new root.
+ d := h.NewDirectory(nil)(nil)
+ h.Attacher.EXPECT().Attach().Return(d, nil).Times(1)
+
+ // Attach to the client.
+ f, err := c.Attach("/")
+ if err != nil {
+ t.Fatalf("got attach err %v, want nil", err)
+ }
+
+ // Don't close the file. This should be closed automatically when the
+ // client disconnects. The mock asserts that everything is closed
+ // exactly once. This statement just removes the unused variable error.
+ _ = f
+}
+
+func TestBadAttach(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Return an error on attach.
+ h.Attacher.EXPECT().Attach().Return(nil, syscall.EINVAL).Times(1)
+
+ // Attach to the client.
+ if _, err := c.Attach("/"); err != syscall.EINVAL {
+ t.Fatalf("got attach err %v, want syscall.EINVAL", err)
+ }
+}
+
+func TestWalkAttach(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ // Create a new root.
+ d := h.NewDirectory(map[string]Generator{
+ "a": h.NewDirectory(map[string]Generator{
+ "b": h.NewFile(),
+ }),
+ })(nil)
+ h.Attacher.EXPECT().Attach().Return(d, nil).Times(1)
+
+ // Attach to the client as a non-root, and ensure that the walk above
+ // occurs as expected. We should get back b, and all references should
+ // be dropped when the file is closed.
+ f, err := c.Attach("/a/b")
+ if err != nil {
+ t.Fatalf("got attach err %v, want nil", err)
+ }
+ defer f.Close()
+
+ // Check that's a regular file.
+ if _, _, attr, err := f.GetAttr(p9.AttrMaskAll()); err != nil {
+ t.Errorf("got err %v, want nil", err)
+ } else if !attr.Mode.IsRegular() {
+ t.Errorf("got mode %v, want regular file", err)
+ }
+}
+
+// newTypeMap returns a new type map dictionary.
+func newTypeMap(h *Harness) map[string]Generator {
+ return map[string]Generator{
+ "directory": h.NewDirectory(map[string]Generator{}),
+ "file": h.NewFile(),
+ "symlink": h.NewSymlink(),
+ "block-device": h.NewBlockDevice(),
+ "character-device": h.NewCharacterDevice(),
+ "named-pipe": h.NewNamedPipe(),
+ "socket": h.NewSocket(),
+ }
+}
+
+// newRoot returns a new root filesystem.
+//
+// This is set up in a deterministic way for testing most operations.
+//
+// The represented file system looks like:
+// - file
+// - symlink
+// - directory
+// ...
+// + one
+// - file
+// - symlink
+// - directory
+// ...
+// + two
+// - file
+// - symlink
+// - directory
+// ...
+// + three
+// - file
+// - symlink
+// - directory
+// ...
+func newRoot(h *Harness, c *p9.Client) (*Mock, p9.File) {
+ root := newTypeMap(h)
+ one := newTypeMap(h)
+ two := newTypeMap(h)
+ three := newTypeMap(h)
+ one["two"] = h.NewDirectory(two) // Will be nested in one.
+ root["one"] = h.NewDirectory(one) // Top level.
+ root["three"] = h.NewDirectory(three) // Alternate top-level.
+
+ // Create a new root.
+ rootBackend := h.NewDirectory(root)(nil)
+ h.Attacher.EXPECT().Attach().Return(rootBackend, nil)
+
+ // Attach to the client.
+ r, err := c.Attach("/")
+ if err != nil {
+ h.t.Fatalf("got attach err %v, want nil", err)
+ }
+
+ return rootBackend, r
+}
+
+func allInvalidNames(from string) []string {
+ return []string{
+ from + "/other",
+ from + "/..",
+ from + "/.",
+ from + "/",
+ "other/" + from,
+ "/" + from,
+ "./" + from,
+ "../" + from,
+ ".",
+ "..",
+ "/",
+ "",
+ }
+}
+
+func TestWalkInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Run relevant tests.
+ for name := range newTypeMap(h) {
+ // These are all the various ways that one might attempt to
+ // construct compound paths. They should all be rejected, as
+ // any compound that contains a / is not allowed, as well as
+ // the singular paths of '.' and '..'.
+ if _, _, err := root.Walk([]string{".", name}); err != syscall.EINVAL {
+ t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := root.Walk([]string{"..", name}); err != syscall.EINVAL {
+ t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := root.Walk([]string{name, "."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s . wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := root.Walk([]string{name, ".."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s .. wanted EINVAL, got %v", name, err)
+ }
+ for _, invalidName := range allInvalidNames(name) {
+ if _, _, err := root.Walk([]string{invalidName}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s wanted EINVAL, got %v", invalidName, err)
+ }
+ }
+ wantErr := syscall.EINVAL
+ if name == "directory" {
+ // We can attempt a walk through a directory. However,
+ // we should never see a file named "other", so we
+ // expect this to return ENOENT.
+ wantErr = syscall.ENOENT
+ }
+ if _, _, err := root.Walk([]string{name, "other"}); err != wantErr {
+ t.Errorf("Walk through %s/other wanted %v, got %v", name, wantErr, err)
+ }
+
+ // Do a successful walk.
+ _, f, err := root.Walk([]string{name})
+ if err != nil {
+ t.Errorf("Walk to %s wanted nil, got %v", name, err)
+ }
+ defer f.Close()
+ local := h.Pop(f)
+
+ // Check that the file matches.
+ _, localMask, localAttr, localErr := local.GetAttr(p9.AttrMaskAll())
+ if _, mask, attr, err := f.GetAttr(p9.AttrMaskAll()); mask != localMask || attr != localAttr || err != localErr {
+ t.Errorf("GetAttr got (%v, %v, %v), wanted (%v, %v, %v)",
+ mask, attr, err, localMask, localAttr, localErr)
+ }
+
+ // Ensure we can't walk backwards.
+ if _, _, err := f.Walk([]string{"."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s/. wanted EINVAL, got %v", name, err)
+ }
+ if _, _, err := f.Walk([]string{".."}); err != syscall.EINVAL {
+ t.Errorf("Walk through %s/.. wanted EINVAL, got %v", name, err)
+ }
+ }
+}
+
+// fileGenerator is a function to generate files via walk or create.
+//
+// Examples are:
+// - walkHelper
+// - walkAndOpenHelper
+// - createHelper
+type fileGenerator func(*Harness, string, p9.File) (*Mock, *Mock, p9.File)
+
+// walkHelper walks to the given file.
+//
+// The backends of the parent and walked file are returned, as well as the
+// walked client file.
+func walkHelper(h *Harness, name string, dir p9.File) (parentBackend *Mock, walkedBackend *Mock, walked p9.File) {
+ _, parent, err := dir.Walk(nil)
+ if err != nil {
+ h.t.Fatalf("Walk(nil) got err %v, want nil", err)
+ }
+ defer parent.Close()
+ parentBackend = h.Pop(parent)
+
+ _, walked, err = parent.Walk([]string{name})
+ if err != nil {
+ h.t.Fatalf("Walk(%s) got err %v, want nil", name, err)
+ }
+ walkedBackend = h.Pop(walked)
+
+ return parentBackend, walkedBackend, walked
+}
+
+// walkAndOpenHelper additionally opens the walked file, if possible.
+func walkAndOpenHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) {
+ parentBackend, walkedBackend, walked := walkHelper(h, name, dir)
+ if p9.CanOpen(walkedBackend.Attr.Mode) {
+ // Open for all file types that we can. We stick to a read-only
+ // open here because directories may not be opened otherwise.
+ walkedBackend.EXPECT().Open(p9.ReadOnly).Times(1)
+ if _, _, _, err := walked.Open(p9.ReadOnly); err != nil {
+ h.t.Errorf("got open err %v, want nil", err)
+ }
+ } else {
+ // ... or assert an error for others.
+ if _, _, _, err := walked.Open(p9.ReadOnly); err != syscall.EINVAL {
+ h.t.Errorf("got open err %v, want EINVAL", err)
+ }
+ }
+ return parentBackend, walkedBackend, walked
+}
+
+// createHelper creates the given file and returns the parent directory,
+// created file and client file, which must be closed when done.
+func createHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) {
+ // Clone the directory first, since Create replaces the existing file.
+ // We change the type after calling create.
+ _, dirThenFile, err := dir.Walk(nil)
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+
+ // Create a new server-side file. On the server-side, the a new file is
+ // returned from a create call. The client will reuse the same file,
+ // but we still expect the normal chain of closes. This complicates
+ // things a bit because the "parent" will always chain to the cloned
+ // dir above.
+ dirBackend := h.Pop(dirThenFile) // New backend directory.
+ newFile := h.NewFile()(dirBackend) // New file with backend parent.
+ dirBackend.EXPECT().Create(name, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, newFile, newFile.QID, uint32(0), nil)
+
+ // Create via the client.
+ _, dirThenFile, _, _, err = dirThenFile.Create(name, p9.ReadOnly, 0, 0, 0)
+ if err != nil {
+ h.t.Fatalf("got create err %v, want nil", err)
+ }
+
+ // Ensure subsequent walks succeed.
+ dirBackend.AddChild(name, h.NewFile())
+ return dirBackend, newFile, dirThenFile
+}
+
+// deprecatedRemover allows us to access the deprecated Remove operation within
+// the p9.File client object.
+type deprecatedRemover interface {
+ Remove() error
+}
+
+// checkDeleted asserts that relevant methods fail for an unlinked file.
+//
+// This function will close the file at the end.
+func checkDeleted(h *Harness, file p9.File) {
+ defer file.Close() // See doc.
+
+ if _, _, _, err := file.Open(p9.ReadOnly); err != syscall.EINVAL {
+ h.t.Errorf("open while deleted, got %v, want EINVAL", err)
+ }
+ if _, _, _, _, err := file.Create("created", p9.ReadOnly, 0, 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("create while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Symlink("old", "new", 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("symlink while deleted, got %v, want EINVAL", err)
+ }
+ // N.B. This link is technically invalid, but if a call to link is
+ // actually made in the backend then the mock will panic.
+ if err := file.Link(file, "new"); err != syscall.EINVAL {
+ h.t.Errorf("link while deleted, got %v, want EINVAL", err)
+ }
+ if err := file.RenameAt("src", file, "dst"); err != syscall.EINVAL {
+ h.t.Errorf("renameAt while deleted, got %v, want EINVAL", err)
+ }
+ if err := file.UnlinkAt("file", 0); err != syscall.EINVAL {
+ h.t.Errorf("unlinkAt while deleted, got %v, want EINVAL", err)
+ }
+ if err := file.Rename(file, "dst"); err != syscall.EINVAL {
+ h.t.Errorf("rename while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Readlink(); err != syscall.EINVAL {
+ h.t.Errorf("readlink while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Mkdir("dir", p9.ModeDirectory, 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("mkdir while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Mknod("dir", p9.ModeDirectory, 0, 0, 0, 0); err != syscall.EINVAL {
+ h.t.Errorf("mknod while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Readdir(0, 1); err != syscall.EINVAL {
+ h.t.Errorf("readdir while deleted, got %v, want EINVAL", err)
+ }
+ if _, err := file.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL {
+ h.t.Errorf("connect while deleted, got %v, want EINVAL", err)
+ }
+
+ // The remove method is technically deprecated, but we want to ensure
+ // that it still checks for deleted appropriately. We must first clone
+ // the file because remove is equivalent to close.
+ _, newFile, err := file.Walk(nil)
+ if err == syscall.EBUSY {
+ // We can't walk from here because this reference is open
+ // already. Okay, we will also have unopened cases through
+ // TestUnlink, just skip the remove operation for now.
+ return
+ } else if err != nil {
+ h.t.Fatalf("clone failed, got %v, want nil", err)
+ }
+ if err := newFile.(deprecatedRemover).Remove(); err != syscall.EINVAL {
+ h.t.Errorf("remove while deleted, got %v, want EINVAL", err)
+ }
+}
+
+// deleter is a function to remove a file.
+type deleter func(parent p9.File, name string) error
+
+// unlinkAt is a deleter.
+func unlinkAt(parent p9.File, name string) error {
+ // Call unlink. Note that a filesystem may normally impose additional
+ // constaints on unlinkat success, such as ensuring that a directory is
+ // empty, requiring AT_REMOVEDIR in flags to remove a directory, etc.
+ // None of that is required internally (entire trees can be marked
+ // deleted when this operation succeeds), so the mock will succeed.
+ return parent.UnlinkAt(name, 0)
+}
+
+// remove is a deleter.
+func remove(parent p9.File, name string) error {
+ // See notes above re: remove.
+ _, newFile, err := parent.Walk([]string{name})
+ if err != nil {
+ // Should not be expected.
+ return err
+ }
+
+ // Do the actual remove.
+ if err := newFile.(deprecatedRemover).Remove(); err != nil {
+ return err
+ }
+
+ // Ensure that the remove closed the file.
+ if err := newFile.(deprecatedRemover).Remove(); err != syscall.EBADF {
+ return syscall.EBADF // Propagate this code.
+ }
+
+ return nil
+}
+
+// unlinkHelper unlinks the noted path, and ensures that all relevant
+// operations on that path, acquired from multiple paths, start failing.
+func unlinkHelper(h *Harness, root p9.File, targetNames []string, targetGen fileGenerator, deleteFn deleter) {
+ // name is the file to be unlinked.
+ name := targetNames[len(targetNames)-1]
+
+ // Walk to the directory containing the target.
+ _, parent, err := root.Walk(targetNames[:len(targetNames)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer parent.Close()
+ parentBackend := h.Pop(parent)
+
+ // Walk to or generate the target file.
+ _, _, target := targetGen(h, name, parent)
+ defer checkDeleted(h, target)
+
+ // Walk to a second reference.
+ _, second, err := parent.Walk([]string{name})
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer checkDeleted(h, second)
+
+ // Walk to a third reference, from the start.
+ _, third, err := root.Walk(targetNames)
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer checkDeleted(h, third)
+
+ // This will be translated in the backend to an unlinkat.
+ parentBackend.EXPECT().UnlinkAt(name, uint32(0)).Return(nil)
+
+ // Actually perform the deletion.
+ if err := deleteFn(parent, name); err != nil {
+ h.t.Fatalf("got delete err %v, want nil", err)
+ }
+}
+
+func unlinkTest(t *testing.T, targetNames []string, targetGen fileGenerator) {
+ t.Run(fmt.Sprintf("unlinkAt(%s)", strings.Join(targetNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ unlinkHelper(h, root, targetNames, targetGen, unlinkAt)
+ })
+ t.Run(fmt.Sprintf("remove(%s)", strings.Join(targetNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ unlinkHelper(h, root, targetNames, targetGen, remove)
+ })
+}
+
+func TestUnlink(t *testing.T) {
+ // Unlink all files.
+ for name := range newTypeMap(nil) {
+ unlinkTest(t, []string{name}, walkHelper)
+ unlinkTest(t, []string{name}, walkAndOpenHelper)
+ unlinkTest(t, []string{"one", name}, walkHelper)
+ unlinkTest(t, []string{"one", name}, walkAndOpenHelper)
+ unlinkTest(t, []string{"one", "two", name}, walkHelper)
+ unlinkTest(t, []string{"one", "two", name}, walkAndOpenHelper)
+ }
+
+ // Unlink a directory.
+ unlinkTest(t, []string{"one"}, walkHelper)
+ unlinkTest(t, []string{"one"}, walkAndOpenHelper)
+ unlinkTest(t, []string{"one", "two"}, walkHelper)
+ unlinkTest(t, []string{"one", "two"}, walkAndOpenHelper)
+
+ // Unlink created files.
+ unlinkTest(t, []string{"created"}, createHelper)
+ unlinkTest(t, []string{"one", "created"}, createHelper)
+ unlinkTest(t, []string{"one", "two", "created"}, createHelper)
+}
+
+func TestUnlinkAtInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.UnlinkAt(invalidName, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+// expectRenamed asserts an ordered sequence of rename calls, based on all the
+// elements in elements being the source, and the first element therein
+// changing to dstName, parented at dstParent.
+func expectRenamed(file *Mock, elements []string, dstParent *Mock, dstName string) *gomock.Call {
+ if len(elements) > 0 {
+ // Recurse to the parent, if necessary.
+ call := expectRenamed(file.parent, elements[:len(elements)-1], dstParent, dstName)
+
+ // Recursive case: this element is unchanged, but should have
+ // it's hook called after the parent.
+ return file.EXPECT().Renamed(file.parent, elements[len(elements)-1]).Do(func(p p9.File, _ string) {
+ file.parent = p.(*Mock)
+ }).After(call)
+ }
+
+ // Base case: this is the changed element.
+ return file.EXPECT().Renamed(dstParent, dstName).Do(func(p p9.File, name string) {
+ file.parent = p.(*Mock)
+ })
+}
+
+// renamer is a rename function.
+type renamer func(h *Harness, srcParent, dstParent p9.File, origName, newName string, selfRename bool) error
+
+// renameAt is a renamer.
+func renameAt(_ *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error {
+ return srcParent.RenameAt(srcName, dstParent, dstName)
+}
+
+// rename is a renamer.
+func rename(h *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error {
+ _, f, err := srcParent.Walk([]string{srcName})
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ if !selfRename {
+ backend := h.Pop(f)
+ backend.EXPECT().Renamed(gomock.Any(), dstName).Do(func(p p9.File, name string) {
+ backend.parent = p.(*Mock) // Required for close ordering.
+ })
+ }
+ return f.Rename(dstParent, dstName)
+}
+
+// renameHelper executes a rename, and asserts that all relevant elements
+// receive expected notifications. If overwriting a file, this includes
+// ensuring that the target has been appropriately marked as unlinked.
+func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string, target fileGenerator, renameFn renamer) {
+ // Walk to the directory containing the target.
+ srcQID, targetParent, err := root.Walk(srcNames[:len(srcNames)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer targetParent.Close()
+ targetParentBackend := h.Pop(targetParent)
+
+ // Walk to or generate the target file.
+ _, targetBackend, src := target(h, srcNames[len(srcNames)-1], targetParent)
+ defer src.Close()
+
+ // Walk to a second reference.
+ _, second, err := targetParent.Walk([]string{srcNames[len(srcNames)-1]})
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer second.Close()
+ secondBackend := h.Pop(second)
+
+ // Walk to a third reference, from the start.
+ _, third, err := root.Walk(srcNames)
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer third.Close()
+ thirdBackend := h.Pop(third)
+
+ // Find the common suffix to identify the rename parent.
+ var (
+ renameDestPath []string
+ renameSrcPath []string
+ selfRename bool
+ )
+ for i := 1; i <= len(srcNames) && i <= len(dstNames); i++ {
+ if srcNames[len(srcNames)-i] != dstNames[len(dstNames)-i] {
+ // Take the full prefix of dstNames up until this
+ // point, including the first mismatched name. The
+ // first mismatch must be the renamed entry.
+ renameDestPath = dstNames[:len(dstNames)-i+1]
+ renameSrcPath = srcNames[:len(srcNames)-i+1]
+
+ // Does the renameDestPath fully contain the
+ // renameSrcPath here? If yes, then this is a mismatch.
+ // We can't rename the src to some subpath of itself.
+ if len(renameDestPath) > len(renameSrcPath) &&
+ reflect.DeepEqual(renameDestPath[:len(renameSrcPath)], renameSrcPath) {
+ renameDestPath = nil
+ renameSrcPath = nil
+ continue
+ }
+ break
+ }
+ }
+ if len(renameSrcPath) == 0 || len(renameDestPath) == 0 {
+ // This must be a rename to self, or a tricky look-alike. This
+ // happens iff we fail to find a suitable divergence in the two
+ // paths. It's a true self move if the path length is the same.
+ renameDestPath = dstNames
+ renameSrcPath = srcNames
+ selfRename = len(srcNames) == len(dstNames)
+ }
+
+ // Walk to the source parent.
+ _, srcParent, err := root.Walk(renameSrcPath[:len(renameSrcPath)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer srcParent.Close()
+ srcParentBackend := h.Pop(srcParent)
+
+ // Walk to the destination parent.
+ _, dstParent, err := root.Walk(renameDestPath[:len(renameDestPath)-1])
+ if err != nil {
+ h.t.Fatalf("got walk err %v, want nil", err)
+ }
+ defer dstParent.Close()
+ dstParentBackend := h.Pop(dstParent)
+
+ // expectedErr is the result of the rename operation.
+ var expectedErr error
+
+ // Walk to the target file, if one exists.
+ dstQID, dst, err := root.Walk(renameDestPath)
+ if err == nil {
+ if !selfRename && srcQID[0].Type == dstQID[0].Type {
+ // If there is a destination file, and is it of the
+ // same type as the source file, then we expect the
+ // rename to succeed. We expect the destination file to
+ // be deleted, so we run a deletion test on it in this
+ // case.
+ defer checkDeleted(h, dst)
+ } else {
+ if !selfRename {
+ // If the type is different than the
+ // destination, then we expect the rename to
+ // fail. We expect ensure that this is
+ // returned.
+ expectedErr = syscall.EINVAL
+ } else {
+ // This is the file being renamed to itself.
+ // This is technically allowed and a no-op, but
+ // all the triggers will fire.
+ }
+ dst.Close()
+ }
+ }
+ dstName := renameDestPath[len(renameDestPath)-1] // Renamed element.
+ srcName := renameSrcPath[len(renameSrcPath)-1] // Renamed element.
+ if expectedErr == nil && !selfRename {
+ // Expect all to be renamed appropriately. Note that if this is
+ // a final file being renamed, then we expect the file to be
+ // called with the new parent. If not, then we expect the
+ // rename hook to be called, but the parent will remain
+ // unchanged.
+ elements := srcNames[len(renameSrcPath):]
+ expectRenamed(targetBackend, elements, dstParentBackend, dstName)
+ expectRenamed(secondBackend, elements, dstParentBackend, dstName)
+ expectRenamed(thirdBackend, elements, dstParentBackend, dstName)
+
+ // The target parent has also been opened, and may be moved
+ // directly or indirectly.
+ if len(elements) > 1 {
+ expectRenamed(targetParentBackend, elements[:len(elements)-1], dstParentBackend, dstName)
+ }
+ }
+
+ // Expect the rename if it's not the same file. Note that like unlink,
+ // renames are always translated to the at variant in the backend.
+ if !selfRename {
+ srcParentBackend.EXPECT().RenameAt(srcName, dstParentBackend, dstName).Return(expectedErr)
+ }
+
+ // Perform the actual rename; everything has been lined up.
+ if err := renameFn(h, srcParent, dstParent, srcName, dstName, selfRename); err != expectedErr {
+ h.t.Fatalf("got rename err %v, want %v", err, expectedErr)
+ }
+}
+
+func renameTest(t *testing.T, srcNames []string, dstNames []string, target fileGenerator) {
+ t.Run(fmt.Sprintf("renameAt(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ renameHelper(h, root, srcNames, dstNames, target, renameAt)
+ })
+ t.Run(fmt.Sprintf("rename(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ renameHelper(h, root, srcNames, dstNames, target, rename)
+ })
+}
+
+func TestRename(t *testing.T) {
+ // In-directory rename, simple case.
+ for name := range newTypeMap(nil) {
+ // Within the root.
+ renameTest(t, []string{name}, []string{"renamed"}, walkHelper)
+ renameTest(t, []string{name}, []string{"renamed"}, walkAndOpenHelper)
+
+ // Within a subdirectory.
+ renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkAndOpenHelper)
+ }
+
+ // ... with created files.
+ renameTest(t, []string{"created"}, []string{"renamed"}, createHelper)
+ renameTest(t, []string{"one", "created"}, []string{"one", "renamed"}, createHelper)
+
+ // Across directories.
+ for name := range newTypeMap(nil) {
+ // Down one level.
+ renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkAndOpenHelper)
+
+ // Up one level.
+ renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkAndOpenHelper)
+
+ // Across at the same level.
+ renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkAndOpenHelper)
+ }
+
+ // ... with created files.
+ renameTest(t, []string{"one", "created"}, []string{"one", "two", "renamed"}, createHelper)
+ renameTest(t, []string{"one", "two", "created"}, []string{"one", "renamed"}, createHelper)
+ renameTest(t, []string{"one", "created"}, []string{"three", "renamed"}, createHelper)
+
+ // Renaming parents.
+ for name := range newTypeMap(nil) {
+ // Rename a parent.
+ renameTest(t, []string{"one", name}, []string{"renamed", name}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"renamed", name}, walkAndOpenHelper)
+
+ // Rename a super parent.
+ renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkHelper)
+ renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkAndOpenHelper)
+ }
+
+ // ... with created files.
+ renameTest(t, []string{"one", "created"}, []string{"renamed", "created"}, createHelper)
+ renameTest(t, []string{"one", "two", "created"}, []string{"renamed", "created"}, createHelper)
+
+ // Over existing files, including itself.
+ for name := range newTypeMap(nil) {
+ for other := range newTypeMap(nil) {
+ // Overwrite the noted file (may be itself).
+ renameTest(t, []string{"one", name}, []string{"one", other}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", other}, walkAndOpenHelper)
+
+ // Overwrite other files in another directory.
+ renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkHelper)
+ renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkAndOpenHelper)
+ }
+
+ // Overwrite by moving the parent.
+ renameTest(t, []string{"three", name}, []string{"one", name}, walkHelper)
+ renameTest(t, []string{"three", name}, []string{"one", name}, walkAndOpenHelper)
+
+ // Create over the types.
+ renameTest(t, []string{"one", "created"}, []string{"one", name}, createHelper)
+ renameTest(t, []string{"one", "created"}, []string{"one", "two", name}, createHelper)
+ renameTest(t, []string{"three", "created"}, []string{"one", name}, createHelper)
+ }
+}
+
+func TestRenameInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.Rename(root, invalidName); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestRenameAtInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.RenameAt(invalidName, root, "okay"); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ if err := root.RenameAt("okay", root, invalidName); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+// TestRenameSecondOrder tests that indirect rename targets continue to receive
+// Renamed calls after a rename of its renamed parent. i.e.,
+//
+// 1. Create /one/file
+// 2. Create /directory
+// 3. Rename /one -> /directory/one
+// 4. Rename /directory -> /three/foo
+// 5. file from (1) should still receive Renamed.
+//
+// This is a regression test for b/135219260.
+func TestRenameSecondOrder(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ rootBackend, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to /one.
+ _, oneBackend, oneFile := walkHelper(h, "one", root)
+ defer oneFile.Close()
+
+ // Walk to and generate /one/file.
+ //
+ // walkHelper re-walks to oneFile, so we need the second backend,
+ // which will also receive Renamed calls.
+ oneSecondBackend, fileBackend, fileFile := walkHelper(h, "file", oneFile)
+ defer fileFile.Close()
+
+ // Walk to and generate /directory.
+ _, directoryBackend, directoryFile := walkHelper(h, "directory", root)
+ defer directoryFile.Close()
+
+ // Rename /one to /directory/one.
+ rootBackend.EXPECT().RenameAt("one", directoryBackend, "one").Return(nil)
+ expectRenamed(oneBackend, []string{}, directoryBackend, "one")
+ expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one")
+ expectRenamed(fileBackend, []string{}, oneBackend, "file")
+ if err := renameAt(h, root, directoryFile, "one", "one", false); err != nil {
+ h.t.Fatalf("got rename err %v, want nil", err)
+ }
+
+ // Walk to /three.
+ _, threeBackend, threeFile := walkHelper(h, "three", root)
+ defer threeFile.Close()
+
+ // Rename /directory to /three/foo.
+ rootBackend.EXPECT().RenameAt("directory", threeBackend, "foo").Return(nil)
+ expectRenamed(directoryBackend, []string{}, threeBackend, "foo")
+ expectRenamed(oneBackend, []string{}, directoryBackend, "one")
+ expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one")
+ expectRenamed(fileBackend, []string{}, oneBackend, "file")
+ if err := renameAt(h, root, threeFile, "directory", "foo", false); err != nil {
+ h.t.Fatalf("got rename err %v, want nil", err)
+ }
+}
+
+func TestReadlink(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, f, err := root.Walk([]string{name})
+ if err != nil {
+ t.Fatalf("walk failed: got %v, wanted nil", err)
+ }
+ defer f.Close()
+ backend := h.Pop(f)
+
+ const symlinkTarget = "symlink-target"
+
+ if backend.Attr.Mode.IsSymlink() {
+ // This should only go through on symlinks.
+ backend.EXPECT().Readlink().Return(symlinkTarget, nil)
+ }
+
+ // Attempt a Readlink operation.
+ target, err := f.Readlink()
+ if err != nil && err != syscall.EINVAL {
+ t.Errorf("readlink got %v, wanted EINVAL", err)
+ } else if err == nil && target != symlinkTarget {
+ t.Errorf("readlink got %v, wanted %v", target, symlinkTarget)
+ }
+ })
+ }
+}
+
+// fdTest is a wrapper around operations that may send file descriptors. This
+// asserts that the file descriptors are working as intended.
+func fdTest(t *testing.T, sendFn func(*fd.FD) *fd.FD) {
+ // Create a pipe that we can read from.
+ r, w, err := os.Pipe()
+ if err != nil {
+ t.Fatalf("unable to create pipe: %v", err)
+ }
+ defer r.Close()
+ defer w.Close()
+
+ // Attempt to send the write end.
+ wFD, err := fd.NewFromFile(w)
+ if err != nil {
+ t.Fatalf("unable to convert file: %v", err)
+ }
+ defer wFD.Close() // This is a copy.
+
+ // Send wFD and receive newFD.
+ newFD := sendFn(wFD)
+ defer newFD.Close()
+
+ // Attempt to write.
+ const message = "hello"
+ if _, err := newFD.Write([]byte(message)); err != nil {
+ t.Fatalf("write got %v, wanted nil", err)
+ }
+
+ // Should see the message on our end.
+ buffer := []byte(message)
+ if _, err := io.ReadFull(r, buffer); err != nil {
+ t.Fatalf("read got %v, wanted nil", err)
+ }
+ if string(buffer) != message {
+ t.Errorf("got message %v, wanted %v", string(buffer), message)
+ }
+}
+
+func TestConnect(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Catch all the non-socket cases.
+ if !backend.Attr.Mode.IsSocket() {
+ // This has been set up to fail if Connect is called.
+ if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL {
+ t.Errorf("connect got %v, wanted EINVAL", err)
+ }
+ return
+ }
+
+ // Ensure the fd exchange works.
+ fdTest(t, func(send *fd.FD) *fd.FD {
+ backend.EXPECT().Connect(p9.ConnectFlags(0)).Return(send, nil)
+ recv, err := backend.Connect(p9.ConnectFlags(0))
+ if err != nil {
+ t.Fatalf("connect got %v, wanted nil", err)
+ }
+ return recv
+ })
+ })
+ }
+}
+
+func TestReaddir(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Catch all the non-directory cases.
+ if !backend.Attr.Mode.IsDir() {
+ // This has also been set up to fail if Readdir is called.
+ if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
+ t.Errorf("readdir got %v, wanted EINVAL", err)
+ }
+ return
+ }
+
+ // Ensure that readdir works for directories.
+ if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
+ t.Errorf("readdir got %v, wanted EINVAL", err)
+ }
+ if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
+ }
+ if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
+ }
+ backend.EXPECT().Open(p9.ReadOnly).Times(1)
+ if _, _, _, err := f.Open(p9.ReadOnly); err != nil {
+ t.Errorf("readdir got %v, wanted nil", err)
+ }
+ backend.EXPECT().Readdir(uint64(0), uint32(1)).Times(1)
+ if _, err := f.Readdir(0, 1); err != nil {
+ t.Errorf("readdir got %v, wanted nil", err)
+ }
+ })
+ }
+}
+
+func TestOpen(t *testing.T) {
+ type openTest struct {
+ name string
+ flags p9.OpenFlags
+ err error
+ match func(p9.FileMode) bool
+ }
+
+ cases := []openTest{
+ {
+ name: "not-openable-read-only",
+ flags: p9.ReadOnly,
+ err: syscall.EINVAL,
+ match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
+ },
+ {
+ name: "not-openable-write-only",
+ flags: p9.WriteOnly,
+ err: syscall.EINVAL,
+ match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
+ },
+ {
+ name: "not-openable-read-write",
+ flags: p9.ReadWrite,
+ err: syscall.EINVAL,
+ match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
+ },
+ {
+ name: "directory-read-only",
+ flags: p9.ReadOnly,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "directory-read-write",
+ flags: p9.ReadWrite,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "directory-write-only",
+ flags: p9.WriteOnly,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "read-only",
+ flags: p9.ReadOnly,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
+ },
+ {
+ name: "write-only",
+ flags: p9.WriteOnly,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "read-write",
+ flags: p9.ReadWrite,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "directory-read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "write-only-truncate",
+ flags: p9.WriteOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "read-write-truncate",
+ flags: p9.ReadWrite | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ }
+
+ // Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
+ // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice
+ // - returning a file works as expected
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Does this match the case?
+ if !tc.match(backend.Attr.Mode) {
+ t.SkipNow()
+ }
+
+ // Ensure open-required operations fail.
+ if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EINVAL {
+ t.Errorf("readAt got %v, wanted EINVAL", err)
+ }
+ if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EINVAL {
+ t.Errorf("writeAt got %v, wanted EINVAL", err)
+ }
+ if err := f.FSync(); err != syscall.EINVAL {
+ t.Errorf("fsync got %v, wanted EINVAL", err)
+ }
+ if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
+ t.Errorf("readdir got %v, wanted EINVAL", err)
+ }
+
+ // Attempt the given open.
+ if tc.err != nil {
+ // We expect an error, just test and return.
+ if _, _, _, err := f.Open(tc.flags); err != tc.err {
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
+ }
+ return
+ }
+
+ // Run an FD test, since we expect success.
+ fdTest(t, func(send *fd.FD) *fd.FD {
+ backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1)
+ recv, _, _, err := f.Open(tc.flags)
+ if err != tc.err {
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
+ }
+ return recv
+ })
+
+ // If the open was successful, attempt another one.
+ if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL {
+ t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err)
+ }
+
+ // Ensure that all illegal operations fail.
+ if _, _, err := f.Walk(nil); err != syscall.EINVAL && err != syscall.EBUSY {
+ t.Errorf("walk got %v, wanted EINVAL or EBUSY", err)
+ }
+ if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EINVAL && err != syscall.EBUSY {
+ t.Errorf("walkgetattr got %v, wanted EINVAL or EBUSY", err)
+ }
+ })
+ }
+ }
+}
+
+func TestClose(t *testing.T) {
+ type closeTest struct {
+ name string
+ closeFn func(backend *Mock, f p9.File)
+ }
+
+ cases := []closeTest{
+ {
+ name: "close",
+ closeFn: func(_ *Mock, f p9.File) {
+ f.Close()
+ },
+ },
+ {
+ name: "remove",
+ closeFn: func(backend *Mock, f p9.File) {
+ // Allow the rename call in the parent, automatically translated.
+ backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1)
+ f.(deprecatedRemover).Remove()
+ },
+ },
+ }
+
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s(%s)", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+
+ // Close via the prescribed method.
+ tc.closeFn(backend, f)
+
+ // Everything should fail with EBADF.
+ if _, _, err := f.Walk(nil); err != syscall.EBADF {
+ t.Errorf("walk got %v, wanted EBADF", err)
+ }
+ if _, err := f.StatFS(); err != syscall.EBADF {
+ t.Errorf("statfs got %v, wanted EBADF", err)
+ }
+ if _, _, _, err := f.GetAttr(p9.AttrMaskAll()); err != syscall.EBADF {
+ t.Errorf("getattr got %v, wanted EBADF", err)
+ }
+ if err := f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}); err != syscall.EBADF {
+ t.Errorf("setattrk got %v, wanted EBADF", err)
+ }
+ if err := f.Rename(root, "new-name"); err != syscall.EBADF {
+ t.Errorf("rename got %v, wanted EBADF", err)
+ }
+ if err := f.Close(); err != syscall.EBADF {
+ t.Errorf("close got %v, wanted EBADF", err)
+ }
+ if _, _, _, err := f.Open(p9.ReadOnly); err != syscall.EBADF {
+ t.Errorf("open got %v, wanted EBADF", err)
+ }
+ if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EBADF {
+ t.Errorf("readAt got %v, wanted EBADF", err)
+ }
+ if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EBADF {
+ t.Errorf("writeAt got %v, wanted EBADF", err)
+ }
+ if err := f.FSync(); err != syscall.EBADF {
+ t.Errorf("fsync got %v, wanted EBADF", err)
+ }
+ if _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 0, 0); err != syscall.EBADF {
+ t.Errorf("create got %v, wanted EBADF", err)
+ }
+ if _, err := f.Mkdir("new-directory", 0, 0, 0); err != syscall.EBADF {
+ t.Errorf("mkdir got %v, wanted EBADF", err)
+ }
+ if _, err := f.Symlink("old-name", "new-name", 0, 0); err != syscall.EBADF {
+ t.Errorf("symlink got %v, wanted EBADF", err)
+ }
+ if err := f.Link(root, "new-name"); err != syscall.EBADF {
+ t.Errorf("link got %v, wanted EBADF", err)
+ }
+ if _, err := f.Mknod("new-block-device", 0, 0, 0, 0, 0); err != syscall.EBADF {
+ t.Errorf("mknod got %v, wanted EBADF", err)
+ }
+ if err := f.RenameAt("old-name", root, "new-name"); err != syscall.EBADF {
+ t.Errorf("renameAt got %v, wanted EBADF", err)
+ }
+ if err := f.UnlinkAt("name", 0); err != syscall.EBADF {
+ t.Errorf("unlinkAt got %v, wanted EBADF", err)
+ }
+ if _, err := f.Readdir(0, 1); err != syscall.EBADF {
+ t.Errorf("readdir got %v, wanted EBADF", err)
+ }
+ if _, err := f.Readlink(); err != syscall.EBADF {
+ t.Errorf("readlink got %v, wanted EBADF", err)
+ }
+ if err := f.Flush(); err != syscall.EBADF {
+ t.Errorf("flush got %v, wanted EBADF", err)
+ }
+ if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EBADF {
+ t.Errorf("walkgetattr got %v, wanted EBADF", err)
+ }
+ if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EBADF {
+ t.Errorf("connect got %v, wanted EBADF", err)
+ }
+ })
+ }
+ }
+}
+
+// onlyWorksOnOpenThings is a helper test method for operations that should
+// only work on files that have been explicitly opened.
+func onlyWorksOnOpenThings(h *Harness, t *testing.T, name string, root p9.File, mode p9.OpenFlags, expectedErr error, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) {
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Does it work before opening?
+ if err := fn(backend, f, false); err != syscall.EINVAL {
+ t.Errorf("operation got %v, wanted EINVAL", err)
+ }
+
+ // Is this openable?
+ if !p9.CanOpen(backend.Attr.Mode) {
+ return // Nothing to do.
+ }
+
+ // If this is a directory, we can't handle writing.
+ if backend.Attr.Mode.IsDir() && (mode == p9.ReadWrite || mode == p9.WriteOnly) {
+ return // Skip.
+ }
+
+ // Open the file.
+ backend.EXPECT().Open(mode)
+ if _, _, _, err := f.Open(mode); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+
+ // Attempt the operation.
+ if err := fn(backend, f, expectedErr == nil); err != expectedErr {
+ t.Fatalf("operation got %v, wanted %v", err, expectedErr)
+ }
+}
+
+func TestRead(t *testing.T) {
+ type readTest struct {
+ name string
+ mode p9.OpenFlags
+ err error
+ }
+
+ cases := []readTest{
+ {
+ name: "read-only",
+ mode: p9.ReadOnly,
+ err: nil,
+ },
+ {
+ name: "read-write",
+ mode: p9.ReadWrite,
+ err: nil,
+ },
+ {
+ name: "write-only",
+ mode: p9.WriteOnly,
+ err: syscall.EPERM,
+ },
+ }
+
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const message = "hello"
+
+ onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if !shouldSucceed {
+ _, err := f.ReadAt([]byte(message), 0)
+ return err
+ }
+
+ // Prepare for the call to readAt in the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ copy(p, message)
+ }).Return(len(message), nil)
+
+ // Make the client call.
+ p := make([]byte, 2*len(message)) // Double size.
+ n, err := f.ReadAt(p, 0)
+
+ // Sanity check result.
+ if err != nil {
+ return err
+ }
+ if n != len(message) {
+ t.Fatalf("message length incorrect, got %d, want %d", n, len(message))
+ }
+ if !bytes.Equal(p[:n], []byte(message)) {
+ t.Fatalf("message incorrect, got %v, want %v", p, []byte(message))
+ }
+ return nil // Success.
+ })
+ })
+ }
+ }
+}
+
+func TestWrite(t *testing.T) {
+ type writeTest struct {
+ name string
+ mode p9.OpenFlags
+ err error
+ }
+
+ cases := []writeTest{
+ {
+ name: "read-only",
+ mode: p9.ReadOnly,
+ err: syscall.EPERM,
+ },
+ {
+ name: "read-write",
+ mode: p9.ReadWrite,
+ err: nil,
+ },
+ {
+ name: "write-only",
+ mode: p9.WriteOnly,
+ err: nil,
+ },
+ }
+
+ for name := range newTypeMap(nil) {
+ for _, tc := range cases {
+ t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const message = "hello"
+
+ onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if !shouldSucceed {
+ _, err := f.WriteAt([]byte(message), 0)
+ return err
+ }
+
+ // Prepare for the call to readAt in the backend.
+ var output []byte // Saved by Do below.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ output = p
+ }).Return(len(message), nil)
+
+ // Make the client call.
+ n, err := f.WriteAt([]byte(message), 0)
+
+ // Sanity check result.
+ if err != nil {
+ return err
+ }
+ if n != len(message) {
+ t.Fatalf("message length incorrect, got %d, want %d", n, len(message))
+ }
+ if !bytes.Equal(output, []byte(message)) {
+ t.Fatalf("message incorrect, got %v, want %v", output, []byte(message))
+ }
+ return nil // Success.
+ })
+ })
+ }
+ }
+}
+
+func TestFSync(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ for _, mode := range []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} {
+ t.Run(fmt.Sprintf("%s-%s", mode, name), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnOpenThings(h, t, name, root, mode, nil, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().FSync().Times(1)
+ }
+ return f.FSync()
+ })
+ })
+ }
+ }
+}
+
+func TestFlush(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ backend.EXPECT().Flush()
+ f.Flush()
+ })
+ }
+}
+
+// onlyWorksOnDirectories is a helper test method for operations that should
+// only work on unopened directories, such as create, mkdir and symlink.
+func onlyWorksOnDirectories(h *Harness, t *testing.T, name string, root p9.File, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) {
+ // Walk to the file normally.
+ _, backend, f := walkHelper(h, name, root)
+ defer f.Close()
+
+ // Only directories support mknod.
+ if !backend.Attr.Mode.IsDir() {
+ if err := fn(backend, f, false); err != syscall.EINVAL {
+ t.Errorf("operation got %v, wanted EINVAL", err)
+ }
+ return // Nothing else to do.
+ }
+
+ // Should succeed.
+ if err := fn(backend, f, true); err != nil {
+ t.Fatalf("operation got %v, wanted nil", err)
+ }
+
+ // Open the directory.
+ backend.EXPECT().Open(p9.ReadOnly).Times(1)
+ if _, _, _, err := f.Open(p9.ReadOnly); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+
+ // Should not work again.
+ if err := fn(backend, f, false); err != syscall.EINVAL {
+ t.Fatalf("operation got %v, wanted EINVAL", err)
+ }
+}
+
+func TestCreate(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if !shouldSucceed {
+ _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 1, 2)
+ return err
+ }
+
+ // If the create is going to succeed, then we
+ // need to create a new backend file, and we
+ // clone to ensure that we don't close the
+ // original.
+ _, newF, err := f.Walk(nil)
+ if err != nil {
+ t.Fatalf("clone got %v, wanted nil", err)
+ }
+ defer newF.Close()
+ newBackend := h.Pop(newF)
+
+ // Run a regular FD test to validate that path.
+ fdTest(t, func(send *fd.FD) *fd.FD {
+ // Return the send FD on success.
+ newFile := h.NewFile()(backend) // New file with the parent backend.
+ newBackend.EXPECT().Create("new-file", p9.ReadWrite, p9.FileMode(0), p9.UID(1), p9.GID(2)).Return(send, newFile, p9.QID{}, uint32(0), nil)
+
+ // Receive the fd back.
+ recv, _, _, _, err := newF.Create("new-file", p9.ReadWrite, 0, 1, 2)
+ if err != nil {
+ t.Fatalf("create got %v, wanted nil", err)
+ }
+ return recv
+ })
+
+ // The above will fail via normal test flow, so
+ // we can assume that it passed.
+ return nil
+ })
+ })
+ }
+}
+
+func TestCreateInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if _, _, _, _, err := root.Create(invalidName, p9.ReadWrite, 0, 0, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestMkdir(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Mkdir("new-directory", p9.FileMode(0), p9.UID(1), p9.GID(2))
+ }
+ _, err := f.Mkdir("new-directory", 0, 1, 2)
+ return err
+ })
+ })
+ }
+}
+
+func TestMkdirInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if _, err := root.Mkdir(invalidName, 0, 0, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestSymlink(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Symlink("old-name", "new-name", p9.UID(1), p9.GID(2))
+ }
+ _, err := f.Symlink("old-name", "new-name", 1, 2)
+ return err
+ })
+ })
+ }
+}
+
+func TestSyminkInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ // We need only test for invalid names in the new name,
+ // the target can be an arbitrary string and we don't
+ // need to sanity check it.
+ if _, err := root.Symlink("old-name", invalidName, 0, 0); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestLink(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Link(gomock.Any(), "new-link")
+ }
+ return f.Link(f, "new-link")
+ })
+ })
+ }
+}
+
+func TestLinkInvalid(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ for name := range newTypeMap(nil) {
+ for _, invalidName := range allInvalidNames(name) {
+ if err := root.Link(root, invalidName); err != syscall.EINVAL {
+ t.Errorf("got %v for name %q, want EINVAL", err, invalidName)
+ }
+ }
+ }
+}
+
+func TestMknod(t *testing.T) {
+ for name := range newTypeMap(nil) {
+ t.Run(name, func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error {
+ if shouldSucceed {
+ backend.EXPECT().Mknod("new-block-device", p9.FileMode(0), uint32(1), uint32(2), p9.UID(3), p9.GID(4)).Times(1)
+ }
+ _, err := f.Mknod("new-block-device", 0, 1, 2, 3, 4)
+ return err
+ })
+ })
+ }
+}
+
+// concurrentFn is a specification of a concurrent operation. This is used to
+// drive the concurrency tests below.
+type concurrentFn struct {
+ name string
+ match func(p9.FileMode) bool
+ op func(h *Harness, backend *Mock, f p9.File, callback func())
+}
+
+func concurrentTest(t *testing.T, name string, fn1, fn2 concurrentFn, sameDir, expectedOkay bool) {
+ var (
+ names1 []string
+ names2 []string
+ )
+ if sameDir {
+ // Use the same file one directory up.
+ names1, names2 = []string{"one", name}, []string{"one", name}
+ } else {
+ // For different directories, just use siblings.
+ names1, names2 = []string{"one", name}, []string{"three", name}
+ }
+
+ t.Run(fmt.Sprintf("%s(%v)+%s(%v)", fn1.name, names1, fn2.name, names2), func(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ // Walk to both files as given.
+ _, f1, err := root.Walk(names1)
+ if err != nil {
+ t.Fatalf("error walking, got %v, want nil", err)
+ }
+ defer f1.Close()
+ b1 := h.Pop(f1)
+ _, f2, err := root.Walk(names2)
+ if err != nil {
+ t.Fatalf("error walking, got %v, want nil", err)
+ }
+ defer f2.Close()
+ b2 := h.Pop(f2)
+
+ // Are these a good match for the current test case?
+ if !fn1.match(b1.Attr.Mode) {
+ t.SkipNow()
+ }
+ if !fn2.match(b2.Attr.Mode) {
+ t.SkipNow()
+ }
+
+ // Construct our "concurrency creator".
+ in1 := make(chan struct{}, 1)
+ in2 := make(chan struct{}, 1)
+ var top sync.WaitGroup
+ var fns sync.WaitGroup
+ defer top.Wait()
+ top.Add(2) // Accounting for below.
+ defer fns.Done()
+ fns.Add(1) // See line above; released before top.Wait.
+ go func() {
+ defer top.Done()
+ fn1.op(h, b1, f1, func() {
+ in1 <- struct{}{}
+ fns.Wait()
+ })
+ }()
+ go func() {
+ defer top.Done()
+ fn2.op(h, b2, f2, func() {
+ in2 <- struct{}{}
+ fns.Wait()
+ })
+ }()
+
+ // Compute a reasonable timeout. If we expect the operation to hang,
+ // give it 10 milliseconds before we assert that it's fine. After all,
+ // there will be a lot of these tests. If we don't expect it to hang,
+ // give it a full minute, since the machine could be slow.
+ timeout := 10 * time.Millisecond
+ if expectedOkay {
+ timeout = 1 * time.Minute
+ }
+
+ // Read the first channel.
+ var second chan struct{}
+ select {
+ case <-in1:
+ second = in2
+ case <-in2:
+ second = in1
+ }
+
+ // Catch concurrency.
+ select {
+ case <-second:
+ // We finished successful. Is this good? Depends on the
+ // expected result.
+ if !expectedOkay {
+ t.Errorf("%q and %q proceeded concurrently!", fn1.name, fn2.name)
+ }
+ case <-time.After(timeout):
+ // Great, things did not proceed concurrently. Is that what we
+ // expected?
+ if expectedOkay {
+ t.Errorf("%q and %q hung concurrently!", fn1.name, fn2.name)
+ }
+ }
+ })
+}
+
+func randomFileName() string {
+ return fmt.Sprintf("%x", rand.Int63())
+}
+
+func TestConcurrency(t *testing.T) {
+ readExclusive := []concurrentFn{
+ {
+ // N.B. We can't explicitly check WalkGetAttr behavior,
+ // but we rely on the fact that the internal code paths
+ // are the same.
+ name: "walk",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // See the documentation of WalkCallback.
+ // Because walk is actually implemented by the
+ // mock, we need a special place for this
+ // callback.
+ //
+ // Note that a clone actually locks the parent
+ // node. So we walk from this node to test
+ // concurrent operations appropriately.
+ backend.WalkCallback = func() error {
+ callback()
+ return nil
+ }
+ f.Walk([]string{randomFileName()}) // Won't exist.
+ },
+ },
+ {
+ name: "fsync",
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Open(gomock.Any())
+ backend.EXPECT().FSync().Do(func() {
+ callback()
+ })
+ f.Open(p9.ReadOnly) // Required.
+ f.FSync()
+ },
+ },
+ {
+ name: "readdir",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Open(gomock.Any())
+ backend.EXPECT().Readdir(gomock.Any(), gomock.Any()).Do(func(uint64, uint32) {
+ callback()
+ })
+ f.Open(p9.ReadOnly) // Required.
+ f.Readdir(0, 1)
+ },
+ },
+ {
+ name: "readlink",
+ match: func(mode p9.FileMode) bool { return mode.IsSymlink() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Readlink().Do(func() {
+ callback()
+ })
+ f.Readlink()
+ },
+ },
+ {
+ name: "connect",
+ match: func(mode p9.FileMode) bool { return mode.IsSocket() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Connect(gomock.Any()).Do(func(p9.ConnectFlags) {
+ callback()
+ })
+ f.Connect(0)
+ },
+ },
+ {
+ name: "open",
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Open(gomock.Any()).Do(func(p9.OpenFlags) {
+ callback()
+ })
+ f.Open(p9.ReadOnly)
+ },
+ },
+ {
+ name: "flush",
+ match: func(mode p9.FileMode) bool { return true },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Flush().Do(func() {
+ callback()
+ })
+ f.Flush()
+ },
+ },
+ }
+ writeExclusive := []concurrentFn{
+ {
+ // N.B. We can't really check getattr. But this is an
+ // extremely low-risk function, it seems likely that
+ // this check is paranoid anyways.
+ name: "setattr",
+ match: func(mode p9.FileMode) bool { return true },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().SetAttr(gomock.Any(), gomock.Any()).Do(func(p9.SetAttrMask, p9.SetAttr) {
+ callback()
+ })
+ f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{})
+ },
+ },
+ {
+ name: "unlinkAt",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) {
+ callback()
+ })
+ f.UnlinkAt(randomFileName(), 0)
+ },
+ },
+ {
+ name: "mknod",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Mknod(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, uint32, uint32, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Mknod(randomFileName(), 0, 0, 0, 0, 0)
+ },
+ },
+ {
+ name: "link",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Link(gomock.Any(), gomock.Any()).Do(func(p9.File, string) {
+ callback()
+ })
+ f.Link(f, randomFileName())
+ },
+ },
+ {
+ name: "symlink",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Symlink(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, string, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Symlink(randomFileName(), randomFileName(), 0, 0)
+ },
+ },
+ {
+ name: "mkdir",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().Mkdir(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Mkdir(randomFileName(), 0, 0, 0)
+ },
+ },
+ {
+ name: "create",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // Return an error for the creation operation, as this is the simplest.
+ backend.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil, p9.QID{}, uint32(0), syscall.EINVAL).Do(func(string, p9.OpenFlags, p9.FileMode, p9.UID, p9.GID) {
+ callback()
+ })
+ f.Create(randomFileName(), p9.ReadOnly, 0, 0, 0)
+ },
+ },
+ }
+ globalExclusive := []concurrentFn{
+ {
+ name: "remove",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // Remove operates on a locked parent. So we
+ // add a child, walk to it and call remove.
+ // Note that because this operation can operate
+ // concurrently with itself, we need to
+ // generate a random file name.
+ randomFile := randomFileName()
+ backend.AddChild(randomFile, h.NewFile())
+ defer backend.RemoveChild(randomFile)
+ _, file, err := f.Walk([]string{randomFile})
+ if err != nil {
+ h.t.Fatalf("walk got %v, want nil", err)
+ }
+
+ // Remove is automatically translated to the parent.
+ backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) {
+ callback()
+ })
+
+ // Remove is also a close.
+ file.(deprecatedRemover).Remove()
+ },
+ },
+ {
+ name: "rename",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ // Similarly to remove, because we need to
+ // operate on a child, we allow a walk.
+ randomFile := randomFileName()
+ backend.AddChild(randomFile, h.NewFile())
+ defer backend.RemoveChild(randomFile)
+ _, file, err := f.Walk([]string{randomFile})
+ if err != nil {
+ h.t.Fatalf("walk got %v, want nil", err)
+ }
+ defer file.Close()
+ fileBackend := h.Pop(file)
+
+ // Rename is automatically translated to the parent.
+ backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) {
+ callback()
+ })
+
+ // Attempt the rename.
+ fileBackend.EXPECT().Renamed(gomock.Any(), gomock.Any())
+ file.Rename(f, randomFileName())
+ },
+ },
+ {
+ name: "renameAt",
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ op: func(h *Harness, backend *Mock, f p9.File, callback func()) {
+ backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) {
+ callback()
+ })
+
+ // Attempt the rename. There are no active fids
+ // with this name, so we don't need to expect
+ // Renamed hooks on anything.
+ f.RenameAt(randomFileName(), f, randomFileName())
+ },
+ },
+ }
+
+ for _, fn1 := range readExclusive {
+ for _, fn2 := range readExclusive {
+ for name := range newTypeMap(nil) {
+ // Everything should be able to proceed in parallel.
+ concurrentTest(t, name, fn1, fn2, true, true)
+ concurrentTest(t, name, fn1, fn2, false, true)
+ }
+ }
+ }
+
+ for _, fn1 := range append(readExclusive, writeExclusive...) {
+ for _, fn2 := range writeExclusive {
+ for name := range newTypeMap(nil) {
+ // Only cross-directory functions should proceed in parallel.
+ concurrentTest(t, name, fn1, fn2, true, false)
+ concurrentTest(t, name, fn1, fn2, false, true)
+ }
+ }
+ }
+
+ for _, fn1 := range append(append(readExclusive, writeExclusive...), globalExclusive...) {
+ for _, fn2 := range globalExclusive {
+ for name := range newTypeMap(nil) {
+ // Nothing should be able to run in parallel.
+ concurrentTest(t, name, fn1, fn2, true, false)
+ concurrentTest(t, name, fn1, fn2, false, false)
+ }
+ }
+ }
+}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
new file mode 100644
index 000000000..dd8b01b6d
--- /dev/null
+++ b/pkg/p9/p9test/p9test.go
@@ -0,0 +1,329 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package p9test provides standard mocks for p9.
+package p9test
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Harness is an attacher mock.
+type Harness struct {
+ t *testing.T
+ mockCtrl *gomock.Controller
+ Attacher *MockAttacher
+ wg sync.WaitGroup
+ clientSocket *unet.Socket
+ mu sync.Mutex
+ created []*Mock
+}
+
+// globalPath is a QID.Path Generator.
+var globalPath uint64
+
+// MakePath returns a globally unique path.
+func MakePath() uint64 {
+ return atomic.AddUint64(&globalPath, 1)
+}
+
+// Generator is a function that generates a new file.
+type Generator func(parent *Mock) *Mock
+
+// Mock is a common mock element.
+type Mock struct {
+ p9.DefaultWalkGetAttr
+ *MockFile
+ parent *Mock
+ closed bool
+ harness *Harness
+ QID p9.QID
+ Attr p9.Attr
+ children map[string]Generator
+
+ // WalkCallback is a special function that will be called from within
+ // the walk context. This is needed for the concurrent tests within
+ // this package.
+ WalkCallback func() error
+}
+
+// globalMu protects the children maps in all mocks. Note that this is not a
+// particularly elegant solution, but because the test has walks from the root
+// through to final nodes, we must share maps below, and it's easiest to simply
+// protect against concurrent access globally.
+var globalMu sync.RWMutex
+
+// AddChild adds a new child to the Mock.
+func (m *Mock) AddChild(name string, generator Generator) {
+ globalMu.Lock()
+ defer globalMu.Unlock()
+ m.children[name] = generator
+}
+
+// RemoveChild removes the child with the given name.
+func (m *Mock) RemoveChild(name string) {
+ globalMu.Lock()
+ defer globalMu.Unlock()
+ delete(m.children, name)
+}
+
+// Matches implements gomock.Matcher.Matches.
+func (m *Mock) Matches(x interface{}) bool {
+ if om, ok := x.(*Mock); ok {
+ return m.QID.Path == om.QID.Path
+ }
+ return false
+}
+
+// String implements gomock.Matcher.String.
+func (m *Mock) String() string {
+ return fmt.Sprintf("Mock{Mode: 0x%x, QID.Path: %d}", m.Attr.Mode, m.QID.Path)
+}
+
+// GetAttr returns the current attributes.
+func (m *Mock) GetAttr(mask p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ return m.QID, p9.AttrMaskAll(), m.Attr, nil
+}
+
+// Walk supports clone and walking in directories.
+func (m *Mock) Walk(names []string) ([]p9.QID, p9.File, error) {
+ if m.WalkCallback != nil {
+ if err := m.WalkCallback(); err != nil {
+ return nil, nil, err
+ }
+ }
+ if len(names) == 0 {
+ // Clone the file appropriately.
+ nm := m.harness.NewMock(m.parent, m.QID.Path, m.Attr)
+ nm.children = m.children // Inherit children.
+ return []p9.QID{nm.QID}, nm, nil
+ } else if len(names) != 1 {
+ m.harness.t.Fail() // Should not happen.
+ return nil, nil, syscall.EINVAL
+ }
+
+ if m.Attr.Mode.IsDir() {
+ globalMu.RLock()
+ defer globalMu.RUnlock()
+ if fn, ok := m.children[names[0]]; ok {
+ // Generate the child.
+ nm := fn(m)
+ return []p9.QID{nm.QID}, nm, nil
+ }
+ // No child found.
+ return nil, nil, syscall.ENOENT
+ }
+
+ // Call the underlying mock.
+ return m.MockFile.Walk(names)
+}
+
+// WalkGetAttr calls the default implementation; this is a client-side optimization.
+func (m *Mock) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) {
+ return m.DefaultWalkGetAttr.WalkGetAttr(names)
+}
+
+// Pop pops off the most recently created Mock and assert that this mock
+// represents the same file passed in. If nil is passed in, no check is
+// performed.
+//
+// Precondition: there must be at least one Mock or this will panic.
+func (h *Harness) Pop(clientFile p9.File) *Mock {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ if clientFile == nil {
+ // If no clientFile is provided, then we always return the last
+ // created file. The caller can safely use this as long as
+ // there is no concurrency.
+ m := h.created[len(h.created)-1]
+ h.created = h.created[:len(h.created)-1]
+ return m
+ }
+
+ qid, _, _, err := clientFile.GetAttr(p9.AttrMaskAll())
+ if err != nil {
+ // We do not expect this to happen.
+ panic(fmt.Sprintf("err during Pop: %v", err))
+ }
+
+ // Find the relevant file in our created list. We must scan the last
+ // from back to front to ensure that we favor the most recently
+ // generated file.
+ for i := len(h.created) - 1; i >= 0; i-- {
+ m := h.created[i]
+ if qid.Path == m.QID.Path {
+ // Copy and truncate.
+ copy(h.created[i:], h.created[i+1:])
+ h.created = h.created[:len(h.created)-1]
+ return m
+ }
+ }
+
+ // Unable to find relevant file.
+ panic(fmt.Sprintf("unable to locate file with QID %+v", qid.Path))
+}
+
+// NewMock returns a new base file.
+func (h *Harness) NewMock(parent *Mock, path uint64, attr p9.Attr) *Mock {
+ m := &Mock{
+ MockFile: NewMockFile(h.mockCtrl),
+ parent: parent,
+ harness: h,
+ QID: p9.QID{
+ Type: p9.QIDType((attr.Mode & p9.FileModeMask) >> 12),
+ Path: path,
+ },
+ Attr: attr,
+ }
+
+ // Always ensure Close is after the parent's close. Note that this
+ // can't be done via a straight-forward After call, because the parent
+ // might change after initial creation. We ensure that this is true at
+ // close time.
+ m.EXPECT().Close().Return(nil).Times(1).Do(func() {
+ if m.parent != nil && m.parent.closed {
+ h.t.FailNow()
+ }
+ // Note that this should not be racy, as this operation should
+ // be protected by the Times(1) above first.
+ m.closed = true
+ })
+
+ // Remember what was created.
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ h.created = append(h.created, m)
+
+ return m
+}
+
+// NewFile returns a new file mock.
+//
+// Note that ReadAt and WriteAt must be mocked separately.
+func (h *Harness) NewFile() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeRegular})
+ }
+}
+
+// NewDirectory returns a new mock directory.
+//
+// Note that Mkdir, Link, Mknod, RenameAt, UnlinkAt and Readdir must be mocked
+// separately. Walk is provided and children may be manipulated via AddChild
+// and RemoveChild. After calling Walk remotely, one can use Pop to find the
+// corresponding backend mock on the server side.
+func (h *Harness) NewDirectory(contents map[string]Generator) Generator {
+ return func(parent *Mock) *Mock {
+ m := h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeDirectory})
+ m.children = contents // Save contents.
+ return m
+ }
+}
+
+// NewSymlink returns a new mock directory.
+//
+// Note that Readlink must be mocked separately.
+func (h *Harness) NewSymlink() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSymlink})
+ }
+}
+
+// NewBlockDevice returns a new mock block device.
+func (h *Harness) NewBlockDevice() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeBlockDevice})
+ }
+}
+
+// NewCharacterDevice returns a new mock character device.
+func (h *Harness) NewCharacterDevice() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeCharacterDevice})
+ }
+}
+
+// NewNamedPipe returns a new mock named pipe.
+func (h *Harness) NewNamedPipe() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeNamedPipe})
+ }
+}
+
+// NewSocket returns a new mock socket.
+func (h *Harness) NewSocket() Generator {
+ return func(parent *Mock) *Mock {
+ return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSocket})
+ }
+}
+
+// Finish completes all checks and shuts down the server.
+func (h *Harness) Finish() {
+ h.clientSocket.Shutdown()
+ h.wg.Wait()
+ h.mockCtrl.Finish()
+}
+
+// NewHarness creates and returns a new test server.
+//
+// It should always be used as:
+//
+// h, c := NewHarness(t)
+// defer h.Finish()
+//
+func NewHarness(t *testing.T) (*Harness, *p9.Client) {
+ // Create the mock.
+ mockCtrl := gomock.NewController(t)
+ h := &Harness{
+ t: t,
+ mockCtrl: mockCtrl,
+ Attacher: NewMockAttacher(mockCtrl),
+ }
+
+ // Make socket pair.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v wanted nil", err)
+ }
+
+ // Start the server, synchronized on exit.
+ server := p9.NewServer(h.Attacher)
+ h.wg.Add(1)
+ go func() {
+ defer h.wg.Done()
+ server.Handle(serverSocket)
+ }()
+
+ // Create the client.
+ client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString())
+ if err != nil {
+ serverSocket.Close()
+ clientSocket.Close()
+ t.Fatalf("new client got %v, expected nil", err)
+ return nil, nil // Never hit.
+ }
+
+ // Capture the client socket.
+ h.clientSocket = clientSocket
+ return h, client
+}
diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go
new file mode 100644
index 000000000..72ef53313
--- /dev/null
+++ b/pkg/p9/path_tree.go
@@ -0,0 +1,222 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// pathNode is a single node in a path traversal.
+//
+// These are shared by all fidRefs that point to the same path.
+//
+// Lock ordering:
+// opMu
+// childMu
+//
+// Two different pathNodes may only be locked if Server.renameMu is held for
+// write, in which case they can be acquired in any order.
+type pathNode struct {
+ // opMu synchronizes high-level, sematic operations, such as the
+ // simultaneous creation and deletion of a file.
+ //
+ // opMu does not directly protect any fields in pathNode.
+ opMu sync.RWMutex
+
+ // childMu protects the fields below.
+ childMu sync.RWMutex
+
+ // childNodes maps child path component names to their pathNode.
+ childNodes map[string]*pathNode
+
+ // childRefs maps child path component names to all of the their
+ // references.
+ childRefs map[string]map[*fidRef]struct{}
+
+ // childRefNames maps child references back to their path component
+ // name.
+ childRefNames map[*fidRef]string
+}
+
+func newPathNode() *pathNode {
+ return &pathNode{
+ childNodes: make(map[string]*pathNode),
+ childRefs: make(map[string]map[*fidRef]struct{}),
+ childRefNames: make(map[*fidRef]string),
+ }
+}
+
+// forEachChildRef calls fn for each child reference.
+func (p *pathNode) forEachChildRef(fn func(ref *fidRef, name string)) {
+ p.childMu.RLock()
+ defer p.childMu.RUnlock()
+
+ for name, m := range p.childRefs {
+ for ref := range m {
+ fn(ref, name)
+ }
+ }
+}
+
+// forEachChildNode calls fn for each child pathNode.
+func (p *pathNode) forEachChildNode(fn func(pn *pathNode)) {
+ p.childMu.RLock()
+ defer p.childMu.RUnlock()
+
+ for _, pn := range p.childNodes {
+ fn(pn)
+ }
+}
+
+// pathNodeFor returns the path node for the given name, or a new one.
+func (p *pathNode) pathNodeFor(name string) *pathNode {
+ p.childMu.RLock()
+ // Fast path, node already exists.
+ if pn, ok := p.childNodes[name]; ok {
+ p.childMu.RUnlock()
+ return pn
+ }
+ p.childMu.RUnlock()
+
+ // Slow path, create a new pathNode for shared use.
+ p.childMu.Lock()
+
+ // Re-check after re-lock.
+ if pn, ok := p.childNodes[name]; ok {
+ p.childMu.Unlock()
+ return pn
+ }
+
+ pn := newPathNode()
+ p.childNodes[name] = pn
+ p.childMu.Unlock()
+ return pn
+}
+
+// nameFor returns the name for the given fidRef.
+//
+// Precondition: addChild is called for ref before nameFor.
+func (p *pathNode) nameFor(ref *fidRef) string {
+ p.childMu.RLock()
+ n, ok := p.childRefNames[ref]
+ p.childMu.RUnlock()
+
+ if !ok {
+ // This should not happen, don't proceed.
+ panic(fmt.Sprintf("expected name for %+v, none found", ref))
+ }
+
+ return n
+}
+
+// addChildLocked adds a child reference to p.
+//
+// Precondition: As addChild, plus childMu is locked for write.
+func (p *pathNode) addChildLocked(ref *fidRef, name string) {
+ if n, ok := p.childRefNames[ref]; ok {
+ // This should not happen, don't proceed.
+ panic(fmt.Sprintf("unexpected fidRef %+v with path %q, wanted %q", ref, n, name))
+ }
+
+ p.childRefNames[ref] = name
+
+ m, ok := p.childRefs[name]
+ if !ok {
+ m = make(map[*fidRef]struct{})
+ p.childRefs[name] = m
+ }
+
+ m[ref] = struct{}{}
+}
+
+// addChild adds a child reference to p.
+//
+// Precondition: ref may only be added once at a time.
+func (p *pathNode) addChild(ref *fidRef, name string) {
+ p.childMu.Lock()
+ p.addChildLocked(ref, name)
+ p.childMu.Unlock()
+}
+
+// removeChild removes the given child.
+//
+// This applies only to an individual fidRef, which is not required to exist.
+func (p *pathNode) removeChild(ref *fidRef) {
+ p.childMu.Lock()
+
+ // This ref may not exist anymore. This can occur, e.g., in unlink,
+ // where a removeWithName removes the ref, and then a DecRef on the ref
+ // attempts to remove again.
+ if name, ok := p.childRefNames[ref]; ok {
+ m, ok := p.childRefs[name]
+ if !ok {
+ // This should not happen, don't proceed.
+ p.childMu.Unlock()
+ panic(fmt.Sprintf("name %s missing from childfidRefs", name))
+ }
+
+ delete(m, ref)
+ if len(m) == 0 {
+ delete(p.childRefs, name)
+ }
+ }
+
+ delete(p.childRefNames, ref)
+
+ p.childMu.Unlock()
+}
+
+// addPathNodeFor adds an existing pathNode as the node for name.
+//
+// Preconditions: newName does not exist.
+func (p *pathNode) addPathNodeFor(name string, pn *pathNode) {
+ p.childMu.Lock()
+
+ if opn, ok := p.childNodes[name]; ok {
+ p.childMu.Unlock()
+ panic(fmt.Sprintf("unexpected pathNode %+v with path %q", opn, name))
+ }
+
+ p.childNodes[name] = pn
+ p.childMu.Unlock()
+}
+
+// removeWithName removes all references with the given name.
+//
+// The provided function is executed after reference removal. The only method
+// it may (transitively) call on this pathNode is addChildLocked.
+//
+// If a child pathNode for name exists, it is removed from this pathNode and
+// returned by this function. Any operations on the removed tree must use this
+// value.
+func (p *pathNode) removeWithName(name string, fn func(ref *fidRef)) *pathNode {
+ p.childMu.Lock()
+ defer p.childMu.Unlock()
+
+ if m, ok := p.childRefs[name]; ok {
+ for ref := range m {
+ delete(m, ref)
+ delete(p.childRefNames, ref)
+ fn(ref)
+ }
+ }
+
+ // Return the original path node, if it exists.
+ origPathNode := p.childNodes[name]
+ delete(p.childNodes, name)
+ return origPathNode
+}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
new file mode 100644
index 000000000..fdfa83648
--- /dev/null
+++ b/pkg/p9/server.go
@@ -0,0 +1,694 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "io"
+ "runtime/debug"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Server is a 9p2000.L server.
+type Server struct {
+ // attacher provides the attach function.
+ attacher Attacher
+
+ // pathTree is the full set of paths opened on this server.
+ //
+ // These may be across different connections, but rename operations
+ // must be serialized globally for safely. There is a single pathTree
+ // for the entire server, and not per connection.
+ pathTree *pathNode
+
+ // renameMu is a global lock protecting rename operations. With this
+ // lock, we can be certain that any given rename operation can safely
+ // acquire two path nodes in any order, as all other concurrent
+ // operations acquire at most a single node.
+ renameMu sync.RWMutex
+}
+
+// NewServer returns a new server.
+func NewServer(attacher Attacher) *Server {
+ return &Server{
+ attacher: attacher,
+ pathTree: newPathNode(),
+ }
+}
+
+// connState is the state for a single connection.
+type connState struct {
+ // server is the backing server.
+ server *Server
+
+ // sendMu is the send lock.
+ sendMu sync.Mutex
+
+ // conn is the connection.
+ conn *unet.Socket
+
+ // fids is the set of active FIDs.
+ //
+ // This is used to find FIDs for files.
+ fidMu sync.Mutex
+ fids map[FID]*fidRef
+
+ // tags is the set of active tags.
+ //
+ // The given channel is closed when the
+ // tag is finished with processing.
+ tagMu sync.Mutex
+ tags map[Tag]chan struct{}
+
+ // messageSize is the maximum message size. The server does not
+ // do automatic splitting of messages.
+ messageSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // -- below relates to the legacy handler --
+
+ // recvOkay indicates that a receive may start.
+ recvOkay chan bool
+
+ // recvDone is signalled when a message is received.
+ recvDone chan error
+
+ // sendDone is signalled when a send is finished.
+ sendDone chan error
+
+ // -- below relates to the flipcall handler --
+
+ // channelMu protects below.
+ channelMu sync.Mutex
+
+ // channelWg represents active workers.
+ channelWg sync.WaitGroup
+
+ // channelAlloc allocates channel memory.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ // channels are the set of initialized channels.
+ channels []*channel
+}
+
+// fidRef wraps a node and tracks references.
+type fidRef struct {
+ // server is the associated server.
+ server *Server
+
+ // file is the associated File.
+ file File
+
+ // refs is an active refence count.
+ //
+ // The node above will be closed only when refs reaches zero.
+ refs int64
+
+ // openedMu protects opened and openFlags.
+ openedMu sync.Mutex
+
+ // opened indicates whether this has been opened already.
+ //
+ // This is updated in handlers.go.
+ opened bool
+
+ // mode is the fidRef's mode from the walk. Only the type bits are
+ // valid, the permissions may change. This is used to sanity check
+ // operations on this element, and prevent walks across
+ // non-directories.
+ mode FileMode
+
+ // openFlags is the mode used in the open.
+ //
+ // This is updated in handlers.go.
+ openFlags OpenFlags
+
+ // pathNode is the current pathNode for this FID.
+ pathNode *pathNode
+
+ // parent is the parent fidRef. We hold on to a parent reference to
+ // ensure that hooks, such as Renamed, can be executed safely by the
+ // server code.
+ //
+ // Note that parent cannot be changed without holding both the global
+ // rename lock and a writable lock on the associated pathNode for this
+ // fidRef. Holding either of these locks is sufficient to examine
+ // parent safely.
+ //
+ // The parent will be nil for root fidRefs, and non-nil otherwise. The
+ // method maybeParent can be used to return a cyclical reference, and
+ // isRoot should be used to check for root over looking at parent
+ // directly.
+ parent *fidRef
+
+ // deleted indicates that the backing file has been deleted. We stop
+ // many operations at the API level if they are incompatible with a
+ // file that has already been unlinked.
+ deleted uint32
+}
+
+// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously.
+func (f *fidRef) OpenFlags() (OpenFlags, bool) {
+ f.openedMu.Lock()
+ defer f.openedMu.Unlock()
+ return f.openFlags, f.opened
+}
+
+// IncRef increases the references on a fid.
+func (f *fidRef) IncRef() {
+ atomic.AddInt64(&f.refs, 1)
+}
+
+// DecRef should be called when you're finished with a fid.
+func (f *fidRef) DecRef() {
+ if atomic.AddInt64(&f.refs, -1) == 0 {
+ f.file.Close()
+
+ // Drop the parent reference.
+ //
+ // Since this fidRef is guaranteed to be non-discoverable when
+ // the references reach zero, we don't need to worry about
+ // clearing the parent.
+ if f.parent != nil {
+ // If we've been previously deleted, this removing this
+ // ref is a no-op. That's expected.
+ f.parent.pathNode.removeChild(f)
+ f.parent.DecRef()
+ }
+ }
+}
+
+// isDeleted returns true if this fidRef has been deleted.
+func (f *fidRef) isDeleted() bool {
+ return atomic.LoadUint32(&f.deleted) != 0
+}
+
+// isRoot indicates whether this is a root fid.
+func (f *fidRef) isRoot() bool {
+ return f.parent == nil
+}
+
+// maybeParent returns a cyclic reference for roots, and the parent otherwise.
+func (f *fidRef) maybeParent() *fidRef {
+ if f.parent != nil {
+ return f.parent
+ }
+ return f // Root has itself.
+}
+
+// notifyDelete marks all fidRefs as deleted.
+//
+// Precondition: this must be called via safelyWrite or safelyGlobal.
+func notifyDelete(pn *pathNode) {
+ // Call on all local references.
+ pn.forEachChildRef(func(ref *fidRef, _ string) {
+ atomic.StoreUint32(&ref.deleted, 1)
+ })
+
+ // Call on all subtrees.
+ pn.forEachChildNode(func(pn *pathNode) {
+ notifyDelete(pn)
+ })
+}
+
+// markChildDeleted marks all children below the given name as deleted.
+//
+// Precondition: this must be called via safelyWrite or safelyGlobal.
+func (f *fidRef) markChildDeleted(name string) {
+ origPathNode := f.pathNode.removeWithName(name, func(ref *fidRef) {
+ atomic.StoreUint32(&ref.deleted, 1)
+ })
+
+ if origPathNode != nil {
+ // Mark all children as deleted.
+ notifyDelete(origPathNode)
+ }
+}
+
+// notifyNameChange calls the relevant Renamed method on all nodes in the path,
+// recursively. Note that this applies only for subtrees, as these
+// notifications do not apply to the actual file whose name has changed.
+//
+// Precondition: this must be called via safelyGlobal.
+func notifyNameChange(pn *pathNode) {
+ // Call on all local references.
+ pn.forEachChildRef(func(ref *fidRef, name string) {
+ ref.file.Renamed(ref.parent.file, name)
+ })
+
+ // Call on all subtrees.
+ pn.forEachChildNode(func(pn *pathNode) {
+ notifyNameChange(pn)
+ })
+}
+
+// renameChildTo renames the given child to the target.
+//
+// Precondition: this must be called via safelyGlobal.
+func (f *fidRef) renameChildTo(oldName string, target *fidRef, newName string) {
+ target.markChildDeleted(newName)
+ origPathNode := f.pathNode.removeWithName(oldName, func(ref *fidRef) {
+ // N.B. DecRef can take f.pathNode's parent's childMu. This is
+ // allowed because renameMu is held for write via safelyGlobal.
+ ref.parent.DecRef() // Drop original reference.
+ ref.parent = target // Change parent.
+ ref.parent.IncRef() // Acquire new one.
+ if f.pathNode == target.pathNode {
+ target.pathNode.addChildLocked(ref, newName)
+ } else {
+ target.pathNode.addChild(ref, newName)
+ }
+ ref.file.Renamed(target.file, newName)
+ })
+
+ if origPathNode != nil {
+ // Replace the previous (now deleted) path node.
+ target.pathNode.addPathNodeFor(newName, origPathNode)
+ // Call Renamed on all children.
+ notifyNameChange(origPathNode)
+ }
+}
+
+// safelyRead executes the given operation with the local path node locked.
+// This implies that paths will not change during the operation.
+func (f *fidRef) safelyRead(fn func() error) (err error) {
+ f.server.renameMu.RLock()
+ defer f.server.renameMu.RUnlock()
+ f.pathNode.opMu.RLock()
+ defer f.pathNode.opMu.RUnlock()
+ return fn()
+}
+
+// safelyWrite executes the given operation with the local path node locked in
+// a writable fashion. This implies some paths may change.
+func (f *fidRef) safelyWrite(fn func() error) (err error) {
+ f.server.renameMu.RLock()
+ defer f.server.renameMu.RUnlock()
+ f.pathNode.opMu.Lock()
+ defer f.pathNode.opMu.Unlock()
+ return fn()
+}
+
+// safelyGlobal executes the given operation with the global path lock held.
+func (f *fidRef) safelyGlobal(fn func() error) (err error) {
+ f.server.renameMu.Lock()
+ defer f.server.renameMu.Unlock()
+ return fn()
+}
+
+// LookupFID finds the given FID.
+//
+// You should call fid.DecRef when you are finished using the fid.
+func (cs *connState) LookupFID(fid FID) (*fidRef, bool) {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ fidRef, ok := cs.fids[fid]
+ if ok {
+ fidRef.IncRef()
+ return fidRef, true
+ }
+ return nil, false
+}
+
+// InsertFID installs the given FID.
+//
+// This fid starts with a reference count of one. If a FID exists in
+// the slot already it is closed, per the specification.
+func (cs *connState) InsertFID(fid FID, newRef *fidRef) {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ origRef, ok := cs.fids[fid]
+ if ok {
+ defer origRef.DecRef()
+ }
+ newRef.IncRef()
+ cs.fids[fid] = newRef
+}
+
+// DeleteFID removes the given FID.
+//
+// This simply removes it from the map and drops a reference.
+func (cs *connState) DeleteFID(fid FID) bool {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ fidRef, ok := cs.fids[fid]
+ if !ok {
+ return false
+ }
+ delete(cs.fids, fid)
+ fidRef.DecRef()
+ return true
+}
+
+// StartTag starts handling the tag.
+//
+// False is returned if this tag is already active.
+func (cs *connState) StartTag(t Tag) bool {
+ cs.tagMu.Lock()
+ defer cs.tagMu.Unlock()
+ _, ok := cs.tags[t]
+ if ok {
+ return false
+ }
+ cs.tags[t] = make(chan struct{})
+ return true
+}
+
+// ClearTag finishes handling a tag.
+func (cs *connState) ClearTag(t Tag) {
+ cs.tagMu.Lock()
+ defer cs.tagMu.Unlock()
+ ch, ok := cs.tags[t]
+ if !ok {
+ // Should never happen.
+ panic("unused tag cleared")
+ }
+ delete(cs.tags, t)
+
+ // Notify.
+ close(ch)
+}
+
+// WaitTag waits for a tag to finish.
+func (cs *connState) WaitTag(t Tag) {
+ cs.tagMu.Lock()
+ ch, ok := cs.tags[t]
+ cs.tagMu.Unlock()
+ if !ok {
+ return
+ }
+
+ // Wait for close.
+ <-ch
+}
+
+// initializeChannels initializes all channels.
+//
+// This is a no-op if channels are already initialized.
+func (cs *connState) initializeChannels() (err error) {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+
+ // Initialize our channel allocator.
+ if cs.channelAlloc == nil {
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return err
+ }
+ cs.channelAlloc = alloc
+ }
+
+ // Create all the channels.
+ for len(cs.channels) < channelsPerClient {
+ res := &channel{
+ done: make(chan struct{}),
+ }
+
+ res.desc, err = cs.channelAlloc.Allocate(channelSize)
+ if err != nil {
+ return err
+ }
+ if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil {
+ return err
+ }
+
+ socks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ res.data.Destroy() // Cleanup.
+ return err
+ }
+ res.fds.Init(socks[0])
+ res.client = fd.New(socks[1])
+
+ cs.channels = append(cs.channels, res)
+
+ // Start servicing the channel.
+ //
+ // When we call stop, we will close all the channels and these
+ // routines should finish. We need the wait group to ensure
+ // that active handlers are actually finished before cleanup.
+ cs.channelWg.Add(1)
+ go func() { // S/R-SAFE: Server side.
+ defer cs.channelWg.Done()
+ if err := res.service(cs); err != nil {
+ // Don't log flipcall.ShutdownErrors, which we expect to be
+ // returned during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("p9.channel.service: %v", err)
+ }
+ }
+ }()
+ }
+
+ return nil
+}
+
+// lookupChannel looks up the channel with given id.
+//
+// The function returns nil if no such channel is available.
+func (cs *connState) lookupChannel(id uint32) *channel {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+ if id >= uint32(len(cs.channels)) {
+ return nil
+ }
+ return cs.channels[id]
+}
+
+// handle handles a single message.
+func (cs *connState) handle(m message) (r message) {
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %s", debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ return
+}
+
+// handleRequest handles a single request.
+//
+// The recvDone channel is signaled when recv is done (with a error if
+// necessary). The sendDone channel is signaled with the result of the send.
+func (cs *connState) handleRequest() {
+ messageSize := atomic.LoadUint32(&cs.messageSize)
+ if messageSize == 0 {
+ // Default or not yet negotiated.
+ messageSize = maximumLength
+ }
+
+ // Receive a message.
+ tag, m, err := recv(cs.conn, messageSize, msgRegistry.get)
+ if errSocket, ok := err.(ErrSocket); ok {
+ // Connection problem; stop serving.
+ cs.recvDone <- errSocket.error
+ return
+ }
+
+ // Signal receive is done.
+ cs.recvDone <- nil
+
+ // Deal with other errors.
+ if err != nil && err != io.EOF {
+ // If it's not a connection error, but some other protocol error,
+ // we can send a response immediately.
+ cs.sendMu.Lock()
+ err := send(cs.conn, tag, newErr(err))
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
+ return
+ }
+
+ // Try to start the tag.
+ if !cs.StartTag(tag) {
+ // Nothing we can do at this point; client is bogus.
+ log.Debugf("no valid tag [%05d]", tag)
+ cs.sendDone <- ErrNoValidMessage
+ return
+ }
+
+ // Handle the message.
+ r := cs.handle(m)
+
+ // Clear the tag before sending. That's because as soon as this hits
+ // the wire, the client can legally send the same tag.
+ cs.ClearTag(tag)
+
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
+
+ // Return the message to the cache.
+ msgRegistry.put(m)
+}
+
+func (cs *connState) handleRequests() {
+ for range cs.recvOkay {
+ cs.handleRequest()
+ }
+}
+
+func (cs *connState) stop() {
+ // Close all channels.
+ close(cs.recvOkay)
+ close(cs.recvDone)
+ close(cs.sendDone)
+
+ // Free the channels.
+ cs.channelMu.Lock()
+ for _, ch := range cs.channels {
+ ch.Shutdown()
+ }
+ cs.channelWg.Wait()
+ for _, ch := range cs.channels {
+ ch.Close()
+ }
+ cs.channels = nil // Clear.
+ cs.channelMu.Unlock()
+
+ // Free the channel memory.
+ if cs.channelAlloc != nil {
+ cs.channelAlloc.Destroy()
+ }
+
+ // Close all remaining fids.
+ for fid, fidRef := range cs.fids {
+ delete(cs.fids, fid)
+
+ // Drop final reference in the FID table. Note this should
+ // always close the file, since we've ensured that there are no
+ // handlers running via the wait for Pending => 0 below.
+ fidRef.DecRef()
+ }
+
+ // Ensure the connection is closed.
+ cs.conn.Close()
+}
+
+// service services requests concurrently.
+func (cs *connState) service() error {
+ // Pending is the number of handlers that have finished receiving but
+ // not finished processing requests. These must be waiting on properly
+ // below. See the next comment for an explanation of the loop.
+ pending := 0
+
+ // Start the first request handler.
+ go cs.handleRequests() // S/R-SAFE: Irrelevant.
+ cs.recvOkay <- true
+
+ // We loop and make sure there's always one goroutine waiting for a new
+ // request. We process all the data for a single request in one
+ // goroutine however, to ensure the best turnaround time possible.
+ for {
+ select {
+ case err := <-cs.recvDone:
+ if err != nil {
+ // Wait for pending handlers.
+ for i := 0; i < pending; i++ {
+ <-cs.sendDone
+ }
+ return nil
+ }
+
+ // This handler is now pending.
+ pending++
+
+ // Kick the next receiver, or start a new handler
+ // if no receiver is currently waiting.
+ select {
+ case cs.recvOkay <- true:
+ default:
+ go cs.handleRequests() // S/R-SAFE: Irrelevant.
+ cs.recvOkay <- true
+ }
+
+ case <-cs.sendDone:
+ // This handler is finished.
+ pending--
+
+ // Error sending a response? Nothing can be done.
+ //
+ // We don't terminate on a send error though, since
+ // we still have a pending receive. The error would
+ // have been logged above, we just ignore it here.
+ }
+ }
+}
+
+// Handle handles a single connection.
+func (s *Server) Handle(conn *unet.Socket) error {
+ cs := &connState{
+ server: s,
+ conn: conn,
+ fids: make(map[FID]*fidRef),
+ tags: make(map[Tag]chan struct{}),
+ recvOkay: make(chan bool),
+ recvDone: make(chan error, 10),
+ sendDone: make(chan error, 10),
+ }
+ defer cs.stop()
+ return cs.service()
+}
+
+// Serve handles requests from the bound socket.
+//
+// The passed serverSocket _must_ be created in packet mode.
+func (s *Server) Serve(serverSocket *unet.ServerSocket) error {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ for {
+ conn, err := serverSocket.Accept()
+ if err != nil {
+ // Something went wrong.
+ //
+ // Socket closed?
+ return err
+ }
+
+ wg.Add(1)
+ go func(conn *unet.Socket) { // S/R-SAFE: Irrelevant.
+ s.Handle(conn)
+ wg.Done()
+ }(conn)
+ }
+}
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
new file mode 100644
index 000000000..7cec0e86d
--- /dev/null
+++ b/pkg/p9/transport.go
@@ -0,0 +1,345 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// ErrSocket is returned in cases of a socket issue.
+//
+// This may be treated differently than other errors.
+type ErrSocket struct {
+ // error is the socket error.
+ error
+}
+
+// ErrMessageTooLarge indicates the size was larger than reasonable.
+type ErrMessageTooLarge struct {
+ size uint32
+ msize uint32
+}
+
+// Error returns a sensible error.
+func (e *ErrMessageTooLarge) Error() string {
+ return fmt.Sprintf("message too large for fixed buffer: size is %d, limit is %d", e.size, e.msize)
+}
+
+// ErrNoValidMessage indicates no valid message could be decoded.
+var ErrNoValidMessage = errors.New("buffer contained no valid message")
+
+const (
+ // headerLength is the number of bytes required for a header.
+ headerLength uint32 = 7
+
+ // maximumLength is the largest possible message.
+ maximumLength uint32 = 1 << 20
+
+ // DefaultMessageSize is a sensible default.
+ DefaultMessageSize uint32 = 64 << 10
+
+ // initialBufferLength is the initial data buffer we allocate.
+ initialBufferLength uint32 = 64
+)
+
+var dataPool = sync.Pool{
+ New: func() interface{} {
+ // These buffers are used for decoding without a payload.
+ return make([]byte, initialBufferLength)
+ },
+}
+
+// send sends the given message over the socket.
+func send(s *unet.Socket, tag Tag, m message) error {
+ data := dataPool.Get().([]byte)
+ dataBuf := buffer{data: data[:0]}
+
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
+ }
+
+ // Encode the message. The buffer will grow automatically.
+ m.encode(&dataBuf)
+
+ // Get our vectors to send.
+ var hdr [headerLength]byte
+ vecs := make([][]byte, 0, 3)
+ vecs = append(vecs, hdr[:])
+ if len(dataBuf.data) > 0 {
+ vecs = append(vecs, dataBuf.data)
+ }
+ totalLength := headerLength + uint32(len(dataBuf.data))
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ if len(p) > 0 {
+ vecs = append(vecs, p)
+ totalLength += uint32(len(p))
+ }
+ }
+
+ // Construct the header.
+ headerBuf := buffer{data: hdr[:0]}
+ headerBuf.Write32(totalLength)
+ headerBuf.WriteMsgType(m.Type())
+ headerBuf.WriteTag(tag)
+
+ // Pack any files if necessary.
+ w := s.Writer(true)
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ defer f.Close()
+ // Pack the file into the message.
+ w.PackFDs(f.FD())
+ }
+ }
+
+ for n := 0; n < int(totalLength); {
+ cur, err := w.WriteVec(vecs)
+ if err != nil {
+ return ErrSocket{err}
+ }
+ n += cur
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(vecs[0]) <= cur-consumed {
+ consumed += len(vecs[0])
+ vecs = vecs[1:]
+ } else {
+ vecs[0] = vecs[0][cur-consumed:]
+ break
+ }
+ }
+
+ if n > 0 && n < int(totalLength) {
+ // Don't resend any control message.
+ w.UnpackFDs()
+ }
+ }
+
+ // All set.
+ dataPool.Put(dataBuf.data)
+ return nil
+}
+
+// lookupTagAndType looks up an existing message or creates a new one.
+//
+// This is called by recv after decoding the header. Any error returned will be
+// propagating back to the caller. You may use messageByType directly as a
+// lookupTagAndType function (by design).
+type lookupTagAndType func(tag Tag, t MsgType) (message, error)
+
+// recv decodes a message from the socket.
+//
+// This is done in two parts, and is thus not safe for multiple callers.
+//
+// On a socket error, the special error type ErrSocket is returned.
+//
+// The tag value NoTag will always be returned if err is non-nil.
+func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, error) {
+ // Read a header.
+ //
+ // Since the send above is atomic, we must always receive control
+ // messages along with the header. This means we need to be careful
+ // about closing FDs during errors to prevent leaks.
+ var hdr [headerLength]byte
+ r := s.Reader(true)
+ r.EnableFDs(1)
+
+ n, err := r.ReadVec([][]byte{hdr[:]})
+ if err != nil && (n == 0 || err != io.EOF) {
+ r.CloseFDs()
+ return NoTag, nil, ErrSocket{err}
+ }
+
+ fds, err := r.ExtractFDs()
+ if err != nil {
+ return NoTag, nil, ErrSocket{err}
+ }
+ defer func() {
+ // Close anything left open. The case where
+ // fds are caught and used is handled below,
+ // and the fds variable will be set to nil.
+ for _, fd := range fds {
+ syscall.Close(fd)
+ }
+ }()
+ r.EnableFDs(0)
+
+ // Continuing reading for a short header.
+ for n < int(headerLength) {
+ cur, err := r.ReadVec([][]byte{hdr[n:]})
+ if err != nil && (cur == 0 || err != io.EOF) {
+ return NoTag, nil, ErrSocket{err}
+ }
+ n += cur
+ }
+
+ // Decode the header.
+ headerBuf := buffer{data: hdr[:]}
+ size := headerBuf.Read32()
+ t := headerBuf.ReadMsgType()
+ tag := headerBuf.ReadTag()
+ if size < headerLength {
+ // The message is too small.
+ //
+ // See above: it's probably screwed.
+ return NoTag, nil, ErrSocket{ErrNoValidMessage}
+ }
+ if size > maximumLength || size > msize {
+ // The message is too big.
+ return NoTag, nil, ErrSocket{&ErrMessageTooLarge{size, msize}}
+ }
+ remaining := size - headerLength
+
+ // Find our message to decode.
+ m, err := lookup(tag, t)
+ if err != nil {
+ // Throw away the contents of this message.
+ if remaining > 0 {
+ io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
+ }
+ return tag, nil, err
+ }
+
+ // Not yet initialized.
+ var dataBuf buffer
+
+ // Read the rest of the payload.
+ //
+ // This requires some special care to ensure that the vectors all line
+ // up the way they should. We do this to minimize copying data around.
+ var vecs [][]byte
+ if payloader, ok := m.(payloader); ok {
+ fixedSize := payloader.FixedSize()
+
+ // Do we need more than there is?
+ if fixedSize > remaining {
+ // This is not a valid message.
+ if remaining > 0 {
+ io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
+ }
+ return NoTag, nil, ErrNoValidMessage
+ }
+
+ if fixedSize != 0 {
+ // Pull a data buffer from the pool.
+ data := dataPool.Get().([]byte)
+ if int(fixedSize) > len(data) {
+ // Create a larger data buffer, ensuring
+ // sufficient capicity for the message.
+ data = make([]byte, fixedSize)
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data}
+ vecs = append(vecs, data)
+ } else {
+ // Limit the data buffer, and make sure it
+ // gets filled before the payload buffer.
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data[:fixedSize]}
+ vecs = append(vecs, data[:fixedSize])
+ }
+ }
+
+ // Include the payload.
+ p := payloader.Payload()
+ if p == nil || len(p) != int(remaining-fixedSize) {
+ p = make([]byte, remaining-fixedSize)
+ payloader.SetPayload(p)
+ }
+ if len(p) > 0 {
+ vecs = append(vecs, p)
+ }
+ } else if remaining != 0 {
+ // Pull a data buffer from the pool.
+ data := dataPool.Get().([]byte)
+ if int(remaining) > len(data) {
+ // Create a larger data buffer.
+ data = make([]byte, remaining)
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data}
+ vecs = append(vecs, data)
+ } else {
+ // Limit the data buffer.
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data[:remaining]}
+ vecs = append(vecs, data[:remaining])
+ }
+ }
+
+ if len(vecs) > 0 {
+ // Read the rest of the message.
+ //
+ // No need to handle a control message.
+ r := s.Reader(true)
+ for n := 0; n < int(remaining); {
+ cur, err := r.ReadVec(vecs)
+ if err != nil && (cur == 0 || err != io.EOF) {
+ return NoTag, nil, ErrSocket{err}
+ }
+ n += cur
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(vecs[0]) <= cur-consumed {
+ consumed += len(vecs[0])
+ vecs = vecs[1:]
+ } else {
+ vecs[0] = vecs[0][cur-consumed:]
+ break
+ }
+ }
+ }
+ }
+
+ // Decode the message data.
+ m.decode(&dataBuf)
+ if dataBuf.isOverrun() {
+ // No need to drain the socket.
+ return NoTag, nil, ErrNoValidMessage
+ }
+
+ // Save the file, if any came out.
+ if filer, ok := m.(filer); ok && len(fds) > 0 {
+ // Set the file object.
+ filer.SetFilePayload(fd.New(fds[0]))
+
+ // Close the rest. We support only one.
+ for i := 1; i < len(fds); i++ {
+ syscall.Close(fds[i])
+ }
+
+ // Don't close in the defer.
+ fds = nil
+ }
+
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
+ }
+
+ // All set.
+ return tag, m, nil
+}
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
new file mode 100644
index 000000000..38038abdf
--- /dev/null
+++ b/pkg/p9/transport_flipcall.go
@@ -0,0 +1,243 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// channelsPerClient is the number of channels to create per client.
+//
+// While the client and server will generally agree on this number, in reality
+// it's completely up to the server. We simply define a minimum of 2, and a
+// maximum of 4, and select the number of available processes as a tie-breaker.
+// Note that we don't want the number of channels to be too large, because each
+// will account for channelSize memory used, which can be large.
+var channelsPerClient = func() int {
+ n := runtime.NumCPU()
+ if n < 2 {
+ return 2
+ }
+ if n > 4 {
+ return 4
+ }
+ return n
+}()
+
+// channelSize is the channel size to create.
+//
+// We simply ensure that this is larger than the largest possible message size,
+// plus the flipcall packet header, plus the two bytes we write below.
+const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
+
+// channel is a fast IPC channel.
+//
+// The same object is used by both the server and client implementations. In
+// general, the client will use only the send and recv methods.
+type channel struct {
+ desc flipcall.PacketWindowDescriptor
+ data flipcall.Endpoint
+ fds fdchannel.Endpoint
+ buf buffer
+
+ // -- client only --
+ connected bool
+ active bool
+
+ // -- server only --
+ client *fd.FD
+ done chan struct{}
+}
+
+// reset resets the channel buffer.
+func (ch *channel) reset(sz uint32) {
+ ch.buf.data = ch.data.Data()[:sz]
+}
+
+// service services the channel.
+func (ch *channel) service(cs *connState) error {
+ rsz, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rsz > 0 {
+ m, err := ch.recv(nil, rsz)
+ if err != nil {
+ return err
+ }
+ r := cs.handle(m)
+ msgRegistry.put(m)
+ rsz, err = ch.send(r)
+ if err != nil {
+ return err
+ }
+ }
+ return nil // Done.
+}
+
+// Shutdown shuts down the channel.
+//
+// This must be called before Close.
+func (ch *channel) Shutdown() {
+ ch.data.Shutdown()
+}
+
+// Close closes the channel.
+//
+// This must only be called once, and cannot return an error. Note that
+// synchronization for this method is provided at a high-level, depending on
+// whether it is the client or server. This cannot be called while there are
+// active callers in either service or sendRecv.
+//
+// Precondition: the channel should be shutdown.
+func (ch *channel) Close() error {
+ // Close all backing transports.
+ ch.fds.Destroy()
+ ch.data.Destroy()
+ if ch.client != nil {
+ ch.client.Close()
+ }
+ return nil
+}
+
+// send sends the given message.
+//
+// The return value is the size of the received response. Not that in the
+// server case, this is the size of the next request.
+func (ch *channel) send(m message) (uint32, error) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [channel @%p] %s", ch, m.String())
+ }
+
+ // Send any file payload.
+ sentFD := false
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ if err := ch.fds.SendFD(f.FD()); err != nil {
+ return 0, err
+ }
+ f.Close() // Per sendRecvLegacy.
+ sentFD = true // To mark below.
+ }
+ }
+
+ // Encode the message.
+ //
+ // Note that IPC itself encodes the length of messages, so we don't
+ // need to encode a standard 9P header. We write only the message type.
+ ch.reset(0)
+
+ ch.buf.WriteMsgType(m.Type())
+ if sentFD {
+ ch.buf.Write8(1) // Incoming FD.
+ } else {
+ ch.buf.Write8(0) // No incoming FD.
+ }
+ m.encode(&ch.buf)
+ ssz := uint32(len(ch.buf.data)) // Updated below.
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ copy(ch.data.Data()[ssz:], p)
+ ssz += uint32(len(p))
+ }
+
+ // Perform the one-shot communication.
+ return ch.data.SendRecv(ssz)
+}
+
+// recv decodes a message that exists on the channel.
+//
+// If the passed r is non-nil, then the type must match or an error will be
+// generated. If the passed r is nil, then a new message will be created and
+// returned.
+func (ch *channel) recv(r message, rsz uint32) (message, error) {
+ // Decode the response from the inline buffer.
+ ch.reset(rsz)
+ t := ch.buf.ReadMsgType()
+ hasFD := ch.buf.Read8() != 0
+ if t == MsgRlerror {
+ // Change the message type. We check for this special case
+ // after decoding below, and transform into an error.
+ r = &Rlerror{}
+ } else if r == nil {
+ nr, err := msgRegistry.get(0, t)
+ if err != nil {
+ return nil, err
+ }
+ r = nr // New message.
+ } else if t != r.Type() {
+ // Not an error and not the expected response; propagate.
+ return nil, &ErrBadResponse{Got: t, Want: r.Type()}
+ }
+
+ // Is there a payload? Copy from the latter portion.
+ if payloader, ok := r.(payloader); ok {
+ fs := payloader.FixedSize()
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
+ ch.buf.data = ch.buf.data[:fs]
+ }
+
+ r.decode(&ch.buf)
+ if ch.buf.isOverrun() {
+ // Nothing valid was available.
+ log.Debugf("recv [got %d bytes, needed more]", rsz)
+ return nil, ErrNoValidMessage
+ }
+
+ // Read any FD result.
+ if hasFD {
+ if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
+ f := fd.New(rfd)
+ if filer, ok := r.(filer); ok {
+ // Set the payload.
+ filer.SetFilePayload(f)
+ } else {
+ // Don't want the FD.
+ f.Close()
+ }
+ } else {
+ // The header bit was set but nothing came in.
+ log.Warningf("expected FD, got err: %v", err)
+ }
+ }
+
+ // Log a message.
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [channel @%p] %s", ch, r.String())
+ }
+
+ // Convert errors appropriately; see above.
+ if rlerr, ok := r.(*Rlerror); ok {
+ return r, syscall.Errno(rlerr.Error)
+ }
+
+ return r, nil
+}
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
new file mode 100644
index 000000000..3668fcad7
--- /dev/null
+++ b/pkg/p9/transport_test.go
@@ -0,0 +1,231 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "io/ioutil"
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ MsgTypeBadEncode = iota + 252
+ MsgTypeBadDecode
+ MsgTypeUnregistered
+)
+
+func TestSendRecv(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &Tlopen{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
+ if err != nil {
+ t.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ t.Fatalf("got tag %v expected 1", tag)
+ }
+ if _, ok := m.(*Tlopen); !ok {
+ t.Fatalf("got message %v expected *Tlopen", m)
+ }
+}
+
+// badDecode overruns on decode.
+type badDecode struct{}
+
+func (*badDecode) decode(b *buffer) { b.markOverrun() }
+func (*badDecode) encode(b *buffer) {}
+func (*badDecode) Type() MsgType { return MsgTypeBadDecode }
+func (*badDecode) String() string { return "badDecode{}" }
+
+func TestRecvOverrun(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &badDecode{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil {
+ t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err)
+ }
+}
+
+// unregistered is not registered on decode.
+type unregistered struct{}
+
+func (*unregistered) decode(b *buffer) {}
+func (*unregistered) encode(b *buffer) {}
+func (*unregistered) Type() MsgType { return MsgTypeUnregistered }
+func (*unregistered) String() string { return "unregistered{}" }
+
+func TestRecvInvalidType(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &unregistered{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
+ if _, ok := err.(*ErrInvalidMsgType); !ok {
+ t.Fatalf("recv got err %v expected ErrInvalidMsgType", err)
+ }
+}
+
+func TestSendRecvWithFile(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ // Create a tempfile.
+ osf, err := ioutil.TempFile("", "p9")
+ if err != nil {
+ t.Fatalf("tempfile got err %v expected nil", err)
+ }
+ os.Remove(osf.Name())
+ f, err := fd.NewFromFile(osf)
+ osf.Close()
+ if err != nil {
+ t.Fatalf("unable to create file: %v", err)
+ }
+
+ rlopen := &Rlopen{}
+ rlopen.SetFilePayload(f)
+ if err := send(client, Tag(1), rlopen); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ // Enable withFile.
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
+ if err != nil {
+ t.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ t.Fatalf("got tag %v expected 1", tag)
+ }
+ rlopen, ok := m.(*Rlopen)
+ if !ok {
+ t.Fatalf("got m %v expected *Rlopen", m)
+ }
+ if rlopen.File == nil {
+ t.Fatalf("got nil file expected non-nil")
+ }
+}
+
+func TestRecvClosed(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ client.Close()
+
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
+ if err == nil {
+ t.Fatalf("got err nil expected non-nil")
+ }
+ if _, ok := err.(ErrSocket); !ok {
+ t.Fatalf("got err %v expected ErrSocket", err)
+ }
+}
+
+func TestSendClosed(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ server.Close()
+ defer client.Close()
+
+ err = send(client, Tag(1), &Tlopen{})
+ if err == nil {
+ t.Fatalf("send got err nil expected non-nil")
+ }
+ if _, ok := err.(ErrSocket); !ok {
+ t.Fatalf("got err %v expected ErrSocket", err)
+ }
+}
+
+func BenchmarkSendRecv(b *testing.B) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ // Exchange Rflush messages since these contain no data and therefore incur
+ // no additional marshaling overhead.
+ go func() {
+ for i := 0; i < b.N; i++ {
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
+ if err != nil {
+ b.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ b.Fatalf("got tag %v expected 1", tag)
+ }
+ if _, ok := m.(*Rflush); !ok {
+ b.Fatalf("got message %T expected *Rflush", m)
+ }
+ if err := send(server, Tag(2), &Rflush{}); err != nil {
+ b.Fatalf("send got err %v expected nil", err)
+ }
+ }
+ }()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if err := send(client, Tag(1), &Rflush{}); err != nil {
+ b.Fatalf("send got err %v expected nil", err)
+ }
+ tag, m, err := recv(client, maximumLength, msgRegistry.get)
+ if err != nil {
+ b.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(2) {
+ b.Fatalf("got tag %v expected 2", tag)
+ }
+ if _, ok := m.(*Rflush); !ok {
+ b.Fatalf("got message %v expected *Rflush", m)
+ }
+ }
+}
+
+func init() {
+ msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} })
+}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
new file mode 100644
index 000000000..09cde9f5a
--- /dev/null
+++ b/pkg/p9/version.go
@@ -0,0 +1,175 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+const (
+ // highestSupportedVersion is the highest supported version X in a
+ // version string of the format 9P2000.L.Google.X.
+ //
+ // Clients are expected to start requesting this version number and
+ // to continuously decrement it until a Tversion request succeeds.
+ highestSupportedVersion uint32 = 11
+
+ // lowestSupportedVersion is the lowest supported version X in a
+ // version string of the format 9P2000.L.Google.X.
+ //
+ // Clients are free to send a Tversion request at a version below this
+ // value but are expected to encounter an Rlerror in response.
+ lowestSupportedVersion uint32 = 0
+
+ // baseVersion is the base version of 9P that this package must always
+ // support. It is equivalent to 9P2000.L.Google.0.
+ baseVersion = "9P2000.L"
+)
+
+// HighestVersionString returns the highest possible version string that a client
+// may request or a server may support.
+func HighestVersionString() string {
+ return versionString(highestSupportedVersion)
+}
+
+// parseVersion parses a Tversion version string into a numeric version number
+// if the version string is supported by p9. Otherwise returns (0, false).
+//
+// From Tversion(9P): "Version strings are defined such that, if the client string
+// contains one or more period characters, the initial substring up to but not
+// including any single period in the version string defines a version of the protocol."
+//
+// p9 intentionally diverges from this and always requires that the version string
+// start with 9P2000.L to express that it is always compatible with 9P2000.L. The
+// only supported versions extensions are of the format 9p2000.L.Google.X where X
+// is an ever increasing version counter.
+//
+// Version 9P2000.L.Google.0 implies 9P2000.L.
+//
+// New versions must always be a strict superset of 9P2000.L. A version increase must
+// define a predicate representing the feature extension introduced by that version. The
+// predicate must be commented and should take the format:
+//
+// // VersionSupportsX returns true if version v supports X and must be checked when ...
+// func VersionSupportsX(v int32) bool {
+// ...
+// )
+func parseVersion(str string) (uint32, bool) {
+ // Special case the base version which lacks the ".Google.X" suffix. This
+ // version always means version 0.
+ if str == baseVersion {
+ return 0, true
+ }
+ substr := strings.Split(str, ".")
+ if len(substr) != 4 {
+ return 0, false
+ }
+ if substr[0] != "9P2000" || substr[1] != "L" || substr[2] != "Google" || len(substr[3]) == 0 {
+ return 0, false
+ }
+ version, err := strconv.ParseUint(substr[3], 10, 32)
+ if err != nil {
+ return 0, false
+ }
+ return uint32(version), true
+}
+
+// versionString formats a p9 version number into a Tversion version string.
+func versionString(version uint32) string {
+ // Special case the base version so that clients expecting this string
+ // instead of the 9P2000.L.Google.0 equivalent get it. This is important
+ // for backwards compatibility with legacy servers that check for exactly
+ // the baseVersion and allow nothing else.
+ if version == 0 {
+ return baseVersion
+ }
+ return fmt.Sprintf("9P2000.L.Google.%d", version)
+}
+
+// VersionSupportsTflushf returns true if version v supports the Tflushf message.
+// This predicate must be checked by clients before attempting to make a Tflushf
+// request. If this predicate returns false, then clients may safely no-op.
+func VersionSupportsTflushf(v uint32) bool {
+ return v >= 1
+}
+
+// versionSupportsTwalkgetattr returns true if version v supports the
+// Twalkgetattr message. This predicate must be checked by clients before
+// attempting to make a Twalkgetattr request.
+func versionSupportsTwalkgetattr(v uint32) bool {
+ return v >= 2
+}
+
+// versionSupportsTucreation returns true if version v supports the Tucreation
+// messages (Tucreate, Tusymlink, Tumkdir, Tumknod). This predicate must be
+// checked by clients before attempting to make a Tucreation request.
+// If Tucreation messages are not supported, their non-UID supporting
+// counterparts (Tlcreate, Tsymlink, Tmkdir, Tmknod) should be used.
+func versionSupportsTucreation(v uint32) bool {
+ return v >= 3
+}
+
+// VersionSupportsConnect returns true if version v supports the Tlconnect
+// message. This predicate must be checked by clients
+// before attempting to make a Tlconnect request. If Tlconnect messages are not
+// supported, Tlopen should be used.
+func VersionSupportsConnect(v uint32) bool {
+ return v >= 4
+}
+
+// VersionSupportsAnonymous returns true if version v supports Tlconnect
+// with the AnonymousSocket mode. This predicate must be checked by clients
+// before attempting to use the AnonymousSocket Tlconnect mode.
+func VersionSupportsAnonymous(v uint32) bool {
+ return v >= 5
+}
+
+// VersionSupportsMultiUser returns true if version v supports multi-user fake
+// directory permissions and ID values.
+func VersionSupportsMultiUser(v uint32) bool {
+ return v >= 6
+}
+
+// versionSupportsTallocate returns true if version v supports Allocate().
+func versionSupportsTallocate(v uint32) bool {
+ return v >= 7
+}
+
+// versionSupportsFlipcall returns true if version v supports IPC channels from
+// the flipcall package. Note that these must be negotiated, but this version
+// string indicates that such a facility exists.
+func versionSupportsFlipcall(v uint32) bool {
+ return v >= 8
+}
+
+// VersionSupportsOpenTruncateFlag returns true if version v supports
+// passing the OpenTruncate flag to Tlopen.
+func VersionSupportsOpenTruncateFlag(v uint32) bool {
+ return v >= 9
+}
+
+// versionSupportsGetSetXattr returns true if version v supports
+// the Tgetxattr and Tsetxattr messages.
+func versionSupportsGetSetXattr(v uint32) bool {
+ return v >= 10
+}
+
+// versionSupportsListRemoveXattr returns true if version v supports
+// the Tlistxattr and Tremovexattr messages.
+func versionSupportsListRemoveXattr(v uint32) bool {
+ return v >= 11
+}
diff --git a/pkg/p9/version_test.go b/pkg/p9/version_test.go
new file mode 100644
index 000000000..291e8580e
--- /dev/null
+++ b/pkg/p9/version_test.go
@@ -0,0 +1,145 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "testing"
+)
+
+func TestVersionNumberEquivalent(t *testing.T) {
+ for i := uint32(0); i < 1024; i++ {
+ str := versionString(i)
+ version, ok := parseVersion(str)
+ if !ok {
+ t.Errorf("#%d: parseVersion(%q) failed, want success", i, str)
+ continue
+ }
+ if i != version {
+ t.Errorf("#%d: got version %d, want %d", i, i, version)
+ }
+ }
+}
+
+func TestVersionStringEquivalent(t *testing.T) {
+ // There is one case where the version is not equivalent on purpose,
+ // that is 9P2000.L.Google.0. It is not equivalent because versionString
+ // must always return the more generic 9P2000.L for legacy servers that
+ // check for it. See net/9p/client.c.
+ str := "9P2000.L.Google.0"
+ version, ok := parseVersion(str)
+ if !ok {
+ t.Errorf("parseVersion(%q) failed, want success", str)
+ }
+ if got := versionString(version); got != "9P2000.L" {
+ t.Errorf("versionString(%d) got %q, want %q", version, got, "9P2000.L")
+ }
+
+ for _, test := range []struct {
+ versionString string
+ }{
+ {
+ versionString: "9P2000.L",
+ },
+ {
+ versionString: "9P2000.L.Google.1",
+ },
+ {
+ versionString: "9P2000.L.Google.347823894",
+ },
+ } {
+ version, ok := parseVersion(test.versionString)
+ if !ok {
+ t.Errorf("parseVersion(%q) failed, want success", test.versionString)
+ continue
+ }
+ if got := versionString(version); got != test.versionString {
+ t.Errorf("versionString(%d) got %q, want %q", version, got, test.versionString)
+ }
+ }
+}
+
+func TestParseVersion(t *testing.T) {
+ for _, test := range []struct {
+ versionString string
+ expectSuccess bool
+ expectedVersion uint32
+ }{
+ {
+ versionString: "9P",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P.L",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P200.L",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L.Google.-1",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L.Google.",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L.Google.3546343826724305832",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2001.L",
+ expectSuccess: false,
+ },
+ {
+ versionString: "9P2000.L",
+ expectSuccess: true,
+ expectedVersion: 0,
+ },
+ {
+ versionString: "9P2000.L.Google.0",
+ expectSuccess: true,
+ expectedVersion: 0,
+ },
+ {
+ versionString: "9P2000.L.Google.1",
+ expectSuccess: true,
+ expectedVersion: 1,
+ },
+ } {
+ version, ok := parseVersion(test.versionString)
+ if ok != test.expectSuccess {
+ t.Errorf("parseVersion(%q) got (_, %v), want (_, %v)", test.versionString, ok, test.expectSuccess)
+ continue
+ }
+ if !test.expectSuccess {
+ continue
+ }
+ if version != test.expectedVersion {
+ t.Errorf("parseVersion(%q) got (%d, _), want (%d, _)", test.versionString, version, test.expectedVersion)
+ }
+ }
+}
+
+func BenchmarkParseVersion(b *testing.B) {
+ for n := 0; n < b.N; n++ {
+ parseVersion("9P2000.L.Google.1")
+ }
+}