diff options
Diffstat (limited to 'pkg/p9')
-rw-r--r-- | pkg/p9/BUILD | 52 | ||||
-rw-r--r-- | pkg/p9/buffer.go | 263 | ||||
-rw-r--r-- | pkg/p9/buffer_test.go | 31 | ||||
-rw-r--r-- | pkg/p9/client.go | 575 | ||||
-rw-r--r-- | pkg/p9/client_file.go | 686 | ||||
-rw-r--r-- | pkg/p9/client_test.go | 109 | ||||
-rw-r--r-- | pkg/p9/file.go | 288 | ||||
-rw-r--r-- | pkg/p9/handlers.go | 1393 | ||||
-rw-r--r-- | pkg/p9/messages.go | 2662 | ||||
-rw-r--r-- | pkg/p9/messages_test.go | 483 | ||||
-rw-r--r-- | pkg/p9/p9.go | 1158 | ||||
-rw-r--r-- | pkg/p9/p9_test.go | 188 | ||||
-rw-r--r-- | pkg/p9/p9test/BUILD | 88 | ||||
-rw-r--r-- | pkg/p9/p9test/client_test.go | 2242 | ||||
-rw-r--r-- | pkg/p9/p9test/p9test.go | 329 | ||||
-rw-r--r-- | pkg/p9/path_tree.go | 222 | ||||
-rw-r--r-- | pkg/p9/server.go | 694 | ||||
-rw-r--r-- | pkg/p9/transport.go | 345 | ||||
-rw-r--r-- | pkg/p9/transport_flipcall.go | 243 | ||||
-rw-r--r-- | pkg/p9/transport_test.go | 231 | ||||
-rw-r--r-- | pkg/p9/version.go | 175 | ||||
-rw-r--r-- | pkg/p9/version_test.go | 145 |
22 files changed, 12602 insertions, 0 deletions
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD new file mode 100644 index 000000000..8904afad9 --- /dev/null +++ b/pkg/p9/BUILD @@ -0,0 +1,52 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_library( + name = "p9", + srcs = [ + "buffer.go", + "client.go", + "client_file.go", + "file.go", + "handlers.go", + "messages.go", + "p9.go", + "path_tree.go", + "server.go", + "transport.go", + "transport_flipcall.go", + "version.go", + ], + deps = [ + "//pkg/fd", + "//pkg/fdchannel", + "//pkg/flipcall", + "//pkg/log", + "//pkg/pool", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "p9_test", + size = "small", + srcs = [ + "buffer_test.go", + "client_test.go", + "messages_test.go", + "p9_test.go", + "transport_test.go", + "version_test.go", + ], + library = ":p9", + deps = [ + "//pkg/fd", + "//pkg/unet", + ], +) diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go new file mode 100644 index 000000000..6a4951821 --- /dev/null +++ b/pkg/p9/buffer.go @@ -0,0 +1,263 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "encoding/binary" +) + +// encoder is used for messages and 9P primitives. +type encoder interface { + // decode decodes from the given buffer. decode may be called more than once + // to reuse the instance. It must clear any previous state. + // + // This may not fail, exhaustion will be recorded in the buffer. + decode(b *buffer) + + // encode encodes to the given buffer. + // + // This may not fail. + encode(b *buffer) +} + +// order is the byte order used for encoding. +var order = binary.LittleEndian + +// buffer is a slice that is consumed. +// +// This is passed to the encoder methods. +type buffer struct { + // data is the underlying data. This may grow during encode. + data []byte + + // overflow indicates whether an overflow has occurred. + overflow bool +} + +// append appends n bytes to the buffer and returns a slice pointing to the +// newly appended bytes. +func (b *buffer) append(n int) []byte { + b.data = append(b.data, make([]byte, n)...) + return b.data[len(b.data)-n:] +} + +// consume consumes n bytes from the buffer. +func (b *buffer) consume(n int) ([]byte, bool) { + if !b.has(n) { + b.markOverrun() + return nil, false + } + rval := b.data[:n] + b.data = b.data[n:] + return rval, true +} + +// has returns true if n bytes are available. +func (b *buffer) has(n int) bool { + return len(b.data) >= n +} + +// markOverrun immediately marks this buffer as overrun. +// +// This is used by ReadString, since some invalid data implies the rest of the +// buffer is no longer valid either. +func (b *buffer) markOverrun() { + b.overflow = true +} + +// isOverrun returns true if this buffer has run past the end. +func (b *buffer) isOverrun() bool { + return b.overflow +} + +// Read8 reads a byte from the buffer. +func (b *buffer) Read8() uint8 { + v, ok := b.consume(1) + if !ok { + return 0 + } + return uint8(v[0]) +} + +// Read16 reads a 16-bit value from the buffer. +func (b *buffer) Read16() uint16 { + v, ok := b.consume(2) + if !ok { + return 0 + } + return order.Uint16(v) +} + +// Read32 reads a 32-bit value from the buffer. +func (b *buffer) Read32() uint32 { + v, ok := b.consume(4) + if !ok { + return 0 + } + return order.Uint32(v) +} + +// Read64 reads a 64-bit value from the buffer. +func (b *buffer) Read64() uint64 { + v, ok := b.consume(8) + if !ok { + return 0 + } + return order.Uint64(v) +} + +// ReadQIDType reads a QIDType value. +func (b *buffer) ReadQIDType() QIDType { + return QIDType(b.Read8()) +} + +// ReadTag reads a Tag value. +func (b *buffer) ReadTag() Tag { + return Tag(b.Read16()) +} + +// ReadFID reads a FID value. +func (b *buffer) ReadFID() FID { + return FID(b.Read32()) +} + +// ReadUID reads a UID value. +func (b *buffer) ReadUID() UID { + return UID(b.Read32()) +} + +// ReadGID reads a GID value. +func (b *buffer) ReadGID() GID { + return GID(b.Read32()) +} + +// ReadPermissions reads a file mode value and applies the mask for permissions. +func (b *buffer) ReadPermissions() FileMode { + return b.ReadFileMode() & permissionsMask +} + +// ReadFileMode reads a file mode value. +func (b *buffer) ReadFileMode() FileMode { + return FileMode(b.Read32()) +} + +// ReadOpenFlags reads an OpenFlags. +func (b *buffer) ReadOpenFlags() OpenFlags { + return OpenFlags(b.Read32()) +} + +// ReadConnectFlags reads a ConnectFlags. +func (b *buffer) ReadConnectFlags() ConnectFlags { + return ConnectFlags(b.Read32()) +} + +// ReadMsgType writes a MsgType. +func (b *buffer) ReadMsgType() MsgType { + return MsgType(b.Read8()) +} + +// ReadString deserializes a string. +func (b *buffer) ReadString() string { + l := b.Read16() + if !b.has(int(l)) { + // Mark the buffer as corrupted. + b.markOverrun() + return "" + } + + bs := make([]byte, l) + for i := 0; i < int(l); i++ { + bs[i] = byte(b.Read8()) + } + return string(bs) +} + +// Write8 writes a byte to the buffer. +func (b *buffer) Write8(v uint8) { + b.append(1)[0] = byte(v) +} + +// Write16 writes a 16-bit value to the buffer. +func (b *buffer) Write16(v uint16) { + order.PutUint16(b.append(2), v) +} + +// Write32 writes a 32-bit value to the buffer. +func (b *buffer) Write32(v uint32) { + order.PutUint32(b.append(4), v) +} + +// Write64 writes a 64-bit value to the buffer. +func (b *buffer) Write64(v uint64) { + order.PutUint64(b.append(8), v) +} + +// WriteQIDType writes a QIDType value. +func (b *buffer) WriteQIDType(qidType QIDType) { + b.Write8(uint8(qidType)) +} + +// WriteTag writes a Tag value. +func (b *buffer) WriteTag(tag Tag) { + b.Write16(uint16(tag)) +} + +// WriteFID writes a FID value. +func (b *buffer) WriteFID(fid FID) { + b.Write32(uint32(fid)) +} + +// WriteUID writes a UID value. +func (b *buffer) WriteUID(uid UID) { + b.Write32(uint32(uid)) +} + +// WriteGID writes a GID value. +func (b *buffer) WriteGID(gid GID) { + b.Write32(uint32(gid)) +} + +// WritePermissions applies a permissions mask and writes the FileMode. +func (b *buffer) WritePermissions(perm FileMode) { + b.WriteFileMode(perm & permissionsMask) +} + +// WriteFileMode writes a FileMode. +func (b *buffer) WriteFileMode(mode FileMode) { + b.Write32(uint32(mode)) +} + +// WriteOpenFlags writes an OpenFlags. +func (b *buffer) WriteOpenFlags(flags OpenFlags) { + b.Write32(uint32(flags)) +} + +// WriteConnectFlags writes a ConnectFlags. +func (b *buffer) WriteConnectFlags(flags ConnectFlags) { + b.Write32(uint32(flags)) +} + +// WriteMsgType writes a MsgType. +func (b *buffer) WriteMsgType(t MsgType) { + b.Write8(uint8(t)) +} + +// WriteString serializes the given string. +func (b *buffer) WriteString(s string) { + b.Write16(uint16(len(s))) + for i := 0; i < len(s); i++ { + b.Write8(byte(s[i])) + } +} diff --git a/pkg/p9/buffer_test.go b/pkg/p9/buffer_test.go new file mode 100644 index 000000000..a9c75f86b --- /dev/null +++ b/pkg/p9/buffer_test.go @@ -0,0 +1,31 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "testing" +) + +func TestBufferOverrun(t *testing.T) { + buf := &buffer{ + // This header indicates that a large string should follow, but + // it is only two bytes. Reading a string should cause an + // overrun. + data: []byte{0x0, 0x16}, + } + if s := buf.ReadString(); s != "" { + t.Errorf("overrun read got %s, want empty", s) + } +} diff --git a/pkg/p9/client.go b/pkg/p9/client.go new file mode 100644 index 000000000..71e944c30 --- /dev/null +++ b/pkg/p9/client.go @@ -0,0 +1,575 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "errors" + "fmt" + "syscall" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/pool" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// ErrOutOfTags indicates no tags are available. +var ErrOutOfTags = errors.New("out of tags -- messages lost?") + +// ErrOutOfFIDs indicates no more FIDs are available. +var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?") + +// ErrUnexpectedTag indicates a response with an unexpected tag was received. +var ErrUnexpectedTag = errors.New("unexpected tag in response") + +// ErrVersionsExhausted indicates that all versions to negotiate have been exhausted. +var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate") + +// ErrBadVersionString indicates that the version string is malformed or unsupported. +var ErrBadVersionString = errors.New("bad version string") + +// ErrBadResponse indicates the response didn't match the request. +type ErrBadResponse struct { + Got MsgType + Want MsgType +} + +// Error returns a highly descriptive error. +func (e *ErrBadResponse) Error() string { + return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want) +} + +// response is the asynchronous return from recv. +// +// This is used in the pending map below. +type response struct { + r message + done chan error +} + +var responsePool = sync.Pool{ + New: func() interface{} { + return &response{ + done: make(chan error, 1), + } + }, +} + +// Client is at least a 9P2000.L client. +type Client struct { + // socket is the connected socket. + socket *unet.Socket + + // tagPool is the collection of available tags. + tagPool pool.Pool + + // fidPool is the collection of available fids. + fidPool pool.Pool + + // messageSize is the maximum total size of a message. + messageSize uint32 + + // payloadSize is the maximum payload size of a read or write. + // + // For large reads and writes this means that the read or write is + // broken up into buffer-size/payloadSize requests. + payloadSize uint32 + + // version is the agreed upon version X of 9P2000.L.Google.X. + // version 0 implies 9P2000.L. + version uint32 + + // closedWg is marked as done when the Client.watch() goroutine, which is + // responsible for closing channels and the socket fd, returns. + closedWg sync.WaitGroup + + // sendRecv is the transport function. + // + // This is determined dynamically based on whether or not the server + // supports flipcall channels (preferred as it is faster and more + // efficient, and does not require tags). + sendRecv func(message, message) error + + // -- below corresponds to sendRecvChannel -- + + // channelsMu protects channels. + channelsMu sync.Mutex + + // channelsWg counts the number of channels for which channel.active == + // true. + channelsWg sync.WaitGroup + + // channels is the set of all initialized channels. + channels []*channel + + // availableChannels is a FIFO of inactive channels. + availableChannels []*channel + + // -- below corresponds to sendRecvLegacy -- + + // pending is the set of pending messages. + pending map[Tag]*response + pendingMu sync.Mutex + + // sendMu is the lock for sending a request. + sendMu sync.Mutex + + // recvr is essentially a mutex for calling recv. + // + // Whoever writes to this channel is permitted to call recv. When + // finished calling recv, this channel should be emptied. + recvr chan bool +} + +// NewClient creates a new client. It performs a Tversion exchange with +// the server to assert that messageSize is ok to use. +// +// If NewClient succeeds, ownership of socket is transferred to the new Client. +func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { + // Need at least one byte of payload. + if messageSize <= msgRegistry.largestFixedSize { + return nil, &ErrMessageTooLarge{ + size: messageSize, + msize: msgRegistry.largestFixedSize, + } + } + + // Compute a payload size and round to 512 (normal block size) + // if it's larger than a single block. + payloadSize := messageSize - msgRegistry.largestFixedSize + if payloadSize > 512 && payloadSize%512 != 0 { + payloadSize -= (payloadSize % 512) + } + c := &Client{ + socket: socket, + tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)}, + fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)}, + pending: make(map[Tag]*response), + recvr: make(chan bool, 1), + messageSize: messageSize, + payloadSize: payloadSize, + } + // Agree upon a version. + requested, ok := parseVersion(version) + if !ok { + return nil, ErrBadVersionString + } + for { + // Always exchange the version using the legacy version of the + // protocol. If the protocol supports flipcall, then we switch + // our sendRecv function to use that functionality. Otherwise, + // we stick to sendRecvLegacy. + rversion := Rversion{} + _, err := c.sendRecvLegacy(&Tversion{ + Version: versionString(requested), + MSize: messageSize, + }, &rversion) + + // The server told us to try again with a lower version. + if err == syscall.EAGAIN { + if requested == lowestSupportedVersion { + return nil, ErrVersionsExhausted + } + requested-- + continue + } + + // We requested an impossible version or our other parameters were bogus. + if err != nil { + return nil, err + } + + // Parse the version. + version, ok := parseVersion(rversion.Version) + if !ok { + // The server gave us a bad version. We return a generically worrisome error. + log.Warningf("server returned bad version string %q", rversion.Version) + return nil, ErrBadVersionString + } + c.version = version + break + } + + // Can we switch to use the more advanced channels and create + // independent channels for communication? Prefer it if possible. + if versionSupportsFlipcall(c.version) { + // Attempt to initialize IPC-based communication. + for i := 0; i < channelsPerClient; i++ { + if err := c.openChannel(i); err != nil { + log.Warningf("error opening flipcall channel: %v", err) + break // Stop. + } + } + if len(c.channels) >= 1 { + // At least one channel created. + c.sendRecv = c.sendRecvChannel + } else { + // Channel setup failed; fallback. + c.sendRecv = c.sendRecvLegacySyscallErr + } + } else { + // No channels available: use the legacy mechanism. + c.sendRecv = c.sendRecvLegacySyscallErr + } + + // Ensure that the socket and channels are closed when the socket is shut + // down. + c.closedWg.Add(1) + go c.watch(socket) // S/R-SAFE: not relevant. + + return c, nil +} + +// watch watches the given socket and releases resources on hangup events. +// +// This is intended to be called as a goroutine. +func (c *Client) watch(socket *unet.Socket) { + defer c.closedWg.Done() + + events := []unix.PollFd{ + unix.PollFd{ + Fd: int32(socket.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == syscall.EINTR || err == syscall.EAGAIN { + continue + } + if err != nil { + log.Warningf("p9.Client.watch(): %v", err) + break + } + if n != 1 { + log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Set availableChannels to nil so that future calls to c.sendRecvChannel() + // don't attempt to activate a channel, and concurrent calls to + // c.sendRecvChannel() don't mark released channels as available. + c.channelsMu.Lock() + c.availableChannels = nil + + // Shut down all active channels. + for _, ch := range c.channels { + if ch.active { + log.Debugf("shutting down active channel@%p...", ch) + ch.Shutdown() + } + } + c.channelsMu.Unlock() + + // Wait for active channels to become inactive. + c.channelsWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.Close() + } + c.channelsMu.Unlock() + + // Close the main socket. + c.socket.Close() +} + +// openChannel attempts to open a client channel. +// +// Note that this function returns naked errors which should not be propagated +// directly to a caller. It is expected that the errors will be logged and a +// fallback path will be used instead. +func (c *Client) openChannel(id int) error { + var ( + rchannel0 Rchannel + rchannel1 Rchannel + res = new(channel) + ) + + // Open the data channel. + if _, err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 0, + }, &rchannel0); err != nil { + return fmt.Errorf("error handling Tchannel message: %v", err) + } + if rchannel0.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on primary channel") + } + + // We don't need to hold this. + defer rchannel0.FilePayload().Close() + + // Open the channel for file descriptors. + if _, err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 1, + }, &rchannel1); err != nil { + return err + } + if rchannel1.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on file descriptor channel") + } + + // Construct the endpoints. + res.desc = flipcall.PacketWindowDescriptor{ + FD: rchannel0.FilePayload().FD(), + Offset: int64(rchannel0.Offset), + Length: int(rchannel0.Length), + } + if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil { + rchannel1.FilePayload().Close() + return err + } + + // The fds channel owns the control payload, and it will be closed when + // the channel object is closed. + res.fds.Init(rchannel1.FilePayload().Release()) + + // Save the channel. + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + c.channels = append(c.channels, res) + c.availableChannels = append(c.availableChannels, res) + return nil +} + +// handleOne handles a single incoming message. +// +// This should only be called with the token from recvr. Note that the received +// tag will automatically be cleared from pending. +func (c *Client) handleOne() { + tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) { + c.pendingMu.Lock() + resp := c.pending[tag] + c.pendingMu.Unlock() + + // Not expecting this message? + if resp == nil { + log.Warningf("client received unexpected tag %v, ignoring", tag) + return nil, ErrUnexpectedTag + } + + // Is it an error? We specifically allow this to + // go through, and then we deserialize below. + if t == MsgRlerror { + return &Rlerror{}, nil + } + + // Does it match expectations? + if t != resp.r.Type() { + return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()} + } + + // Return the response. + return resp.r, nil + }) + + if err != nil { + // No tag was extracted (probably a socket error). + // + // Likely catastrophic. Notify all waiters and clear pending. + c.pendingMu.Lock() + for _, resp := range c.pending { + resp.done <- err + } + c.pending = make(map[Tag]*response) + c.pendingMu.Unlock() + } else { + // Process the tag. + // + // We know that is is contained in the map because our lookup function + // above must have succeeded (found the tag) to return nil err. + c.pendingMu.Lock() + resp := c.pending[tag] + delete(c.pending, tag) + c.pendingMu.Unlock() + resp.r = r + resp.done <- err + } +} + +// waitAndRecv co-ordinates with other receivers to handle responses. +func (c *Client) waitAndRecv(done chan error) error { + for { + select { + case err := <-done: + return err + case c.recvr <- true: + select { + case err := <-done: + // It's possible that we got the token, despite + // done also being available. Check for that. + <-c.recvr + return err + default: + // Handle receiving one tag. + c.handleOne() + + // Return the token. + <-c.recvr + } + } + } +} + +// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all +// non-syscall errors to EIO. +func (c *Client) sendRecvLegacySyscallErr(t message, r message) error { + received, err := c.sendRecvLegacy(t, r) + if !received { + log.Warningf("p9.Client.sendRecvChannel: %v", err) + return syscall.EIO + } + return err +} + +// sendRecvLegacy performs a roundtrip message exchange. +// +// sendRecvLegacy returns true if a message was received. This allows us to +// differentiate between failed receives and successful receives where the +// response was an error message. +// +// This is called by internal functions. +func (c *Client) sendRecvLegacy(t message, r message) (bool, error) { + tag, ok := c.tagPool.Get() + if !ok { + return false, ErrOutOfTags + } + defer c.tagPool.Put(tag) + + // Indicate we're expecting a response. + // + // Note that the tag will be cleared from pending + // automatically (see handleOne for details). + resp := responsePool.Get().(*response) + defer responsePool.Put(resp) + resp.r = r + c.pendingMu.Lock() + c.pending[Tag(tag)] = resp + c.pendingMu.Unlock() + + // Send the request over the wire. + c.sendMu.Lock() + err := send(c.socket, Tag(tag), t) + c.sendMu.Unlock() + if err != nil { + return false, err + } + + // Co-ordinate with other receivers. + if err := c.waitAndRecv(resp.done); err != nil { + return false, err + } + + // Is it an error message? + // + // For convenience, we transform these directly + // into errors. Handlers need not handle this case. + if rlerr, ok := resp.r.(*Rlerror); ok { + return true, syscall.Errno(rlerr.Error) + } + + // At this point, we know it matches. + // + // Per recv call above, we will only allow a type + // match (and give our r) or an instance of Rlerror. + return true, nil +} + +// sendRecvChannel uses channels to send a message. +func (c *Client) sendRecvChannel(t message, r message) error { + // Acquire an available channel. + c.channelsMu.Lock() + if len(c.availableChannels) == 0 { + c.channelsMu.Unlock() + return c.sendRecvLegacySyscallErr(t, r) + } + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + ch.active = true + c.channelsWg.Add(1) + c.channelsMu.Unlock() + + // Ensure that it's connected. + if !ch.connected { + ch.connected = true + if err := ch.data.Connect(); err != nil { + // The channel is unusable, so don't return it to + // c.availableChannels. However, we still have to mark it as + // inactive so c.watch() doesn't wait for it. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + // Map all transport errors to EIO, but ensure that the real error + // is logged. + log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) + return syscall.EIO + } + } + + // Send the request and receive the server's response. + rsz, err := ch.send(t) + if err != nil { + // See above. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) + return syscall.EIO + } + + // Parse the server's response. + resp, retErr := ch.recv(r, rsz) + if resp == nil { + log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr) + retErr = syscall.EIO + } + + // Release the channel. + c.channelsMu.Lock() + ch.active = false + // If c.availableChannels is nil, c.watch() has fired and we should not + // mark this channel as available. + if c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.channelsMu.Unlock() + c.channelsWg.Done() + + return retErr +} + +// Version returns the negotiated 9P2000.L.Google version number. +func (c *Client) Version() uint32 { + return c.version +} + +// Close closes the underlying socket and channels. +func (c *Client) Close() { + // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already + // been called (by c.watch()). + c.socket.Shutdown() + c.closedWg.Wait() +} diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go new file mode 100644 index 000000000..2ee07b664 --- /dev/null +++ b/pkg/p9/client_file.go @@ -0,0 +1,686 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "fmt" + "io" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/log" +) + +// Attach attaches to a server. +// +// Note that authentication is not currently supported. +func (c *Client) Attach(name string) (File, error) { + fid, ok := c.fidPool.Get() + if !ok { + return nil, ErrOutOfFIDs + } + + rattach := Rattach{} + if err := c.sendRecv(&Tattach{FID: FID(fid), Auth: Tauth{AttachName: name, AuthenticationFID: NoFID, UID: NoUID}}, &rattach); err != nil { + c.fidPool.Put(fid) + return nil, err + } + + return c.newFile(FID(fid)), nil +} + +// newFile returns a new client file. +func (c *Client) newFile(fid FID) *clientFile { + return &clientFile{ + client: c, + fid: fid, + } +} + +// clientFile is provided to clients. +// +// This proxies all of the interfaces found in file.go. +type clientFile struct { + // client is the originating client. + client *Client + + // fid is the FID for this file. + fid FID + + // closed indicates whether this file has been closed. + closed uint32 +} + +// Walk implements File.Walk. +func (c *clientFile) Walk(names []string) ([]QID, File, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, nil, syscall.EBADF + } + + fid, ok := c.client.fidPool.Get() + if !ok { + return nil, nil, ErrOutOfFIDs + } + + rwalk := Rwalk{} + if err := c.client.sendRecv(&Twalk{FID: c.fid, NewFID: FID(fid), Names: names}, &rwalk); err != nil { + c.client.fidPool.Put(fid) + return nil, nil, err + } + + // Return a new client file. + return rwalk.QIDs, c.client.newFile(FID(fid)), nil +} + +// WalkGetAttr implements File.WalkGetAttr. +func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, Attr, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, nil, AttrMask{}, Attr{}, syscall.EBADF + } + + if !versionSupportsTwalkgetattr(c.client.version) { + qids, file, err := c.Walk(components) + if err != nil { + return nil, nil, AttrMask{}, Attr{}, err + } + _, valid, attr, err := file.GetAttr(AttrMaskAll()) + if err != nil { + file.Close() + return nil, nil, AttrMask{}, Attr{}, err + } + return qids, file, valid, attr, nil + } + + fid, ok := c.client.fidPool.Get() + if !ok { + return nil, nil, AttrMask{}, Attr{}, ErrOutOfFIDs + } + + rwalkgetattr := Rwalkgetattr{} + if err := c.client.sendRecv(&Twalkgetattr{FID: c.fid, NewFID: FID(fid), Names: components}, &rwalkgetattr); err != nil { + c.client.fidPool.Put(fid) + return nil, nil, AttrMask{}, Attr{}, err + } + + // Return a new client file. + return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil +} + +// StatFS implements File.StatFS. +func (c *clientFile) StatFS() (FSStat, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return FSStat{}, syscall.EBADF + } + + rstatfs := Rstatfs{} + if err := c.client.sendRecv(&Tstatfs{FID: c.fid}, &rstatfs); err != nil { + return FSStat{}, err + } + + return rstatfs.FSStat, nil +} + +// FSync implements File.FSync. +func (c *clientFile) FSync() error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + + return c.client.sendRecv(&Tfsync{FID: c.fid}, &Rfsync{}) +} + +// GetAttr implements File.GetAttr. +func (c *clientFile) GetAttr(req AttrMask) (QID, AttrMask, Attr, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return QID{}, AttrMask{}, Attr{}, syscall.EBADF + } + + rgetattr := Rgetattr{} + if err := c.client.sendRecv(&Tgetattr{FID: c.fid, AttrMask: req}, &rgetattr); err != nil { + return QID{}, AttrMask{}, Attr{}, err + } + + return rgetattr.QID, rgetattr.Valid, rgetattr.Attr, nil +} + +// SetAttr implements File.SetAttr. +func (c *clientFile) SetAttr(valid SetAttrMask, attr SetAttr) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + + return c.client.sendRecv(&Tsetattr{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattr{}) +} + +// GetXattr implements File.GetXattr. +func (c *clientFile) GetXattr(name string, size uint64) (string, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return "", syscall.EBADF + } + if !versionSupportsGetSetXattr(c.client.version) { + return "", syscall.EOPNOTSUPP + } + + rgetxattr := Rgetxattr{} + if err := c.client.sendRecv(&Tgetxattr{FID: c.fid, Name: name, Size: size}, &rgetxattr); err != nil { + return "", err + } + + return rgetxattr.Value, nil +} + +// SetXattr implements File.SetXattr. +func (c *clientFile) SetXattr(name, value string, flags uint32) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + if !versionSupportsGetSetXattr(c.client.version) { + return syscall.EOPNOTSUPP + } + + return c.client.sendRecv(&Tsetxattr{FID: c.fid, Name: name, Value: value, Flags: flags}, &Rsetxattr{}) +} + +// ListXattr implements File.ListXattr. +func (c *clientFile) ListXattr(size uint64) (map[string]struct{}, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, syscall.EBADF + } + if !versionSupportsListRemoveXattr(c.client.version) { + return nil, syscall.EOPNOTSUPP + } + + rlistxattr := Rlistxattr{} + if err := c.client.sendRecv(&Tlistxattr{FID: c.fid, Size: size}, &rlistxattr); err != nil { + return nil, err + } + + xattrs := make(map[string]struct{}, len(rlistxattr.Xattrs)) + for _, x := range rlistxattr.Xattrs { + xattrs[x] = struct{}{} + } + return xattrs, nil +} + +// RemoveXattr implements File.RemoveXattr. +func (c *clientFile) RemoveXattr(name string) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + if !versionSupportsListRemoveXattr(c.client.version) { + return syscall.EOPNOTSUPP + } + + return c.client.sendRecv(&Tremovexattr{FID: c.fid, Name: name}, &Rremovexattr{}) +} + +// Allocate implements File.Allocate. +func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + if !versionSupportsTallocate(c.client.version) { + return syscall.EOPNOTSUPP + } + + return c.client.sendRecv(&Tallocate{FID: c.fid, Mode: mode, Offset: offset, Length: length}, &Rallocate{}) +} + +// Remove implements File.Remove. +// +// N.B. This method is no longer part of the file interface and should be +// considered deprecated. +func (c *clientFile) Remove() error { + // Avoid double close. + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return syscall.EBADF + } + + // Send the remove message. + if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil { + return err + } + + // "It is correct to consider remove to be a clunk with the side effect + // of removing the file if permissions allow." + // https://swtch.com/plan9port/man/man9/remove.html + + // Return the FID to the pool. + c.client.fidPool.Put(uint64(c.fid)) + return nil +} + +// Close implements File.Close. +func (c *clientFile) Close() error { + // Avoid double close. + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return syscall.EBADF + } + + // Send the close message. + if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil { + // If an error occurred, we toss away the FID. This isn't ideal, + // but I'm not sure what else makes sense in this context. + log.Warningf("Tclunk failed, losing FID %v: %v", c.fid, err) + return err + } + + // Return the FID to the pool. + c.client.fidPool.Put(uint64(c.fid)) + return nil +} + +// Open implements File.Open. +func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, QID{}, 0, syscall.EBADF + } + + rlopen := Rlopen{} + if err := c.client.sendRecv(&Tlopen{FID: c.fid, Flags: flags}, &rlopen); err != nil { + return nil, QID{}, 0, err + } + + return rlopen.File, rlopen.QID, rlopen.IoUnit, nil +} + +// Connect implements File.Connect. +func (c *clientFile) Connect(flags ConnectFlags) (*fd.FD, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, syscall.EBADF + } + + if !VersionSupportsConnect(c.client.version) { + return nil, syscall.ECONNREFUSED + } + + rlconnect := Rlconnect{} + if err := c.client.sendRecv(&Tlconnect{FID: c.fid, Flags: flags}, &rlconnect); err != nil { + return nil, err + } + + return rlconnect.File, nil +} + +// chunk applies fn to p in chunkSize-sized chunks until fn returns a partial result, p is +// exhausted, or an error is encountered (which may be io.EOF). +func chunk(chunkSize uint32, fn func([]byte, uint64) (int, error), p []byte, offset uint64) (int, error) { + // Some p9.Clients depend on executing fn on zero-byte buffers. Handle this + // as a special case (normally it is fine to short-circuit and return (0, nil)). + if len(p) == 0 { + return fn(p, offset) + } + + // total is the cumulative bytes processed. + var total int + for { + var n int + var err error + + // We're done, don't bother trying to do anything more. + if total == len(p) { + return total, nil + } + + // Apply fn to a chunkSize-sized (or less) chunk of p. + if len(p) < total+int(chunkSize) { + n, err = fn(p[total:], offset) + } else { + n, err = fn(p[total:total+int(chunkSize)], offset) + } + total += n + offset += uint64(n) + + // Return whatever we have processed if we encounter an error. This error + // could be io.EOF. + if err != nil { + return total, err + } + + // Did we get a partial result? If so, return it immediately. + if n < int(chunkSize) { + return total, nil + } + + // If we received more bytes than we ever requested, this is a problem. + if total > len(p) { + panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", total, len(p))) + } + } +} + +// ReadAt proxies File.ReadAt. +func (c *clientFile) ReadAt(p []byte, offset uint64) (int, error) { + return chunk(c.client.payloadSize, c.readAt, p, offset) +} + +func (c *clientFile) readAt(p []byte, offset uint64) (int, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return 0, syscall.EBADF + } + + rread := Rread{Data: p} + if err := c.client.sendRecv(&Tread{FID: c.fid, Offset: offset, Count: uint32(len(p))}, &rread); err != nil { + return 0, err + } + + // The message may have been truncated, or for some reason a new buffer + // allocated. This isn't the common path, but we make sure that if the + // payload has changed we copy it. See transport.go for more information. + if len(p) > 0 && len(rread.Data) > 0 && &rread.Data[0] != &p[0] { + copy(p, rread.Data) + } + + // io.EOF is not an error that a p9 server can return. Use POSIX semantics to + // return io.EOF manually: zero bytes were returned and a non-zero buffer was used. + if len(rread.Data) == 0 && len(p) > 0 { + return 0, io.EOF + } + + return len(rread.Data), nil +} + +// WriteAt proxies File.WriteAt. +func (c *clientFile) WriteAt(p []byte, offset uint64) (int, error) { + return chunk(c.client.payloadSize, c.writeAt, p, offset) +} + +func (c *clientFile) writeAt(p []byte, offset uint64) (int, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return 0, syscall.EBADF + } + + rwrite := Rwrite{} + if err := c.client.sendRecv(&Twrite{FID: c.fid, Offset: offset, Data: p}, &rwrite); err != nil { + return 0, err + } + + return int(rwrite.Count), nil +} + +// ReadWriterFile wraps a File and implements io.ReadWriter, io.ReaderAt, and io.WriterAt. +type ReadWriterFile struct { + File File + Offset uint64 +} + +// Read implements part of the io.ReadWriter interface. +func (r *ReadWriterFile) Read(p []byte) (int, error) { + n, err := r.File.ReadAt(p, r.Offset) + r.Offset += uint64(n) + if err != nil { + return n, err + } + if n == 0 && len(p) > 0 { + return n, io.EOF + } + return n, nil +} + +// ReadAt implements the io.ReaderAt interface. +func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) { + n, err := r.File.ReadAt(p, uint64(offset)) + if err != nil { + return 0, err + } + if n == 0 && len(p) > 0 { + return n, io.EOF + } + return n, nil +} + +// Write implements part of the io.ReadWriter interface. +func (r *ReadWriterFile) Write(p []byte) (int, error) { + n, err := r.File.WriteAt(p, r.Offset) + r.Offset += uint64(n) + if err != nil { + return n, err + } + if n < len(p) { + return n, io.ErrShortWrite + } + return n, nil +} + +// WriteAt implements the io.WriteAt interface. +func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) { + n, err := r.File.WriteAt(p, uint64(offset)) + if err != nil { + return n, err + } + if n < len(p) { + return n, io.ErrShortWrite + } + return n, nil +} + +// Rename implements File.Rename. +func (c *clientFile) Rename(dir File, name string) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + + clientDir, ok := dir.(*clientFile) + if !ok { + return syscall.EBADF + } + + return c.client.sendRecv(&Trename{FID: c.fid, Directory: clientDir.fid, Name: name}, &Rrename{}) +} + +// Create implements File.Create. +func (c *clientFile) Create(name string, openFlags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, nil, QID{}, 0, syscall.EBADF + } + + msg := Tlcreate{ + FID: c.fid, + Name: name, + OpenFlags: openFlags, + Permissions: permissions, + GID: NoGID, + } + + if versionSupportsTucreation(c.client.version) { + msg.GID = gid + rucreate := Rucreate{} + if err := c.client.sendRecv(&Tucreate{Tlcreate: msg, UID: uid}, &rucreate); err != nil { + return nil, nil, QID{}, 0, err + } + return rucreate.File, c, rucreate.QID, rucreate.IoUnit, nil + } + + rlcreate := Rlcreate{} + if err := c.client.sendRecv(&msg, &rlcreate); err != nil { + return nil, nil, QID{}, 0, err + } + + return rlcreate.File, c, rlcreate.QID, rlcreate.IoUnit, nil +} + +// Mkdir implements File.Mkdir. +func (c *clientFile) Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return QID{}, syscall.EBADF + } + + msg := Tmkdir{ + Directory: c.fid, + Name: name, + Permissions: permissions, + GID: NoGID, + } + + if versionSupportsTucreation(c.client.version) { + msg.GID = gid + rumkdir := Rumkdir{} + if err := c.client.sendRecv(&Tumkdir{Tmkdir: msg, UID: uid}, &rumkdir); err != nil { + return QID{}, err + } + return rumkdir.QID, nil + } + + rmkdir := Rmkdir{} + if err := c.client.sendRecv(&msg, &rmkdir); err != nil { + return QID{}, err + } + + return rmkdir.QID, nil +} + +// Symlink implements File.Symlink. +func (c *clientFile) Symlink(oldname string, newname string, uid UID, gid GID) (QID, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return QID{}, syscall.EBADF + } + + msg := Tsymlink{ + Directory: c.fid, + Name: newname, + Target: oldname, + GID: NoGID, + } + + if versionSupportsTucreation(c.client.version) { + msg.GID = gid + rusymlink := Rusymlink{} + if err := c.client.sendRecv(&Tusymlink{Tsymlink: msg, UID: uid}, &rusymlink); err != nil { + return QID{}, err + } + return rusymlink.QID, nil + } + + rsymlink := Rsymlink{} + if err := c.client.sendRecv(&msg, &rsymlink); err != nil { + return QID{}, err + } + + return rsymlink.QID, nil +} + +// Link implements File.Link. +func (c *clientFile) Link(target File, newname string) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + + targetFile, ok := target.(*clientFile) + if !ok { + return syscall.EBADF + } + + return c.client.sendRecv(&Tlink{Directory: c.fid, Name: newname, Target: targetFile.fid}, &Rlink{}) +} + +// Mknod implements File.Mknod. +func (c *clientFile) Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return QID{}, syscall.EBADF + } + + msg := Tmknod{ + Directory: c.fid, + Name: name, + Mode: mode, + Major: major, + Minor: minor, + GID: NoGID, + } + + if versionSupportsTucreation(c.client.version) { + msg.GID = gid + rumknod := Rumknod{} + if err := c.client.sendRecv(&Tumknod{Tmknod: msg, UID: uid}, &rumknod); err != nil { + return QID{}, err + } + return rumknod.QID, nil + } + + rmknod := Rmknod{} + if err := c.client.sendRecv(&msg, &rmknod); err != nil { + return QID{}, err + } + + return rmknod.QID, nil +} + +// RenameAt implements File.RenameAt. +func (c *clientFile) RenameAt(oldname string, newdir File, newname string) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + + clientNewDir, ok := newdir.(*clientFile) + if !ok { + return syscall.EBADF + } + + return c.client.sendRecv(&Trenameat{OldDirectory: c.fid, OldName: oldname, NewDirectory: clientNewDir.fid, NewName: newname}, &Rrenameat{}) +} + +// UnlinkAt implements File.UnlinkAt. +func (c *clientFile) UnlinkAt(name string, flags uint32) error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + + return c.client.sendRecv(&Tunlinkat{Directory: c.fid, Name: name, Flags: flags}, &Runlinkat{}) +} + +// Readdir implements File.Readdir. +func (c *clientFile) Readdir(offset uint64, count uint32) ([]Dirent, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, syscall.EBADF + } + + rreaddir := Rreaddir{} + if err := c.client.sendRecv(&Treaddir{Directory: c.fid, Offset: offset, Count: count}, &rreaddir); err != nil { + return nil, err + } + + return rreaddir.Entries, nil +} + +// Readlink implements File.Readlink. +func (c *clientFile) Readlink() (string, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return "", syscall.EBADF + } + + rreadlink := Rreadlink{} + if err := c.client.sendRecv(&Treadlink{FID: c.fid}, &rreadlink); err != nil { + return "", err + } + + return rreadlink.Target, nil +} + +// Flush implements File.Flush. +func (c *clientFile) Flush() error { + if atomic.LoadUint32(&c.closed) != 0 { + return syscall.EBADF + } + + if !VersionSupportsTflushf(c.client.version) { + return nil + } + + return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{}) +} + +// Renamed implements File.Renamed. +func (c *clientFile) Renamed(newDir File, newName string) {} diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go new file mode 100644 index 000000000..c757583e0 --- /dev/null +++ b/pkg/p9/client_test.go @@ -0,0 +1,109 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/unet" +) + +// TestVersion tests the version negotiation. +func TestVersion(t *testing.T) { + // First, create a new server and connection. + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer clientSocket.Close() + + // Create a new server and client. + s := NewServer(nil) + go s.Handle(serverSocket) + + // NewClient does a Tversion exchange, so this is our test for success. + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) + if err != nil { + t.Fatalf("got %v, expected nil", err) + } + + // Check a bogus version string. + if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { + t.Errorf("got %v expected %v", err, syscall.EINVAL) + } + + // Check a bogus version number. + if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { + t.Errorf("got %v expected %v", err, syscall.EINVAL) + } + + // Check a too high version number. + if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN { + t.Errorf("got %v expected %v", err, syscall.EAGAIN) + } + + // Check an invalid MSize. + if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion), MSize: 0}, &Rversion{}); err != syscall.EINVAL { + t.Errorf("got %v expected %v", err, syscall.EINVAL) + } +} + +func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + // See above. + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + b.Fatalf("socketpair got err %v expected nil", err) + } + defer clientSocket.Close() + + // See above. + s := NewServer(nil) + go s.Handle(serverSocket) + + // See above. + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) + if err != nil { + b.Fatalf("got %v, expected nil", err) + } + + // Initialize messages. + sendRecv := fn(c) + tversion := &Tversion{ + Version: versionString(highestSupportedVersion), + MSize: DefaultMessageSize, + } + rversion := new(Rversion) + + // Run in a loop. + for i := 0; i < b.N; i++ { + if err := sendRecv(tversion, rversion); err != nil { + b.Fatalf("got unexpected err: %v", err) + } + } +} + +func BenchmarkSendRecvLegacy(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { + return func(t message, r message) error { + _, err := c.sendRecvLegacy(t, r) + return err + } + }) +} + +func BenchmarkSendRecvChannel(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel }) +} diff --git a/pkg/p9/file.go b/pkg/p9/file.go new file mode 100644 index 000000000..cab35896f --- /dev/null +++ b/pkg/p9/file.go @@ -0,0 +1,288 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/fd" +) + +// Attacher is provided by the server. +type Attacher interface { + // Attach returns a new File. + // + // The client-side attach will be translate to a series of walks from + // the file returned by this Attach call. + Attach() (File, error) +} + +// File is a set of operations corresponding to a single node. +// +// Note that on the server side, the server logic places constraints on +// concurrent operations to make things easier. This may reduce the need for +// complex, error-prone locking and logic in the backend. These are documented +// for each method. +// +// There are three different types of guarantees provided: +// +// none: There is no concurrency guarantee. The method may be invoked +// concurrently with any other method on any other file. +// +// read: The method is guaranteed to be exclusive of any write or global +// operation that is mutating the state of the directory tree starting at this +// node. For example, this means creating new files, symlinks, directories or +// renaming a directory entry (or renaming in to this target), but the method +// may be called concurrently with other read methods. +// +// write: The method is guaranteed to be exclusive of any read, write or global +// operation that is mutating the state of the directory tree starting at this +// node, as described in read above. There may however, be other write +// operations executing concurrently on other components in the directory tree. +// +// global: The method is guaranteed to be exclusive of any read, write or +// global operation. +type File interface { + // Walk walks to the path components given in names. + // + // Walk returns QIDs in the same order that the names were passed in. + // + // An empty list of arguments should return a copy of the current file. + // + // On the server, Walk has a read concurrency guarantee. + Walk(names []string) ([]QID, File, error) + + // WalkGetAttr walks to the next file and returns its maximal set of + // attributes. + // + // Server-side p9.Files may return syscall.ENOSYS to indicate that Walk + // and GetAttr should be used separately to satisfy this request. + // + // On the server, WalkGetAttr has a read concurrency guarantee. + WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) + + // StatFS returns information about the file system associated with + // this file. + // + // On the server, StatFS has no concurrency guarantee. + StatFS() (FSStat, error) + + // GetAttr returns attributes of this node. + // + // On the server, GetAttr has a read concurrency guarantee. + GetAttr(req AttrMask) (QID, AttrMask, Attr, error) + + // SetAttr sets attributes on this node. + // + // On the server, SetAttr has a write concurrency guarantee. + SetAttr(valid SetAttrMask, attr SetAttr) error + + // GetXattr returns extended attributes of this node. + // + // Size indicates the size of the buffer that has been allocated to hold the + // attribute value. If the value is larger than size, implementations may + // return ERANGE to indicate that the buffer is too small, but they are also + // free to ignore the hint entirely (i.e. the value returned may be larger + // than size). All size checking is done independently at the syscall layer. + // + // On the server, GetXattr has a read concurrency guarantee. + GetXattr(name string, size uint64) (string, error) + + // SetXattr sets extended attributes on this node. + // + // On the server, SetXattr has a write concurrency guarantee. + SetXattr(name, value string, flags uint32) error + + // ListXattr lists the names of the extended attributes on this node. + // + // Size indicates the size of the buffer that has been allocated to hold the + // attribute list. If the list would be larger than size, implementations may + // return ERANGE to indicate that the buffer is too small, but they are also + // free to ignore the hint entirely (i.e. the value returned may be larger + // than size). All size checking is done independently at the syscall layer. + // + // On the server, ListXattr has a read concurrency guarantee. + ListXattr(size uint64) (map[string]struct{}, error) + + // RemoveXattr removes extended attributes on this node. + // + // On the server, RemoveXattr has a write concurrency guarantee. + RemoveXattr(name string) error + + // Allocate allows the caller to directly manipulate the allocated disk space + // for the file. See fallocate(2) for more details. + Allocate(mode AllocateMode, offset, length uint64) error + + // Close is called when all references are dropped on the server side, + // and Close should be called by the client to drop all references. + // + // For server-side implementations of Close, the error is ignored. + // + // Close must be called even when Open has not been called. + // + // On the server, Close has no concurrency guarantee. + Close() error + + // Open must be called prior to using Read, Write or Readdir. Once Open + // is called, some operations, such as Walk, will no longer work. + // + // On the client, Open should be called only once. The fd return is + // optional, and may be nil. + // + // On the server, Open has a read concurrency guarantee. If an *fd.FD + // is provided, ownership now belongs to the caller. Open is guaranteed + // to be called only once. + // + // N.B. The server must resolve any lazy paths when open is called. + // After this point, read and write may be called on files with no + // deletion check, so resolving in the data path is not viable. + Open(flags OpenFlags) (*fd.FD, QID, uint32, error) + + // Read reads from this file. Open must be called first. + // + // This may return io.EOF in addition to syscall.Errno values. + // + // On the server, ReadAt has a read concurrency guarantee. See Open for + // additional requirements regarding lazy path resolution. + ReadAt(p []byte, offset uint64) (int, error) + + // Write writes to this file. Open must be called first. + // + // This may return io.EOF in addition to syscall.Errno values. + // + // On the server, WriteAt has a read concurrency guarantee. See Open + // for additional requirements regarding lazy path resolution. + WriteAt(p []byte, offset uint64) (int, error) + + // FSync syncs this node. Open must be called first. + // + // On the server, FSync has a read concurrency guarantee. + FSync() error + + // Create creates a new regular file and opens it according to the + // flags given. This file is already Open. + // + // N.B. On the client, the returned file is a reference to the current + // file, which now represents the created file. This is not the case on + // the server. These semantics are very subtle and can easily lead to + // bugs, but are a consequence of the 9P create operation. + // + // See p9.File.Open for a description of *fd.FD. + // + // On the server, Create has a write concurrency guarantee. + Create(name string, flags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error) + + // Mkdir creates a subdirectory. + // + // On the server, Mkdir has a write concurrency guarantee. + Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error) + + // Symlink makes a new symbolic link. + // + // On the server, Symlink has a write concurrency guarantee. + Symlink(oldName string, newName string, uid UID, gid GID) (QID, error) + + // Link makes a new hard link. + // + // On the server, Link has a write concurrency guarantee. + Link(target File, newName string) error + + // Mknod makes a new device node. + // + // On the server, Mknod has a write concurrency guarantee. + Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error) + + // Rename renames the file. + // + // Rename will never be called on the server, and RenameAt will always + // be used instead. + Rename(newDir File, newName string) error + + // RenameAt renames a given file to a new name in a potentially new + // directory. + // + // oldName must be a name relative to this file, which must be a + // directory. newName is a name relative to newDir. + // + // On the server, RenameAt has a global concurrency guarantee. + RenameAt(oldName string, newDir File, newName string) error + + // UnlinkAt the given named file. + // + // name must be a file relative to this directory. + // + // Flags are implementation-specific (e.g. O_DIRECTORY), but are + // generally Linux unlinkat(2) flags. + // + // On the server, UnlinkAt has a write concurrency guarantee. + UnlinkAt(name string, flags uint32) error + + // Readdir reads directory entries. + // + // This may return io.EOF in addition to syscall.Errno values. + // + // On the server, Readdir has a read concurrency guarantee. + Readdir(offset uint64, count uint32) ([]Dirent, error) + + // Readlink reads the link target. + // + // On the server, Readlink has a read concurrency guarantee. + Readlink() (string, error) + + // Flush is called prior to Close. + // + // Whereas Close drops all references to the file, Flush cleans up the + // file state. Behavior is implementation-specific. + // + // Flush is not related to flush(9p). Flush is an extension to 9P2000.L, + // see version.go. + // + // On the server, Flush has a read concurrency guarantee. + Flush() error + + // Connect establishes a new host-socket backed connection with a + // socket. A File does not need to be opened before it can be connected + // and it can be connected to multiple times resulting in a unique + // *fd.FD each time. In addition, the lifetime of the *fd.FD is + // independent from the lifetime of the p9.File and must be managed by + // the caller. + // + // The returned FD must be non-blocking. + // + // Flags indicates the requested type of socket. + // + // On the server, Connect has a read concurrency guarantee. + Connect(flags ConnectFlags) (*fd.FD, error) + + // Renamed is called when this node is renamed. + // + // This may not fail. The file will hold a reference to its parent + // within the p9 package, and is therefore safe to use for the lifetime + // of this File (until Close is called). + // + // This method should not be called by clients, who should use the + // relevant Rename methods. (Although the method will be a no-op.) + // + // On the server, Renamed has a global concurrency guarantee. + Renamed(newDir File, newName string) +} + +// DefaultWalkGetAttr implements File.WalkGetAttr to return ENOSYS for server-side Files. +type DefaultWalkGetAttr struct{} + +// WalkGetAttr implements File.WalkGetAttr. +func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { + return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go new file mode 100644 index 000000000..1db5797dd --- /dev/null +++ b/pkg/p9/handlers.go @@ -0,0 +1,1393 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "fmt" + "io" + "os" + "path" + "strings" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/log" +) + +// ExtractErrno extracts a syscall.Errno from a error, best effort. +func ExtractErrno(err error) syscall.Errno { + switch err { + case os.ErrNotExist: + return syscall.ENOENT + case os.ErrExist: + return syscall.EEXIST + case os.ErrPermission: + return syscall.EACCES + case os.ErrInvalid: + return syscall.EINVAL + } + + // Attempt to unwrap. + switch e := err.(type) { + case syscall.Errno: + return e + case *os.PathError: + return ExtractErrno(e.Err) + case *os.SyscallError: + return ExtractErrno(e.Err) + case *os.LinkError: + return ExtractErrno(e.Err) + } + + // Default case. + log.Warningf("unknown error: %v", err) + return syscall.EIO +} + +// newErr returns a new error message from an error. +func newErr(err error) *Rlerror { + return &Rlerror{Error: uint32(ExtractErrno(err))} +} + +// handler is implemented for server-handled messages. +// +// See server.go for call information. +type handler interface { + // Handle handles the given message. + // + // This may modify the server state. The handle function must return a + // message which will be sent back to the client. It may be useful to + // use newErr to automatically extract an error message. + handle(cs *connState) message +} + +// handle implements handler.handle. +func (t *Tversion) handle(cs *connState) message { + if t.MSize == 0 { + return newErr(syscall.EINVAL) + } + if t.MSize > maximumLength { + return newErr(syscall.EINVAL) + } + atomic.StoreUint32(&cs.messageSize, t.MSize) + requested, ok := parseVersion(t.Version) + if !ok { + return newErr(syscall.EINVAL) + } + // The server cannot support newer versions that it doesn't know about. In this + // case we return EAGAIN to tell the client to try again with a lower version. + if requested > highestSupportedVersion { + return newErr(syscall.EAGAIN) + } + // From Tversion(9P): "The server may respond with the client’s version + // string, or a version string identifying an earlier defined protocol version". + atomic.StoreUint32(&cs.version, requested) + return &Rversion{ + MSize: t.MSize, + Version: t.Version, + } +} + +// handle implements handler.handle. +func (t *Tflush) handle(cs *connState) message { + cs.WaitTag(t.OldTag) + return &Rflush{} +} + +// checkSafeName validates the name and returns nil or returns an error. +func checkSafeName(name string) error { + if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." { + return nil + } + return syscall.EINVAL +} + +// handle implements handler.handle. +func (t *Tclunk) handle(cs *connState) message { + if !cs.DeleteFID(t.FID) { + return newErr(syscall.EBADF) + } + return &Rclunk{} +} + +// handle implements handler.handle. +func (t *Tremove) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + // Frustratingly, because we can't be guaranteed that a rename is not + // occurring simultaneously with this removal, we need to acquire the + // global rename lock for this kind of remove operation to ensure that + // ref.parent does not change out from underneath us. + // + // This is why Tremove is a bad idea, and clients should generally use + // Tunlinkat. All p9 clients will use Tunlinkat. + err := ref.safelyGlobal(func() error { + // Is this a root? Can't remove that. + if ref.isRoot() { + return syscall.EINVAL + } + + // N.B. this remove operation is permitted, even if the file is open. + // See also rename below for reasoning. + + // Is this file already deleted? + if ref.isDeleted() { + return syscall.EINVAL + } + + // Retrieve the file's proper name. + name := ref.parent.pathNode.nameFor(ref) + + // Attempt the removal. + if err := ref.parent.file.UnlinkAt(name, 0); err != nil { + return err + } + + // Mark all relevant fids as deleted. We don't need to lock any + // individual nodes because we already hold the global lock. + ref.parent.markChildDeleted(name) + return nil + }) + + // "The remove request asks the file server both to remove the file + // represented by fid and to clunk the fid, even if the remove fails." + // + // "It is correct to consider remove to be a clunk with the side effect + // of removing the file if permissions allow." + // https://swtch.com/plan9port/man/man9/remove.html + if !cs.DeleteFID(t.FID) { + return newErr(syscall.EBADF) + } + if err != nil { + return newErr(err) + } + + return &Rremove{} +} + +// handle implements handler.handle. +// +// We don't support authentication, so this just returns ENOSYS. +func (t *Tauth) handle(cs *connState) message { + return newErr(syscall.ENOSYS) +} + +// handle implements handler.handle. +func (t *Tattach) handle(cs *connState) message { + // Ensure no authentication FID is provided. + if t.Auth.AuthenticationFID != NoFID { + return newErr(syscall.EINVAL) + } + + // Must provide an absolute path. + if path.IsAbs(t.Auth.AttachName) { + // Trim off the leading / if the path is absolute. We always + // treat attach paths as absolute and call attach with the root + // argument on the server file for clarity. + t.Auth.AttachName = t.Auth.AttachName[1:] + } + + // Do the attach on the root. + sf, err := cs.server.attacher.Attach() + if err != nil { + return newErr(err) + } + qid, valid, attr, err := sf.GetAttr(AttrMaskAll()) + if err != nil { + sf.Close() // Drop file. + return newErr(err) + } + if !valid.Mode { + sf.Close() // Drop file. + return newErr(syscall.EINVAL) + } + + // Build a transient reference. + root := &fidRef{ + server: cs.server, + parent: nil, + file: sf, + refs: 1, + mode: attr.Mode.FileType(), + pathNode: cs.server.pathTree, + } + defer root.DecRef() + + // Attach the root? + if len(t.Auth.AttachName) == 0 { + cs.InsertFID(t.FID, root) + return &Rattach{QID: qid} + } + + // We want the same traversal checks to apply on attach, so always + // attach at the root and use the regular walk paths. + names := strings.Split(t.Auth.AttachName, "/") + _, newRef, _, _, err := doWalk(cs, root, names, false) + if err != nil { + return newErr(err) + } + defer newRef.DecRef() + + // Insert the FID. + cs.InsertFID(t.FID, newRef) + return &Rattach{QID: qid} +} + +// CanOpen returns whether this file open can be opened, read and written to. +// +// This includes everything except symlinks and sockets. +func CanOpen(mode FileMode) bool { + return mode.IsRegular() || mode.IsDir() || mode.IsNamedPipe() || mode.IsBlockDevice() || mode.IsCharacterDevice() +} + +// handle implements handler.handle. +func (t *Tlopen) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + ref.openedMu.Lock() + defer ref.openedMu.Unlock() + + // Has it been opened already? + if ref.opened || !CanOpen(ref.mode) { + return newErr(syscall.EINVAL) + } + + if ref.mode.IsDir() { + // Directory must be opened ReadOnly. + if t.Flags&OpenFlagsModeMask != ReadOnly { + return newErr(syscall.EISDIR) + } + // Directory not truncatable. + if t.Flags&OpenTruncate != 0 { + return newErr(syscall.EISDIR) + } + } + + var ( + qid QID + ioUnit uint32 + osFile *fd.FD + ) + if err := ref.safelyRead(func() (err error) { + // Has it been deleted already? + if ref.isDeleted() { + return syscall.EINVAL + } + + osFile, qid, ioUnit, err = ref.file.Open(t.Flags) + return err + }); err != nil { + return newErr(err) + } + + // Mark file as opened and set open mode. + ref.opened = true + ref.openFlags = t.Flags + + rlopen := &Rlopen{QID: qid, IoUnit: ioUnit} + rlopen.SetFilePayload(osFile) + return rlopen +} + +func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { + if err := checkSafeName(t.Name); err != nil { + return nil, err + } + + ref, ok := cs.LookupFID(t.FID) + if !ok { + return nil, syscall.EBADF + } + defer ref.DecRef() + + var ( + osFile *fd.FD + nsf File + qid QID + ioUnit uint32 + newRef *fidRef + ) + if err := ref.safelyWrite(func() (err error) { + // Don't allow creation from non-directories or deleted directories. + if ref.isDeleted() || !ref.mode.IsDir() { + return syscall.EINVAL + } + + // Not allowed on open directories. + if _, opened := ref.OpenFlags(); opened { + return syscall.EINVAL + } + + // Do the create. + osFile, nsf, qid, ioUnit, err = ref.file.Create(t.Name, t.OpenFlags, t.Permissions, uid, t.GID) + if err != nil { + return err + } + + newRef = &fidRef{ + server: cs.server, + parent: ref, + file: nsf, + opened: true, + openFlags: t.OpenFlags, + mode: ModeRegular, + pathNode: ref.pathNode.pathNodeFor(t.Name), + } + ref.pathNode.addChild(newRef, t.Name) + ref.IncRef() // Acquire parent reference. + return nil + }); err != nil { + return nil, err + } + + // Replace the FID reference. + cs.InsertFID(t.FID, newRef) + + rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}} + rlcreate.SetFilePayload(osFile) + return rlcreate, nil +} + +// handle implements handler.handle. +func (t *Tlcreate) handle(cs *connState) message { + rlcreate, err := t.do(cs, NoUID) + if err != nil { + return newErr(err) + } + return rlcreate +} + +// handle implements handler.handle. +func (t *Tsymlink) handle(cs *connState) message { + rsymlink, err := t.do(cs, NoUID) + if err != nil { + return newErr(err) + } + return rsymlink +} + +func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { + if err := checkSafeName(t.Name); err != nil { + return nil, err + } + + ref, ok := cs.LookupFID(t.Directory) + if !ok { + return nil, syscall.EBADF + } + defer ref.DecRef() + + var qid QID + if err := ref.safelyWrite(func() (err error) { + // Don't allow symlinks from non-directories or deleted directories. + if ref.isDeleted() || !ref.mode.IsDir() { + return syscall.EINVAL + } + + // Not allowed on open directories. + if _, opened := ref.OpenFlags(); opened { + return syscall.EINVAL + } + + // Do the symlink. + qid, err = ref.file.Symlink(t.Target, t.Name, uid, t.GID) + return err + }); err != nil { + return nil, err + } + + return &Rsymlink{QID: qid}, nil +} + +// handle implements handler.handle. +func (t *Tlink) handle(cs *connState) message { + if err := checkSafeName(t.Name); err != nil { + return newErr(err) + } + + ref, ok := cs.LookupFID(t.Directory) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + refTarget, ok := cs.LookupFID(t.Target) + if !ok { + return newErr(syscall.EBADF) + } + defer refTarget.DecRef() + + if err := ref.safelyWrite(func() (err error) { + // Don't allow create links from non-directories or deleted directories. + if ref.isDeleted() || !ref.mode.IsDir() { + return syscall.EINVAL + } + + // Not allowed on open directories. + if _, opened := ref.OpenFlags(); opened { + return syscall.EINVAL + } + + // Do the link. + return ref.file.Link(refTarget.file, t.Name) + }); err != nil { + return newErr(err) + } + + return &Rlink{} +} + +// handle implements handler.handle. +func (t *Trenameat) handle(cs *connState) message { + if err := checkSafeName(t.OldName); err != nil { + return newErr(err) + } + if err := checkSafeName(t.NewName); err != nil { + return newErr(err) + } + + ref, ok := cs.LookupFID(t.OldDirectory) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + refTarget, ok := cs.LookupFID(t.NewDirectory) + if !ok { + return newErr(syscall.EBADF) + } + defer refTarget.DecRef() + + // Perform the rename holding the global lock. + if err := ref.safelyGlobal(func() (err error) { + // Don't allow renaming across deleted directories. + if ref.isDeleted() || !ref.mode.IsDir() || refTarget.isDeleted() || !refTarget.mode.IsDir() { + return syscall.EINVAL + } + + // Not allowed on open directories. + if _, opened := ref.OpenFlags(); opened { + return syscall.EINVAL + } + + // Is this the same file? If yes, short-circuit and return success. + if ref.pathNode == refTarget.pathNode && t.OldName == t.NewName { + return nil + } + + // Attempt the actual rename. + if err := ref.file.RenameAt(t.OldName, refTarget.file, t.NewName); err != nil { + return err + } + + // Update the path tree. + ref.renameChildTo(t.OldName, refTarget, t.NewName) + return nil + }); err != nil { + return newErr(err) + } + + return &Rrenameat{} +} + +// handle implements handler.handle. +func (t *Tunlinkat) handle(cs *connState) message { + if err := checkSafeName(t.Name); err != nil { + return newErr(err) + } + + ref, ok := cs.LookupFID(t.Directory) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + if err := ref.safelyWrite(func() (err error) { + // Don't allow deletion from non-directories or deleted directories. + if ref.isDeleted() || !ref.mode.IsDir() { + return syscall.EINVAL + } + + // Not allowed on open directories. + if _, opened := ref.OpenFlags(); opened { + return syscall.EINVAL + } + + // Before we do the unlink itself, we need to ensure that there + // are no operations in flight on associated path node. The + // child's path node lock must be held to ensure that the + // unlinkat marking the child deleted below is atomic with + // respect to any other read or write operations. + // + // This is one case where we have a lock ordering issue, but + // since we always acquire deeper in the hierarchy, we know + // that we are free of lock cycles. + childPathNode := ref.pathNode.pathNodeFor(t.Name) + childPathNode.opMu.Lock() + defer childPathNode.opMu.Unlock() + + // Do the unlink. + err = ref.file.UnlinkAt(t.Name, t.Flags) + if err != nil { + return err + } + + // Mark the path as deleted. + ref.markChildDeleted(t.Name) + return nil + }); err != nil { + return newErr(err) + } + + return &Runlinkat{} +} + +// handle implements handler.handle. +func (t *Trename) handle(cs *connState) message { + if err := checkSafeName(t.Name); err != nil { + return newErr(err) + } + + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + refTarget, ok := cs.LookupFID(t.Directory) + if !ok { + return newErr(syscall.EBADF) + } + defer refTarget.DecRef() + + if err := ref.safelyGlobal(func() (err error) { + // Don't allow a root rename. + if ref.isRoot() { + return syscall.EINVAL + } + + // Don't allow renaming deleting entries, or target non-directories. + if ref.isDeleted() || refTarget.isDeleted() || !refTarget.mode.IsDir() { + return syscall.EINVAL + } + + // If the parent is deleted, but we not, something is seriously wrong. + // It's fail to die at this point with an assertion failure. + if ref.parent.isDeleted() { + panic(fmt.Sprintf("parent %+v deleted, child %+v is not", ref.parent, ref)) + } + + // N.B. The rename operation is allowed to proceed on open files. It + // does impact the state of its parent, but this is merely a sanity + // check in any case, and the operation is safe. There may be other + // files corresponding to the same path that are renamed anyways. + + // Check for the exact same file and short-circuit. + oldName := ref.parent.pathNode.nameFor(ref) + if ref.parent.pathNode == refTarget.pathNode && oldName == t.Name { + return nil + } + + // Call the rename method on the parent. + if err := ref.parent.file.RenameAt(oldName, refTarget.file, t.Name); err != nil { + return err + } + + // Update the path tree. + ref.parent.renameChildTo(oldName, refTarget, t.Name) + return nil + }); err != nil { + return newErr(err) + } + + return &Rrename{} +} + +// handle implements handler.handle. +func (t *Treadlink) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + var target string + if err := ref.safelyRead(func() (err error) { + // Don't allow readlink on deleted files. There is no need to + // check if this file is opened because symlinks cannot be + // opened. + if ref.isDeleted() || !ref.mode.IsSymlink() { + return syscall.EINVAL + } + + // Do the read. + target, err = ref.file.Readlink() + return err + }); err != nil { + return newErr(err) + } + + return &Rreadlink{target} +} + +// handle implements handler.handle. +func (t *Tread) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + // Constrain the size of the read buffer. + if int(t.Count) > int(maximumLength) { + return newErr(syscall.ENOBUFS) + } + + var ( + data = make([]byte, t.Count) + n int + ) + if err := ref.safelyRead(func() (err error) { + // Has it been opened already? + openFlags, opened := ref.OpenFlags() + if !opened { + return syscall.EINVAL + } + + // Can it be read? Check permissions. + if openFlags&OpenFlagsModeMask == WriteOnly { + return syscall.EPERM + } + + n, err = ref.file.ReadAt(data, t.Offset) + return err + }); err != nil && err != io.EOF { + return newErr(err) + } + + return &Rread{Data: data[:n]} +} + +// handle implements handler.handle. +func (t *Twrite) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + var n int + if err := ref.safelyRead(func() (err error) { + // Has it been opened already? + openFlags, opened := ref.OpenFlags() + if !opened { + return syscall.EINVAL + } + + // Can it be written? Check permissions. + if openFlags&OpenFlagsModeMask == ReadOnly { + return syscall.EPERM + } + + n, err = ref.file.WriteAt(t.Data, t.Offset) + return err + }); err != nil { + return newErr(err) + } + + return &Rwrite{Count: uint32(n)} +} + +// handle implements handler.handle. +func (t *Tmknod) handle(cs *connState) message { + rmknod, err := t.do(cs, NoUID) + if err != nil { + return newErr(err) + } + return rmknod +} + +func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) { + if err := checkSafeName(t.Name); err != nil { + return nil, err + } + + ref, ok := cs.LookupFID(t.Directory) + if !ok { + return nil, syscall.EBADF + } + defer ref.DecRef() + + var qid QID + if err := ref.safelyWrite(func() (err error) { + // Don't allow mknod on deleted files. + if ref.isDeleted() || !ref.mode.IsDir() { + return syscall.EINVAL + } + + // Not allowed on open directories. + if _, opened := ref.OpenFlags(); opened { + return syscall.EINVAL + } + + // Do the mknod. + qid, err = ref.file.Mknod(t.Name, t.Mode, t.Major, t.Minor, uid, t.GID) + return err + }); err != nil { + return nil, err + } + + return &Rmknod{QID: qid}, nil +} + +// handle implements handler.handle. +func (t *Tmkdir) handle(cs *connState) message { + rmkdir, err := t.do(cs, NoUID) + if err != nil { + return newErr(err) + } + return rmkdir +} + +func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { + if err := checkSafeName(t.Name); err != nil { + return nil, err + } + + ref, ok := cs.LookupFID(t.Directory) + if !ok { + return nil, syscall.EBADF + } + defer ref.DecRef() + + var qid QID + if err := ref.safelyWrite(func() (err error) { + // Don't allow mkdir on deleted files. + if ref.isDeleted() || !ref.mode.IsDir() { + return syscall.EINVAL + } + + // Not allowed on open directories. + if _, opened := ref.OpenFlags(); opened { + return syscall.EINVAL + } + + // Do the mkdir. + qid, err = ref.file.Mkdir(t.Name, t.Permissions, uid, t.GID) + return err + }); err != nil { + return nil, err + } + + return &Rmkdir{QID: qid}, nil +} + +// handle implements handler.handle. +func (t *Tgetattr) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + // We allow getattr on deleted files. Depending on the backing + // implementation, it's possible that races exist that might allow + // fetching attributes of other files. But we need to generally allow + // refreshing attributes and this is a minor leak, if at all. + + var ( + qid QID + valid AttrMask + attr Attr + ) + if err := ref.safelyRead(func() (err error) { + qid, valid, attr, err = ref.file.GetAttr(t.AttrMask) + return err + }); err != nil { + return newErr(err) + } + + return &Rgetattr{QID: qid, Valid: valid, Attr: attr} +} + +// handle implements handler.handle. +func (t *Tsetattr) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + if err := ref.safelyWrite(func() error { + // We don't allow setattr on files that have been deleted. + // This might be technically incorrect, as it's possible that + // there were multiple links and you can still change the + // corresponding inode information. + if ref.isDeleted() { + return syscall.EINVAL + } + + // Set the attributes. + return ref.file.SetAttr(t.Valid, t.SetAttr) + }); err != nil { + return newErr(err) + } + + return &Rsetattr{} +} + +// handle implements handler.handle. +func (t *Tallocate) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + if err := ref.safelyWrite(func() error { + // Has it been opened already? + openFlags, opened := ref.OpenFlags() + if !opened { + return syscall.EINVAL + } + + // Can it be written? Check permissions. + if openFlags&OpenFlagsModeMask == ReadOnly { + return syscall.EBADF + } + + // We don't allow allocate on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + + return ref.file.Allocate(t.Mode, t.Offset, t.Length) + }); err != nil { + return newErr(err) + } + + return &Rallocate{} +} + +// handle implements handler.handle. +func (t *Txattrwalk) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + // We don't support extended attributes. + return newErr(syscall.ENODATA) +} + +// handle implements handler.handle. +func (t *Txattrcreate) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + // We don't support extended attributes. + return newErr(syscall.ENOSYS) +} + +// handle implements handler.handle. +func (t *Tgetxattr) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + var val string + if err := ref.safelyRead(func() (err error) { + // Don't allow getxattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + val, err = ref.file.GetXattr(t.Name, t.Size) + return err + }); err != nil { + return newErr(err) + } + return &Rgetxattr{Value: val} +} + +// handle implements handler.handle. +func (t *Tsetxattr) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + if err := ref.safelyWrite(func() error { + // Don't allow setxattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + return ref.file.SetXattr(t.Name, t.Value, t.Flags) + }); err != nil { + return newErr(err) + } + return &Rsetxattr{} +} + +// handle implements handler.handle. +func (t *Tlistxattr) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + var xattrs map[string]struct{} + if err := ref.safelyRead(func() (err error) { + // Don't allow listxattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + xattrs, err = ref.file.ListXattr(t.Size) + return err + }); err != nil { + return newErr(err) + } + + xattrList := make([]string, 0, len(xattrs)) + for x := range xattrs { + xattrList = append(xattrList, x) + } + return &Rlistxattr{Xattrs: xattrList} +} + +// handle implements handler.handle. +func (t *Tremovexattr) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + if err := ref.safelyWrite(func() error { + // Don't allow removexattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + return ref.file.RemoveXattr(t.Name) + }); err != nil { + return newErr(err) + } + return &Rremovexattr{} +} + +// handle implements handler.handle. +func (t *Treaddir) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.Directory) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + var entries []Dirent + if err := ref.safelyRead(func() (err error) { + // Don't allow reading deleted directories. + if ref.isDeleted() || !ref.mode.IsDir() { + return syscall.EINVAL + } + + // Has it been opened already? + if _, opened := ref.OpenFlags(); !opened { + return syscall.EINVAL + } + + // Read the entries. + entries, err = ref.file.Readdir(t.Offset, t.Count) + if err != nil && err != io.EOF { + return err + } + return nil + }); err != nil { + return newErr(err) + } + + return &Rreaddir{Count: t.Count, Entries: entries} +} + +// handle implements handler.handle. +func (t *Tfsync) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + if err := ref.safelyRead(func() (err error) { + // Has it been opened already? + if _, opened := ref.OpenFlags(); !opened { + return syscall.EINVAL + } + + // Perform the sync. + return ref.file.FSync() + }); err != nil { + return newErr(err) + } + + return &Rfsync{} +} + +// handle implements handler.handle. +func (t *Tstatfs) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + st, err := ref.file.StatFS() + if err != nil { + return newErr(err) + } + + return &Rstatfs{st} +} + +// handle implements handler.handle. +func (t *Tflushf) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + if err := ref.safelyRead(ref.file.Flush); err != nil { + return newErr(err) + } + + return &Rflushf{} +} + +// walkOne walks zero or one path elements. +// +// The slice passed as qids is append and returned. +func walkOne(qids []QID, from File, names []string, getattr bool) ([]QID, File, AttrMask, Attr, error) { + if len(names) > 1 { + // We require exactly zero or one elements. + return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL + } + var ( + localQIDs []QID + sf File + valid AttrMask + attr Attr + err error + ) + switch { + case getattr: + localQIDs, sf, valid, attr, err = from.WalkGetAttr(names) + // Can't put fallthrough in the if because Go. + if err != syscall.ENOSYS { + break + } + fallthrough + default: + localQIDs, sf, err = from.Walk(names) + if err != nil { + // No way to walk this element. + break + } + if getattr { + _, valid, attr, err = sf.GetAttr(AttrMaskAll()) + if err != nil { + // Don't leak the file. + sf.Close() + } + } + } + if err != nil { + // Error walking, don't return anything. + return nil, nil, AttrMask{}, Attr{}, err + } + if len(localQIDs) != 1 { + // Expected a single QID. + sf.Close() + return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL + } + return append(qids, localQIDs...), sf, valid, attr, nil +} + +// doWalk walks from a given fidRef. +// +// This enforces that all intermediate nodes are walkable (directories). The +// fidRef returned (newRef) has a reference associated with it that is now +// owned by the caller and must be handled appropriately. +func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QID, newRef *fidRef, valid AttrMask, attr Attr, err error) { + // Check the names. + for _, name := range names { + err = checkSafeName(name) + if err != nil { + return + } + } + + // Has it been opened already? + if _, opened := ref.OpenFlags(); opened { + err = syscall.EBUSY + return + } + + // Is this an empty list? Handle specially. We don't actually need to + // validate anything since this is always permitted. + if len(names) == 0 { + var sf File // Temporary. + if err := ref.maybeParent().safelyRead(func() (err error) { + // Clone the single element. + qids, sf, valid, attr, err = walkOne(nil, ref.file, nil, getattr) + if err != nil { + return err + } + + newRef = &fidRef{ + server: cs.server, + parent: ref.parent, + file: sf, + mode: ref.mode, + pathNode: ref.pathNode, + + // For the clone case, the cloned fid must + // preserve the deleted property of the + // original FID. + deleted: ref.deleted, + } + if !ref.isRoot() { + if !newRef.isDeleted() { + // Add only if a non-root node; the same node. + ref.parent.pathNode.addChild(newRef, ref.parent.pathNode.nameFor(ref)) + } + ref.parent.IncRef() // Acquire parent reference. + } + // doWalk returns a reference. + newRef.IncRef() + return nil + }); err != nil { + return nil, nil, AttrMask{}, Attr{}, err + } + // Do not return the new QID. + return nil, newRef, valid, attr, nil + } + + // Do the walk, one element at a time. + walkRef := ref + walkRef.IncRef() + for i := 0; i < len(names); i++ { + // We won't allow beyond past symlinks; stop here if this isn't + // a proper directory and we have additional paths to walk. + if !walkRef.mode.IsDir() { + walkRef.DecRef() // Drop walk reference; no lock required. + return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL + } + + var sf File // Temporary. + if err := walkRef.safelyRead(func() (err error) { + // Pass getattr = true to walkOne since we need the file type for + // newRef. + qids, sf, valid, attr, err = walkOne(qids, walkRef.file, names[i:i+1], true) + if err != nil { + return err + } + + // Note that we don't need to acquire a lock on any of + // these individual instances. That's because they are + // not actually addressable via a FID. They are + // anonymous. They exist in the tree for tracking + // purposes. + newRef := &fidRef{ + server: cs.server, + parent: walkRef, + file: sf, + mode: attr.Mode.FileType(), + pathNode: walkRef.pathNode.pathNodeFor(names[i]), + } + walkRef.pathNode.addChild(newRef, names[i]) + // We allow our walk reference to become the new parent + // reference here and so we don't IncRef. Instead, just + // set walkRef to the newRef above and acquire a new + // walk reference. + walkRef = newRef + walkRef.IncRef() + return nil + }); err != nil { + walkRef.DecRef() // Drop the old walkRef. + return nil, nil, AttrMask{}, Attr{}, err + } + } + + // Success. + return qids, walkRef, valid, attr, nil +} + +// handle implements handler.handle. +func (t *Twalk) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + // Do the walk. + qids, newRef, _, _, err := doWalk(cs, ref, t.Names, false) + if err != nil { + return newErr(err) + } + defer newRef.DecRef() + + // Install the new FID. + cs.InsertFID(t.NewFID, newRef) + return &Rwalk{QIDs: qids} +} + +// handle implements handler.handle. +func (t *Twalkgetattr) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + // Do the walk. + qids, newRef, valid, attr, err := doWalk(cs, ref, t.Names, true) + if err != nil { + return newErr(err) + } + defer newRef.DecRef() + + // Install the new FID. + cs.InsertFID(t.NewFID, newRef) + return &Rwalkgetattr{QIDs: qids, Valid: valid, Attr: attr} +} + +// handle implements handler.handle. +func (t *Tucreate) handle(cs *connState) message { + rlcreate, err := t.Tlcreate.do(cs, t.UID) + if err != nil { + return newErr(err) + } + return &Rucreate{*rlcreate} +} + +// handle implements handler.handle. +func (t *Tumkdir) handle(cs *connState) message { + rmkdir, err := t.Tmkdir.do(cs, t.UID) + if err != nil { + return newErr(err) + } + return &Rumkdir{*rmkdir} +} + +// handle implements handler.handle. +func (t *Tusymlink) handle(cs *connState) message { + rsymlink, err := t.Tsymlink.do(cs, t.UID) + if err != nil { + return newErr(err) + } + return &Rusymlink{*rsymlink} +} + +// handle implements handler.handle. +func (t *Tumknod) handle(cs *connState) message { + rmknod, err := t.Tmknod.do(cs, t.UID) + if err != nil { + return newErr(err) + } + return &Rumknod{*rmknod} +} + +// handle implements handler.handle. +func (t *Tlconnect) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + var osFile *fd.FD + if err := ref.safelyRead(func() (err error) { + // Don't allow connecting to deleted files. + if ref.isDeleted() || !ref.mode.IsSocket() { + return syscall.EINVAL + } + + // Do the connect. + osFile, err = ref.file.Connect(t.Flags) + return err + }); err != nil { + return newErr(err) + } + + rlconnect := &Rlconnect{} + rlconnect.SetFilePayload(osFile) + return rlconnect +} + +// handle implements handler.handle. +func (t *Tchannel) handle(cs *connState) message { + // Ensure that channels are enabled. + if err := cs.initializeChannels(); err != nil { + return newErr(err) + } + + ch := cs.lookupChannel(t.ID) + if ch == nil { + return newErr(syscall.ENOSYS) + } + + // Return the payload. Note that we need to duplicate the file + // descriptor for the channel allocator, because sending is a + // destructive operation between sendRecvLegacy (and now the newer + // channel send operations). Same goes for the client FD. + rchannel := &Rchannel{ + Offset: uint64(ch.desc.Offset), + Length: uint64(ch.desc.Length), + } + switch t.Control { + case 0: + // Open the main data channel. + mfd, err := syscall.Dup(int(cs.channelAlloc.FD())) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(mfd)) + case 1: + cfd, err := syscall.Dup(ch.client.FD()) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(cfd)) + default: + return newErr(syscall.EINVAL) + } + return rchannel +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go new file mode 100644 index 000000000..57b89ad7d --- /dev/null +++ b/pkg/p9/messages.go @@ -0,0 +1,2662 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "fmt" + "math" + + "gvisor.dev/gvisor/pkg/fd" +) + +// ErrInvalidMsgType is returned when an unsupported message type is found. +type ErrInvalidMsgType struct { + MsgType +} + +// Error returns a useful string. +func (e *ErrInvalidMsgType) Error() string { + return fmt.Sprintf("invalid message type: %d", e.MsgType) +} + +// message is a generic 9P message. +type message interface { + encoder + fmt.Stringer + + // Type returns the message type number. + Type() MsgType +} + +// payloader is a special message which may include an inline payload. +type payloader interface { + // FixedSize returns the size of the fixed portion of this message. + FixedSize() uint32 + + // Payload returns the payload for sending. + Payload() []byte + + // SetPayload returns the decoded message. + // + // This is going to be total message size - FixedSize. But this should + // be validated during decode, which will be called after SetPayload. + SetPayload([]byte) +} + +// filer is a message capable of passing a file. +type filer interface { + // FilePayload returns the file payload. + FilePayload() *fd.FD + + // SetFilePayload sets the file payload. + SetFilePayload(*fd.FD) +} + +// filePayload embeds a File object. +type filePayload struct { + File *fd.FD +} + +// FilePayload returns the file payload. +func (f *filePayload) FilePayload() *fd.FD { + return f.File +} + +// SetFilePayload sets the received file. +func (f *filePayload) SetFilePayload(file *fd.FD) { + f.File = file +} + +// Tversion is a version request. +type Tversion struct { + // MSize is the message size to use. + MSize uint32 + + // Version is the version string. + // + // For this implementation, this must be 9P2000.L. + Version string +} + +// decode implements encoder.decode. +func (t *Tversion) decode(b *buffer) { + t.MSize = b.Read32() + t.Version = b.ReadString() +} + +// encode implements encoder.encode. +func (t *Tversion) encode(b *buffer) { + b.Write32(t.MSize) + b.WriteString(t.Version) +} + +// Type implements message.Type. +func (*Tversion) Type() MsgType { + return MsgTversion +} + +// String implements fmt.Stringer. +func (t *Tversion) String() string { + return fmt.Sprintf("Tversion{MSize: %d, Version: %s}", t.MSize, t.Version) +} + +// Rversion is a version response. +type Rversion struct { + // MSize is the negotiated size. + MSize uint32 + + // Version is the negotiated version. + Version string +} + +// decode implements encoder.decode. +func (r *Rversion) decode(b *buffer) { + r.MSize = b.Read32() + r.Version = b.ReadString() +} + +// encode implements encoder.encode. +func (r *Rversion) encode(b *buffer) { + b.Write32(r.MSize) + b.WriteString(r.Version) +} + +// Type implements message.Type. +func (*Rversion) Type() MsgType { + return MsgRversion +} + +// String implements fmt.Stringer. +func (r *Rversion) String() string { + return fmt.Sprintf("Rversion{MSize: %d, Version: %s}", r.MSize, r.Version) +} + +// Tflush is a flush request. +type Tflush struct { + // OldTag is the tag to wait on. + OldTag Tag +} + +// decode implements encoder.decode. +func (t *Tflush) decode(b *buffer) { + t.OldTag = b.ReadTag() +} + +// encode implements encoder.encode. +func (t *Tflush) encode(b *buffer) { + b.WriteTag(t.OldTag) +} + +// Type implements message.Type. +func (*Tflush) Type() MsgType { + return MsgTflush +} + +// String implements fmt.Stringer. +func (t *Tflush) String() string { + return fmt.Sprintf("Tflush{OldTag: %d}", t.OldTag) +} + +// Rflush is a flush response. +type Rflush struct { +} + +// decode implements encoder.decode. +func (*Rflush) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rflush) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rflush) Type() MsgType { + return MsgRflush +} + +// String implements fmt.Stringer. +func (r *Rflush) String() string { + return "RFlush{}" +} + +// Twalk is a walk request. +type Twalk struct { + // FID is the FID to be walked. + FID FID + + // NewFID is the resulting FID. + NewFID FID + + // Names are the set of names to be walked. + Names []string +} + +// decode implements encoder.decode. +func (t *Twalk) decode(b *buffer) { + t.FID = b.ReadFID() + t.NewFID = b.ReadFID() + n := b.Read16() + t.Names = t.Names[:0] + for i := 0; i < int(n); i++ { + t.Names = append(t.Names, b.ReadString()) + } +} + +// encode implements encoder.encode. +func (t *Twalk) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteFID(t.NewFID) + b.Write16(uint16(len(t.Names))) + for _, name := range t.Names { + b.WriteString(name) + } +} + +// Type implements message.Type. +func (*Twalk) Type() MsgType { + return MsgTwalk +} + +// String implements fmt.Stringer. +func (t *Twalk) String() string { + return fmt.Sprintf("Twalk{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names) +} + +// Rwalk is a walk response. +type Rwalk struct { + // QIDs are the set of QIDs returned. + QIDs []QID +} + +// decode implements encoder.decode. +func (r *Rwalk) decode(b *buffer) { + n := b.Read16() + r.QIDs = r.QIDs[:0] + for i := 0; i < int(n); i++ { + var q QID + q.decode(b) + r.QIDs = append(r.QIDs, q) + } +} + +// encode implements encoder.encode. +func (r *Rwalk) encode(b *buffer) { + b.Write16(uint16(len(r.QIDs))) + for _, q := range r.QIDs { + q.encode(b) + } +} + +// Type implements message.Type. +func (*Rwalk) Type() MsgType { + return MsgRwalk +} + +// String implements fmt.Stringer. +func (r *Rwalk) String() string { + return fmt.Sprintf("Rwalk{QIDs: %v}", r.QIDs) +} + +// Tclunk is a close request. +type Tclunk struct { + // FID is the FID to be closed. + FID FID +} + +// decode implements encoder.decode. +func (t *Tclunk) decode(b *buffer) { + t.FID = b.ReadFID() +} + +// encode implements encoder.encode. +func (t *Tclunk) encode(b *buffer) { + b.WriteFID(t.FID) +} + +// Type implements message.Type. +func (*Tclunk) Type() MsgType { + return MsgTclunk +} + +// String implements fmt.Stringer. +func (t *Tclunk) String() string { + return fmt.Sprintf("Tclunk{FID: %d}", t.FID) +} + +// Rclunk is a close response. +type Rclunk struct { +} + +// decode implements encoder.decode. +func (*Rclunk) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rclunk) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rclunk) Type() MsgType { + return MsgRclunk +} + +// String implements fmt.Stringer. +func (r *Rclunk) String() string { + return "Rclunk{}" +} + +// Tremove is a remove request. +// +// This will eventually be replaced by Tunlinkat. +type Tremove struct { + // FID is the FID to be removed. + FID FID +} + +// decode implements encoder.decode. +func (t *Tremove) decode(b *buffer) { + t.FID = b.ReadFID() +} + +// encode implements encoder.encode. +func (t *Tremove) encode(b *buffer) { + b.WriteFID(t.FID) +} + +// Type implements message.Type. +func (*Tremove) Type() MsgType { + return MsgTremove +} + +// String implements fmt.Stringer. +func (t *Tremove) String() string { + return fmt.Sprintf("Tremove{FID: %d}", t.FID) +} + +// Rremove is a remove response. +type Rremove struct { +} + +// decode implements encoder.decode. +func (*Rremove) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rremove) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rremove) Type() MsgType { + return MsgRremove +} + +// String implements fmt.Stringer. +func (r *Rremove) String() string { + return "Rremove{}" +} + +// Rlerror is an error response. +// +// Note that this replaces the error code used in 9p. +type Rlerror struct { + Error uint32 +} + +// decode implements encoder.decode. +func (r *Rlerror) decode(b *buffer) { + r.Error = b.Read32() +} + +// encode implements encoder.encode. +func (r *Rlerror) encode(b *buffer) { + b.Write32(r.Error) +} + +// Type implements message.Type. +func (*Rlerror) Type() MsgType { + return MsgRlerror +} + +// String implements fmt.Stringer. +func (r *Rlerror) String() string { + return fmt.Sprintf("Rlerror{Error: %d}", r.Error) +} + +// Tauth is an authentication request. +type Tauth struct { + // AuthenticationFID is the FID to attach the authentication result. + AuthenticationFID FID + + // UserName is the user to attach. + UserName string + + // AttachName is the attach name. + AttachName string + + // UserID is the numeric identifier for UserName. + UID UID +} + +// decode implements encoder.decode. +func (t *Tauth) decode(b *buffer) { + t.AuthenticationFID = b.ReadFID() + t.UserName = b.ReadString() + t.AttachName = b.ReadString() + t.UID = b.ReadUID() +} + +// encode implements encoder.encode. +func (t *Tauth) encode(b *buffer) { + b.WriteFID(t.AuthenticationFID) + b.WriteString(t.UserName) + b.WriteString(t.AttachName) + b.WriteUID(t.UID) +} + +// Type implements message.Type. +func (*Tauth) Type() MsgType { + return MsgTauth +} + +// String implements fmt.Stringer. +func (t *Tauth) String() string { + return fmt.Sprintf("Tauth{AuthFID: %d, UserName: %s, AttachName: %s, UID: %d", t.AuthenticationFID, t.UserName, t.AttachName, t.UID) +} + +// Rauth is an authentication response. +// +// encode and decode are inherited directly from QID. +type Rauth struct { + QID +} + +// Type implements message.Type. +func (*Rauth) Type() MsgType { + return MsgRauth +} + +// String implements fmt.Stringer. +func (r *Rauth) String() string { + return fmt.Sprintf("Rauth{QID: %s}", r.QID) +} + +// Tattach is an attach request. +type Tattach struct { + // FID is the FID to be attached. + FID FID + + // Auth is the embedded authentication request. + // + // See client.Attach for information regarding authentication. + Auth Tauth +} + +// decode implements encoder.decode. +func (t *Tattach) decode(b *buffer) { + t.FID = b.ReadFID() + t.Auth.decode(b) +} + +// encode implements encoder.encode. +func (t *Tattach) encode(b *buffer) { + b.WriteFID(t.FID) + t.Auth.encode(b) +} + +// Type implements message.Type. +func (*Tattach) Type() MsgType { + return MsgTattach +} + +// String implements fmt.Stringer. +func (t *Tattach) String() string { + return fmt.Sprintf("Tattach{FID: %d, AuthFID: %d, UserName: %s, AttachName: %s, UID: %d}", t.FID, t.Auth.AuthenticationFID, t.Auth.UserName, t.Auth.AttachName, t.Auth.UID) +} + +// Rattach is an attach response. +type Rattach struct { + QID +} + +// Type implements message.Type. +func (*Rattach) Type() MsgType { + return MsgRattach +} + +// String implements fmt.Stringer. +func (r *Rattach) String() string { + return fmt.Sprintf("Rattach{QID: %s}", r.QID) +} + +// Tlopen is an open request. +type Tlopen struct { + // FID is the FID to be opened. + FID FID + + // Flags are the open flags. + Flags OpenFlags +} + +// decode implements encoder.decode. +func (t *Tlopen) decode(b *buffer) { + t.FID = b.ReadFID() + t.Flags = b.ReadOpenFlags() +} + +// encode implements encoder.encode. +func (t *Tlopen) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteOpenFlags(t.Flags) +} + +// Type implements message.Type. +func (*Tlopen) Type() MsgType { + return MsgTlopen +} + +// String implements fmt.Stringer. +func (t *Tlopen) String() string { + return fmt.Sprintf("Tlopen{FID: %d, Flags: %v}", t.FID, t.Flags) +} + +// Rlopen is a open response. +type Rlopen struct { + // QID is the file's QID. + QID QID + + // IoUnit is the recommended I/O unit. + IoUnit uint32 + + filePayload +} + +// decode implements encoder.decode. +func (r *Rlopen) decode(b *buffer) { + r.QID.decode(b) + r.IoUnit = b.Read32() +} + +// encode implements encoder.encode. +func (r *Rlopen) encode(b *buffer) { + r.QID.encode(b) + b.Write32(r.IoUnit) +} + +// Type implements message.Type. +func (*Rlopen) Type() MsgType { + return MsgRlopen +} + +// String implements fmt.Stringer. +func (r *Rlopen) String() string { + return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File) +} + +// Tlcreate is a create request. +type Tlcreate struct { + // FID is the parent FID. + // + // This becomes the new file. + FID FID + + // Name is the file name to create. + Name string + + // Mode is the open mode (O_RDWR, etc.). + // + // Note that flags like O_TRUNC are ignored, as is O_EXCL. All + // create operations are exclusive. + OpenFlags OpenFlags + + // Permissions is the set of permission bits. + Permissions FileMode + + // GID is the group ID to use for creating the file. + GID GID +} + +// decode implements encoder.decode. +func (t *Tlcreate) decode(b *buffer) { + t.FID = b.ReadFID() + t.Name = b.ReadString() + t.OpenFlags = b.ReadOpenFlags() + t.Permissions = b.ReadPermissions() + t.GID = b.ReadGID() +} + +// encode implements encoder.encode. +func (t *Tlcreate) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteString(t.Name) + b.WriteOpenFlags(t.OpenFlags) + b.WritePermissions(t.Permissions) + b.WriteGID(t.GID) +} + +// Type implements message.Type. +func (*Tlcreate) Type() MsgType { + return MsgTlcreate +} + +// String implements fmt.Stringer. +func (t *Tlcreate) String() string { + return fmt.Sprintf("Tlcreate{FID: %d, Name: %s, OpenFlags: %s, Permissions: 0o%o, GID: %d}", t.FID, t.Name, t.OpenFlags, t.Permissions, t.GID) +} + +// Rlcreate is a create response. +// +// The encode, decode, etc. methods are inherited from Rlopen. +type Rlcreate struct { + Rlopen +} + +// Type implements message.Type. +func (*Rlcreate) Type() MsgType { + return MsgRlcreate +} + +// String implements fmt.Stringer. +func (r *Rlcreate) String() string { + return fmt.Sprintf("Rlcreate{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File) +} + +// Tsymlink is a symlink request. +type Tsymlink struct { + // Directory is the directory FID. + Directory FID + + // Name is the new in the directory. + Name string + + // Target is the symlink target. + Target string + + // GID is the owning group. + GID GID +} + +// decode implements encoder.decode. +func (t *Tsymlink) decode(b *buffer) { + t.Directory = b.ReadFID() + t.Name = b.ReadString() + t.Target = b.ReadString() + t.GID = b.ReadGID() +} + +// encode implements encoder.encode. +func (t *Tsymlink) encode(b *buffer) { + b.WriteFID(t.Directory) + b.WriteString(t.Name) + b.WriteString(t.Target) + b.WriteGID(t.GID) +} + +// Type implements message.Type. +func (*Tsymlink) Type() MsgType { + return MsgTsymlink +} + +// String implements fmt.Stringer. +func (t *Tsymlink) String() string { + return fmt.Sprintf("Tsymlink{DirectoryFID: %d, Name: %s, Target: %s, GID: %d}", t.Directory, t.Name, t.Target, t.GID) +} + +// Rsymlink is a symlink response. +type Rsymlink struct { + // QID is the new symlink's QID. + QID QID +} + +// decode implements encoder.decode. +func (r *Rsymlink) decode(b *buffer) { + r.QID.decode(b) +} + +// encode implements encoder.encode. +func (r *Rsymlink) encode(b *buffer) { + r.QID.encode(b) +} + +// Type implements message.Type. +func (*Rsymlink) Type() MsgType { + return MsgRsymlink +} + +// String implements fmt.Stringer. +func (r *Rsymlink) String() string { + return fmt.Sprintf("Rsymlink{QID: %s}", r.QID) +} + +// Tlink is a link request. +type Tlink struct { + // Directory is the directory to contain the link. + Directory FID + + // FID is the target. + Target FID + + // Name is the new source name. + Name string +} + +// decode implements encoder.decode. +func (t *Tlink) decode(b *buffer) { + t.Directory = b.ReadFID() + t.Target = b.ReadFID() + t.Name = b.ReadString() +} + +// encode implements encoder.encode. +func (t *Tlink) encode(b *buffer) { + b.WriteFID(t.Directory) + b.WriteFID(t.Target) + b.WriteString(t.Name) +} + +// Type implements message.Type. +func (*Tlink) Type() MsgType { + return MsgTlink +} + +// String implements fmt.Stringer. +func (t *Tlink) String() string { + return fmt.Sprintf("Tlink{DirectoryFID: %d, TargetFID: %d, Name: %s}", t.Directory, t.Target, t.Name) +} + +// Rlink is a link response. +type Rlink struct { +} + +// Type implements message.Type. +func (*Rlink) Type() MsgType { + return MsgRlink +} + +// decode implements encoder.decode. +func (*Rlink) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rlink) encode(*buffer) { +} + +// String implements fmt.Stringer. +func (r *Rlink) String() string { + return "Rlink{}" +} + +// Trenameat is a rename request. +type Trenameat struct { + // OldDirectory is the source directory. + OldDirectory FID + + // OldName is the source file name. + OldName string + + // NewDirectory is the target directory. + NewDirectory FID + + // NewName is the new file name. + NewName string +} + +// decode implements encoder.decode. +func (t *Trenameat) decode(b *buffer) { + t.OldDirectory = b.ReadFID() + t.OldName = b.ReadString() + t.NewDirectory = b.ReadFID() + t.NewName = b.ReadString() +} + +// encode implements encoder.encode. +func (t *Trenameat) encode(b *buffer) { + b.WriteFID(t.OldDirectory) + b.WriteString(t.OldName) + b.WriteFID(t.NewDirectory) + b.WriteString(t.NewName) +} + +// Type implements message.Type. +func (*Trenameat) Type() MsgType { + return MsgTrenameat +} + +// String implements fmt.Stringer. +func (t *Trenameat) String() string { + return fmt.Sprintf("TrenameAt{OldDirectoryFID: %d, OldName: %s, NewDirectoryFID: %d, NewName: %s}", t.OldDirectory, t.OldName, t.NewDirectory, t.NewName) +} + +// Rrenameat is a rename response. +type Rrenameat struct { +} + +// decode implements encoder.decode. +func (*Rrenameat) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rrenameat) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rrenameat) Type() MsgType { + return MsgRrenameat +} + +// String implements fmt.Stringer. +func (r *Rrenameat) String() string { + return "Rrenameat{}" +} + +// Tunlinkat is an unlink request. +type Tunlinkat struct { + // Directory is the originating directory. + Directory FID + + // Name is the name of the entry to unlink. + Name string + + // Flags are extra flags (e.g. O_DIRECTORY). These are not interpreted by p9. + Flags uint32 +} + +// decode implements encoder.decode. +func (t *Tunlinkat) decode(b *buffer) { + t.Directory = b.ReadFID() + t.Name = b.ReadString() + t.Flags = b.Read32() +} + +// encode implements encoder.encode. +func (t *Tunlinkat) encode(b *buffer) { + b.WriteFID(t.Directory) + b.WriteString(t.Name) + b.Write32(t.Flags) +} + +// Type implements message.Type. +func (*Tunlinkat) Type() MsgType { + return MsgTunlinkat +} + +// String implements fmt.Stringer. +func (t *Tunlinkat) String() string { + return fmt.Sprintf("Tunlinkat{DirectoryFID: %d, Name: %s, Flags: 0x%X}", t.Directory, t.Name, t.Flags) +} + +// Runlinkat is an unlink response. +type Runlinkat struct { +} + +// decode implements encoder.decode. +func (*Runlinkat) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Runlinkat) encode(*buffer) { +} + +// Type implements message.Type. +func (*Runlinkat) Type() MsgType { + return MsgRunlinkat +} + +// String implements fmt.Stringer. +func (r *Runlinkat) String() string { + return "Runlinkat{}" +} + +// Trename is a rename request. +// +// Note that this generally isn't used anymore, and ideally all rename calls +// should Trenameat below. +type Trename struct { + // FID is the FID to rename. + FID FID + + // Directory is the target directory. + Directory FID + + // Name is the new file name. + Name string +} + +// decode implements encoder.decode. +func (t *Trename) decode(b *buffer) { + t.FID = b.ReadFID() + t.Directory = b.ReadFID() + t.Name = b.ReadString() +} + +// encode implements encoder.encode. +func (t *Trename) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteFID(t.Directory) + b.WriteString(t.Name) +} + +// Type implements message.Type. +func (*Trename) Type() MsgType { + return MsgTrename +} + +// String implements fmt.Stringer. +func (t *Trename) String() string { + return fmt.Sprintf("Trename{FID: %d, DirectoryFID: %d, Name: %s}", t.FID, t.Directory, t.Name) +} + +// Rrename is a rename response. +type Rrename struct { +} + +// decode implements encoder.decode. +func (*Rrename) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rrename) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rrename) Type() MsgType { + return MsgRrename +} + +// String implements fmt.Stringer. +func (r *Rrename) String() string { + return "Rrename{}" +} + +// Treadlink is a readlink request. +type Treadlink struct { + // FID is the symlink. + FID FID +} + +// decode implements encoder.decode. +func (t *Treadlink) decode(b *buffer) { + t.FID = b.ReadFID() +} + +// encode implements encoder.encode. +func (t *Treadlink) encode(b *buffer) { + b.WriteFID(t.FID) +} + +// Type implements message.Type. +func (*Treadlink) Type() MsgType { + return MsgTreadlink +} + +// String implements fmt.Stringer. +func (t *Treadlink) String() string { + return fmt.Sprintf("Treadlink{FID: %d}", t.FID) +} + +// Rreadlink is a readlink response. +type Rreadlink struct { + // Target is the symlink target. + Target string +} + +// decode implements encoder.decode. +func (r *Rreadlink) decode(b *buffer) { + r.Target = b.ReadString() +} + +// encode implements encoder.encode. +func (r *Rreadlink) encode(b *buffer) { + b.WriteString(r.Target) +} + +// Type implements message.Type. +func (*Rreadlink) Type() MsgType { + return MsgRreadlink +} + +// String implements fmt.Stringer. +func (r *Rreadlink) String() string { + return fmt.Sprintf("Rreadlink{Target: %s}", r.Target) +} + +// Tread is a read request. +type Tread struct { + // FID is the FID to read. + FID FID + + // Offset indicates the file offset. + Offset uint64 + + // Count indicates the number of bytes to read. + Count uint32 +} + +// decode implements encoder.decode. +func (t *Tread) decode(b *buffer) { + t.FID = b.ReadFID() + t.Offset = b.Read64() + t.Count = b.Read32() +} + +// encode implements encoder.encode. +func (t *Tread) encode(b *buffer) { + b.WriteFID(t.FID) + b.Write64(t.Offset) + b.Write32(t.Count) +} + +// Type implements message.Type. +func (*Tread) Type() MsgType { + return MsgTread +} + +// String implements fmt.Stringer. +func (t *Tread) String() string { + return fmt.Sprintf("Tread{FID: %d, Offset: %d, Count: %d}", t.FID, t.Offset, t.Count) +} + +// Rread is the response for a Tread. +type Rread struct { + // Data is the resulting data. + Data []byte +} + +// decode implements encoder.decode. +// +// Data is automatically decoded via Payload. +func (r *Rread) decode(b *buffer) { + count := b.Read32() + if count != uint32(len(r.Data)) { + b.markOverrun() + } +} + +// encode implements encoder.encode. +// +// Data is automatically encoded via Payload. +func (r *Rread) encode(b *buffer) { + b.Write32(uint32(len(r.Data))) +} + +// Type implements message.Type. +func (*Rread) Type() MsgType { + return MsgRread +} + +// FixedSize implements payloader.FixedSize. +func (*Rread) FixedSize() uint32 { + return 4 +} + +// Payload implements payloader.Payload. +func (r *Rread) Payload() []byte { + return r.Data +} + +// SetPayload implements payloader.SetPayload. +func (r *Rread) SetPayload(p []byte) { + r.Data = p +} + +// String implements fmt.Stringer. +func (r *Rread) String() string { + return fmt.Sprintf("Rread{len(Data): %d}", len(r.Data)) +} + +// Twrite is a write request. +type Twrite struct { + // FID is the FID to read. + FID FID + + // Offset indicates the file offset. + Offset uint64 + + // Data is the data to be written. + Data []byte +} + +// decode implements encoder.decode. +func (t *Twrite) decode(b *buffer) { + t.FID = b.ReadFID() + t.Offset = b.Read64() + count := b.Read32() + if count != uint32(len(t.Data)) { + b.markOverrun() + } +} + +// encode implements encoder.encode. +// +// This uses the buffer payload to avoid a copy. +func (t *Twrite) encode(b *buffer) { + b.WriteFID(t.FID) + b.Write64(t.Offset) + b.Write32(uint32(len(t.Data))) +} + +// Type implements message.Type. +func (*Twrite) Type() MsgType { + return MsgTwrite +} + +// FixedSize implements payloader.FixedSize. +func (*Twrite) FixedSize() uint32 { + return 16 +} + +// Payload implements payloader.Payload. +func (t *Twrite) Payload() []byte { + return t.Data +} + +// SetPayload implements payloader.SetPayload. +func (t *Twrite) SetPayload(p []byte) { + t.Data = p +} + +// String implements fmt.Stringer. +func (t *Twrite) String() string { + return fmt.Sprintf("Twrite{FID: %v, Offset %d, len(Data): %d}", t.FID, t.Offset, len(t.Data)) +} + +// Rwrite is the response for a Twrite. +type Rwrite struct { + // Count indicates the number of bytes successfully written. + Count uint32 +} + +// decode implements encoder.decode. +func (r *Rwrite) decode(b *buffer) { + r.Count = b.Read32() +} + +// encode implements encoder.encode. +func (r *Rwrite) encode(b *buffer) { + b.Write32(r.Count) +} + +// Type implements message.Type. +func (*Rwrite) Type() MsgType { + return MsgRwrite +} + +// String implements fmt.Stringer. +func (r *Rwrite) String() string { + return fmt.Sprintf("Rwrite{Count: %d}", r.Count) +} + +// Tmknod is a mknod request. +type Tmknod struct { + // Directory is the parent directory. + Directory FID + + // Name is the device name. + Name string + + // Mode is the device mode and permissions. + Mode FileMode + + // Major is the device major number. + Major uint32 + + // Minor is the device minor number. + Minor uint32 + + // GID is the device GID. + GID GID +} + +// decode implements encoder.decode. +func (t *Tmknod) decode(b *buffer) { + t.Directory = b.ReadFID() + t.Name = b.ReadString() + t.Mode = b.ReadFileMode() + t.Major = b.Read32() + t.Minor = b.Read32() + t.GID = b.ReadGID() +} + +// encode implements encoder.encode. +func (t *Tmknod) encode(b *buffer) { + b.WriteFID(t.Directory) + b.WriteString(t.Name) + b.WriteFileMode(t.Mode) + b.Write32(t.Major) + b.Write32(t.Minor) + b.WriteGID(t.GID) +} + +// Type implements message.Type. +func (*Tmknod) Type() MsgType { + return MsgTmknod +} + +// String implements fmt.Stringer. +func (t *Tmknod) String() string { + return fmt.Sprintf("Tmknod{DirectoryFID: %d, Name: %s, Mode: 0o%o, Major: %d, Minor: %d, GID: %d}", t.Directory, t.Name, t.Mode, t.Major, t.Minor, t.GID) +} + +// Rmknod is a mknod response. +type Rmknod struct { + // QID is the resulting QID. + QID QID +} + +// decode implements encoder.decode. +func (r *Rmknod) decode(b *buffer) { + r.QID.decode(b) +} + +// encode implements encoder.encode. +func (r *Rmknod) encode(b *buffer) { + r.QID.encode(b) +} + +// Type implements message.Type. +func (*Rmknod) Type() MsgType { + return MsgRmknod +} + +// String implements fmt.Stringer. +func (r *Rmknod) String() string { + return fmt.Sprintf("Rmknod{QID: %s}", r.QID) +} + +// Tmkdir is a mkdir request. +type Tmkdir struct { + // Directory is the parent directory. + Directory FID + + // Name is the new directory name. + Name string + + // Permissions is the set of permission bits. + Permissions FileMode + + // GID is the owning group. + GID GID +} + +// decode implements encoder.decode. +func (t *Tmkdir) decode(b *buffer) { + t.Directory = b.ReadFID() + t.Name = b.ReadString() + t.Permissions = b.ReadPermissions() + t.GID = b.ReadGID() +} + +// encode implements encoder.encode. +func (t *Tmkdir) encode(b *buffer) { + b.WriteFID(t.Directory) + b.WriteString(t.Name) + b.WritePermissions(t.Permissions) + b.WriteGID(t.GID) +} + +// Type implements message.Type. +func (*Tmkdir) Type() MsgType { + return MsgTmkdir +} + +// String implements fmt.Stringer. +func (t *Tmkdir) String() string { + return fmt.Sprintf("Tmkdir{DirectoryFID: %d, Name: %s, Permissions: 0o%o, GID: %d}", t.Directory, t.Name, t.Permissions, t.GID) +} + +// Rmkdir is a mkdir response. +type Rmkdir struct { + // QID is the resulting QID. + QID QID +} + +// decode implements encoder.decode. +func (r *Rmkdir) decode(b *buffer) { + r.QID.decode(b) +} + +// encode implements encoder.encode. +func (r *Rmkdir) encode(b *buffer) { + r.QID.encode(b) +} + +// Type implements message.Type. +func (*Rmkdir) Type() MsgType { + return MsgRmkdir +} + +// String implements fmt.Stringer. +func (r *Rmkdir) String() string { + return fmt.Sprintf("Rmkdir{QID: %s}", r.QID) +} + +// Tgetattr is a getattr request. +type Tgetattr struct { + // FID is the FID to get attributes for. + FID FID + + // AttrMask is the set of attributes to get. + AttrMask AttrMask +} + +// decode implements encoder.decode. +func (t *Tgetattr) decode(b *buffer) { + t.FID = b.ReadFID() + t.AttrMask.decode(b) +} + +// encode implements encoder.encode. +func (t *Tgetattr) encode(b *buffer) { + b.WriteFID(t.FID) + t.AttrMask.encode(b) +} + +// Type implements message.Type. +func (*Tgetattr) Type() MsgType { + return MsgTgetattr +} + +// String implements fmt.Stringer. +func (t *Tgetattr) String() string { + return fmt.Sprintf("Tgetattr{FID: %d, AttrMask: %s}", t.FID, t.AttrMask) +} + +// Rgetattr is a getattr response. +type Rgetattr struct { + // Valid indicates which fields are valid. + Valid AttrMask + + // QID is the QID for this file. + QID + + // Attr is the set of attributes. + Attr Attr +} + +// decode implements encoder.decode. +func (r *Rgetattr) decode(b *buffer) { + r.Valid.decode(b) + r.QID.decode(b) + r.Attr.decode(b) +} + +// encode implements encoder.encode. +func (r *Rgetattr) encode(b *buffer) { + r.Valid.encode(b) + r.QID.encode(b) + r.Attr.encode(b) +} + +// Type implements message.Type. +func (*Rgetattr) Type() MsgType { + return MsgRgetattr +} + +// String implements fmt.Stringer. +func (r *Rgetattr) String() string { + return fmt.Sprintf("Rgetattr{Valid: %v, QID: %s, Attr: %s}", r.Valid, r.QID, r.Attr) +} + +// Tsetattr is a setattr request. +type Tsetattr struct { + // FID is the FID to change. + FID FID + + // Valid is the set of bits which will be used. + Valid SetAttrMask + + // SetAttr is the set request. + SetAttr SetAttr +} + +// decode implements encoder.decode. +func (t *Tsetattr) decode(b *buffer) { + t.FID = b.ReadFID() + t.Valid.decode(b) + t.SetAttr.decode(b) +} + +// encode implements encoder.encode. +func (t *Tsetattr) encode(b *buffer) { + b.WriteFID(t.FID) + t.Valid.encode(b) + t.SetAttr.encode(b) +} + +// Type implements message.Type. +func (*Tsetattr) Type() MsgType { + return MsgTsetattr +} + +// String implements fmt.Stringer. +func (t *Tsetattr) String() string { + return fmt.Sprintf("Tsetattr{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr) +} + +// Rsetattr is a setattr response. +type Rsetattr struct { +} + +// decode implements encoder.decode. +func (*Rsetattr) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rsetattr) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rsetattr) Type() MsgType { + return MsgRsetattr +} + +// String implements fmt.Stringer. +func (r *Rsetattr) String() string { + return "Rsetattr{}" +} + +// Tallocate is an allocate request. This is an extension to 9P protocol, not +// present in the 9P2000.L standard. +type Tallocate struct { + FID FID + Mode AllocateMode + Offset uint64 + Length uint64 +} + +// decode implements encoder.decode. +func (t *Tallocate) decode(b *buffer) { + t.FID = b.ReadFID() + t.Mode.decode(b) + t.Offset = b.Read64() + t.Length = b.Read64() +} + +// encode implements encoder.encode. +func (t *Tallocate) encode(b *buffer) { + b.WriteFID(t.FID) + t.Mode.encode(b) + b.Write64(t.Offset) + b.Write64(t.Length) +} + +// Type implements message.Type. +func (*Tallocate) Type() MsgType { + return MsgTallocate +} + +// String implements fmt.Stringer. +func (t *Tallocate) String() string { + return fmt.Sprintf("Tallocate{FID: %d, Offset: %d, Length: %d}", t.FID, t.Offset, t.Length) +} + +// Rallocate is an allocate response. +type Rallocate struct { +} + +// decode implements encoder.decode. +func (*Rallocate) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rallocate) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rallocate) Type() MsgType { + return MsgRallocate +} + +// String implements fmt.Stringer. +func (r *Rallocate) String() string { + return "Rallocate{}" +} + +// Tlistxattr is a listxattr request. +type Tlistxattr struct { + // FID refers to the file on which to list xattrs. + FID FID + + // Size is the buffer size for the xattr list. + Size uint64 +} + +// decode implements encoder.decode. +func (t *Tlistxattr) decode(b *buffer) { + t.FID = b.ReadFID() + t.Size = b.Read64() +} + +// encode implements encoder.encode. +func (t *Tlistxattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.Write64(t.Size) +} + +// Type implements message.Type. +func (*Tlistxattr) Type() MsgType { + return MsgTlistxattr +} + +// String implements fmt.Stringer. +func (t *Tlistxattr) String() string { + return fmt.Sprintf("Tlistxattr{FID: %d, Size: %d}", t.FID, t.Size) +} + +// Rlistxattr is a listxattr response. +type Rlistxattr struct { + // Xattrs is a list of extended attribute names. + Xattrs []string +} + +// decode implements encoder.decode. +func (r *Rlistxattr) decode(b *buffer) { + n := b.Read16() + r.Xattrs = r.Xattrs[:0] + for i := 0; i < int(n); i++ { + r.Xattrs = append(r.Xattrs, b.ReadString()) + } +} + +// encode implements encoder.encode. +func (r *Rlistxattr) encode(b *buffer) { + b.Write16(uint16(len(r.Xattrs))) + for _, x := range r.Xattrs { + b.WriteString(x) + } +} + +// Type implements message.Type. +func (*Rlistxattr) Type() MsgType { + return MsgRlistxattr +} + +// String implements fmt.Stringer. +func (r *Rlistxattr) String() string { + return fmt.Sprintf("Rlistxattr{Xattrs: %v}", r.Xattrs) +} + +// Txattrwalk walks extended attributes. +type Txattrwalk struct { + // FID is the FID to check for attributes. + FID FID + + // NewFID is the new FID associated with the attributes. + NewFID FID + + // Name is the attribute name. + Name string +} + +// decode implements encoder.decode. +func (t *Txattrwalk) decode(b *buffer) { + t.FID = b.ReadFID() + t.NewFID = b.ReadFID() + t.Name = b.ReadString() +} + +// encode implements encoder.encode. +func (t *Txattrwalk) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteFID(t.NewFID) + b.WriteString(t.Name) +} + +// Type implements message.Type. +func (*Txattrwalk) Type() MsgType { + return MsgTxattrwalk +} + +// String implements fmt.Stringer. +func (t *Txattrwalk) String() string { + return fmt.Sprintf("Txattrwalk{FID: %d, NewFID: %d, Name: %s}", t.FID, t.NewFID, t.Name) +} + +// Rxattrwalk is a xattrwalk response. +type Rxattrwalk struct { + // Size is the size of the extended attribute. + Size uint64 +} + +// decode implements encoder.decode. +func (r *Rxattrwalk) decode(b *buffer) { + r.Size = b.Read64() +} + +// encode implements encoder.encode. +func (r *Rxattrwalk) encode(b *buffer) { + b.Write64(r.Size) +} + +// Type implements message.Type. +func (*Rxattrwalk) Type() MsgType { + return MsgRxattrwalk +} + +// String implements fmt.Stringer. +func (r *Rxattrwalk) String() string { + return fmt.Sprintf("Rxattrwalk{Size: %d}", r.Size) +} + +// Txattrcreate prepare to set extended attributes. +type Txattrcreate struct { + // FID is input/output parameter, it identifies the file on which + // extended attributes will be set but after successful Rxattrcreate + // it is used to write the extended attribute value. + FID FID + + // Name is the attribute name. + Name string + + // Size of the attribute value. When the FID is clunked it has to match + // the number of bytes written to the FID. + AttrSize uint64 + + // Linux setxattr(2) flags. + Flags uint32 +} + +// decode implements encoder.decode. +func (t *Txattrcreate) decode(b *buffer) { + t.FID = b.ReadFID() + t.Name = b.ReadString() + t.AttrSize = b.Read64() + t.Flags = b.Read32() +} + +// encode implements encoder.encode. +func (t *Txattrcreate) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteString(t.Name) + b.Write64(t.AttrSize) + b.Write32(t.Flags) +} + +// Type implements message.Type. +func (*Txattrcreate) Type() MsgType { + return MsgTxattrcreate +} + +// String implements fmt.Stringer. +func (t *Txattrcreate) String() string { + return fmt.Sprintf("Txattrcreate{FID: %d, Name: %s, AttrSize: %d, Flags: %d}", t.FID, t.Name, t.AttrSize, t.Flags) +} + +// Rxattrcreate is a xattrcreate response. +type Rxattrcreate struct { +} + +// decode implements encoder.decode. +func (r *Rxattrcreate) decode(*buffer) { +} + +// encode implements encoder.encode. +func (r *Rxattrcreate) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rxattrcreate) Type() MsgType { + return MsgRxattrcreate +} + +// String implements fmt.Stringer. +func (r *Rxattrcreate) String() string { + return "Rxattrcreate{}" +} + +// Tgetxattr is a getxattr request. +type Tgetxattr struct { + // FID refers to the file for which to get xattrs. + FID FID + + // Name is the xattr to get. + Name string + + // Size is the buffer size for the xattr to get. + Size uint64 +} + +// decode implements encoder.decode. +func (t *Tgetxattr) decode(b *buffer) { + t.FID = b.ReadFID() + t.Name = b.ReadString() + t.Size = b.Read64() +} + +// encode implements encoder.encode. +func (t *Tgetxattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteString(t.Name) + b.Write64(t.Size) +} + +// Type implements message.Type. +func (*Tgetxattr) Type() MsgType { + return MsgTgetxattr +} + +// String implements fmt.Stringer. +func (t *Tgetxattr) String() string { + return fmt.Sprintf("Tgetxattr{FID: %d, Name: %s, Size: %d}", t.FID, t.Name, t.Size) +} + +// Rgetxattr is a getxattr response. +type Rgetxattr struct { + // Value is the extended attribute value. + Value string +} + +// decode implements encoder.decode. +func (r *Rgetxattr) decode(b *buffer) { + r.Value = b.ReadString() +} + +// encode implements encoder.encode. +func (r *Rgetxattr) encode(b *buffer) { + b.WriteString(r.Value) +} + +// Type implements message.Type. +func (*Rgetxattr) Type() MsgType { + return MsgRgetxattr +} + +// String implements fmt.Stringer. +func (r *Rgetxattr) String() string { + return fmt.Sprintf("Rgetxattr{Value: %s}", r.Value) +} + +// Tsetxattr sets extended attributes. +type Tsetxattr struct { + // FID refers to the file on which to set xattrs. + FID FID + + // Name is the attribute name. + Name string + + // Value is the attribute value. + Value string + + // Linux setxattr(2) flags. + Flags uint32 +} + +// decode implements encoder.decode. +func (t *Tsetxattr) decode(b *buffer) { + t.FID = b.ReadFID() + t.Name = b.ReadString() + t.Value = b.ReadString() + t.Flags = b.Read32() +} + +// encode implements encoder.encode. +func (t *Tsetxattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteString(t.Name) + b.WriteString(t.Value) + b.Write32(t.Flags) +} + +// Type implements message.Type. +func (*Tsetxattr) Type() MsgType { + return MsgTsetxattr +} + +// String implements fmt.Stringer. +func (t *Tsetxattr) String() string { + return fmt.Sprintf("Tsetxattr{FID: %d, Name: %s, Value: %s, Flags: %d}", t.FID, t.Name, t.Value, t.Flags) +} + +// Rsetxattr is a setxattr response. +type Rsetxattr struct { +} + +// decode implements encoder.decode. +func (r *Rsetxattr) decode(*buffer) { +} + +// encode implements encoder.encode. +func (r *Rsetxattr) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rsetxattr) Type() MsgType { + return MsgRsetxattr +} + +// String implements fmt.Stringer. +func (r *Rsetxattr) String() string { + return "Rsetxattr{}" +} + +// Tremovexattr is a removexattr request. +type Tremovexattr struct { + // FID refers to the file on which to set xattrs. + FID FID + + // Name is the attribute name. + Name string +} + +// decode implements encoder.decode. +func (t *Tremovexattr) decode(b *buffer) { + t.FID = b.ReadFID() + t.Name = b.ReadString() +} + +// encode implements encoder.encode. +func (t *Tremovexattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteString(t.Name) +} + +// Type implements message.Type. +func (*Tremovexattr) Type() MsgType { + return MsgTremovexattr +} + +// String implements fmt.Stringer. +func (t *Tremovexattr) String() string { + return fmt.Sprintf("Tremovexattr{FID: %d, Name: %s}", t.FID, t.Name) +} + +// Rremovexattr is a removexattr response. +type Rremovexattr struct { +} + +// decode implements encoder.decode. +func (r *Rremovexattr) decode(*buffer) { +} + +// encode implements encoder.encode. +func (r *Rremovexattr) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rremovexattr) Type() MsgType { + return MsgRremovexattr +} + +// String implements fmt.Stringer. +func (r *Rremovexattr) String() string { + return "Rremovexattr{}" +} + +// Treaddir is a readdir request. +type Treaddir struct { + // Directory is the directory FID to read. + Directory FID + + // Offset is the offset to read at. + Offset uint64 + + // Count is the number of bytes to read. + Count uint32 +} + +// decode implements encoder.decode. +func (t *Treaddir) decode(b *buffer) { + t.Directory = b.ReadFID() + t.Offset = b.Read64() + t.Count = b.Read32() +} + +// encode implements encoder.encode. +func (t *Treaddir) encode(b *buffer) { + b.WriteFID(t.Directory) + b.Write64(t.Offset) + b.Write32(t.Count) +} + +// Type implements message.Type. +func (*Treaddir) Type() MsgType { + return MsgTreaddir +} + +// String implements fmt.Stringer. +func (t *Treaddir) String() string { + return fmt.Sprintf("Treaddir{DirectoryFID: %d, Offset: %d, Count: %d}", t.Directory, t.Offset, t.Count) +} + +// Rreaddir is a readdir response. +type Rreaddir struct { + // Count is the byte limit. + // + // This should always be set from the Treaddir request. + Count uint32 + + // Entries are the resulting entries. + // + // This may be constructed in decode. + Entries []Dirent + + // payload is the encoded payload. + // + // This is constructed by encode. + payload []byte +} + +// decode implements encoder.decode. +func (r *Rreaddir) decode(b *buffer) { + r.Count = b.Read32() + entriesBuf := buffer{data: r.payload} + r.Entries = r.Entries[:0] + for { + var d Dirent + d.decode(&entriesBuf) + if entriesBuf.isOverrun() { + // Couldn't decode a complete entry. + break + } + r.Entries = append(r.Entries, d) + } +} + +// encode implements encoder.encode. +func (r *Rreaddir) encode(b *buffer) { + entriesBuf := buffer{} + payloadSize := 0 + for _, d := range r.Entries { + d.encode(&entriesBuf) + if len(entriesBuf.data) > int(r.Count) { + break + } + payloadSize = len(entriesBuf.data) + } + r.Count = uint32(payloadSize) + r.payload = entriesBuf.data[:payloadSize] + b.Write32(r.Count) +} + +// Type implements message.Type. +func (*Rreaddir) Type() MsgType { + return MsgRreaddir +} + +// FixedSize implements payloader.FixedSize. +func (*Rreaddir) FixedSize() uint32 { + return 4 +} + +// Payload implements payloader.Payload. +func (r *Rreaddir) Payload() []byte { + return r.payload +} + +// SetPayload implements payloader.SetPayload. +func (r *Rreaddir) SetPayload(p []byte) { + r.payload = p +} + +// String implements fmt.Stringer. +func (r *Rreaddir) String() string { + return fmt.Sprintf("Rreaddir{Count: %d, Entries: %s}", r.Count, r.Entries) +} + +// Tfsync is an fsync request. +type Tfsync struct { + // FID is the fid to sync. + FID FID +} + +// decode implements encoder.decode. +func (t *Tfsync) decode(b *buffer) { + t.FID = b.ReadFID() +} + +// encode implements encoder.encode. +func (t *Tfsync) encode(b *buffer) { + b.WriteFID(t.FID) +} + +// Type implements message.Type. +func (*Tfsync) Type() MsgType { + return MsgTfsync +} + +// String implements fmt.Stringer. +func (t *Tfsync) String() string { + return fmt.Sprintf("Tfsync{FID: %d}", t.FID) +} + +// Rfsync is an fsync response. +type Rfsync struct { +} + +// decode implements encoder.decode. +func (*Rfsync) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rfsync) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rfsync) Type() MsgType { + return MsgRfsync +} + +// String implements fmt.Stringer. +func (r *Rfsync) String() string { + return "Rfsync{}" +} + +// Tstatfs is a stat request. +type Tstatfs struct { + // FID is the root. + FID FID +} + +// decode implements encoder.decode. +func (t *Tstatfs) decode(b *buffer) { + t.FID = b.ReadFID() +} + +// encode implements encoder.encode. +func (t *Tstatfs) encode(b *buffer) { + b.WriteFID(t.FID) +} + +// Type implements message.Type. +func (*Tstatfs) Type() MsgType { + return MsgTstatfs +} + +// String implements fmt.Stringer. +func (t *Tstatfs) String() string { + return fmt.Sprintf("Tstatfs{FID: %d}", t.FID) +} + +// Rstatfs is the response for a Tstatfs. +type Rstatfs struct { + // FSStat is the stat result. + FSStat FSStat +} + +// decode implements encoder.decode. +func (r *Rstatfs) decode(b *buffer) { + r.FSStat.decode(b) +} + +// encode implements encoder.encode. +func (r *Rstatfs) encode(b *buffer) { + r.FSStat.encode(b) +} + +// Type implements message.Type. +func (*Rstatfs) Type() MsgType { + return MsgRstatfs +} + +// String implements fmt.Stringer. +func (r *Rstatfs) String() string { + return fmt.Sprintf("Rstatfs{FSStat: %v}", r.FSStat) +} + +// Tflushf is a flush file request, not to be confused with Tflush. +type Tflushf struct { + // FID is the FID to be flushed. + FID FID +} + +// decode implements encoder.decode. +func (t *Tflushf) decode(b *buffer) { + t.FID = b.ReadFID() +} + +// encode implements encoder.encode. +func (t *Tflushf) encode(b *buffer) { + b.WriteFID(t.FID) +} + +// Type implements message.Type. +func (*Tflushf) Type() MsgType { + return MsgTflushf +} + +// String implements fmt.Stringer. +func (t *Tflushf) String() string { + return fmt.Sprintf("Tflushf{FID: %d}", t.FID) +} + +// Rflushf is a flush file response. +type Rflushf struct { +} + +// decode implements encoder.decode. +func (*Rflushf) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rflushf) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rflushf) Type() MsgType { + return MsgRflushf +} + +// String implements fmt.Stringer. +func (*Rflushf) String() string { + return "Rflushf{}" +} + +// Twalkgetattr is a walk request. +type Twalkgetattr struct { + // FID is the FID to be walked. + FID FID + + // NewFID is the resulting FID. + NewFID FID + + // Names are the set of names to be walked. + Names []string +} + +// decode implements encoder.decode. +func (t *Twalkgetattr) decode(b *buffer) { + t.FID = b.ReadFID() + t.NewFID = b.ReadFID() + n := b.Read16() + t.Names = t.Names[:0] + for i := 0; i < int(n); i++ { + t.Names = append(t.Names, b.ReadString()) + } +} + +// encode implements encoder.encode. +func (t *Twalkgetattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteFID(t.NewFID) + b.Write16(uint16(len(t.Names))) + for _, name := range t.Names { + b.WriteString(name) + } +} + +// Type implements message.Type. +func (*Twalkgetattr) Type() MsgType { + return MsgTwalkgetattr +} + +// String implements fmt.Stringer. +func (t *Twalkgetattr) String() string { + return fmt.Sprintf("Twalkgetattr{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names) +} + +// Rwalkgetattr is a walk response. +type Rwalkgetattr struct { + // Valid indicates which fields are valid in the Attr below. + Valid AttrMask + + // Attr is the set of attributes for the last QID (the file walked to). + Attr Attr + + // QIDs are the set of QIDs returned. + QIDs []QID +} + +// decode implements encoder.decode. +func (r *Rwalkgetattr) decode(b *buffer) { + r.Valid.decode(b) + r.Attr.decode(b) + n := b.Read16() + r.QIDs = r.QIDs[:0] + for i := 0; i < int(n); i++ { + var q QID + q.decode(b) + r.QIDs = append(r.QIDs, q) + } +} + +// encode implements encoder.encode. +func (r *Rwalkgetattr) encode(b *buffer) { + r.Valid.encode(b) + r.Attr.encode(b) + b.Write16(uint16(len(r.QIDs))) + for _, q := range r.QIDs { + q.encode(b) + } +} + +// Type implements message.Type. +func (*Rwalkgetattr) Type() MsgType { + return MsgRwalkgetattr +} + +// String implements fmt.Stringer. +func (r *Rwalkgetattr) String() string { + return fmt.Sprintf("Rwalkgetattr{Valid: %s, Attr: %s, QIDs: %v}", r.Valid, r.Attr, r.QIDs) +} + +// Tucreate is a Tlcreate message that includes a UID. +type Tucreate struct { + Tlcreate + + // UID is the UID to use as the effective UID in creation messages. + UID UID +} + +// decode implements encoder.decode. +func (t *Tucreate) decode(b *buffer) { + t.Tlcreate.decode(b) + t.UID = b.ReadUID() +} + +// encode implements encoder.encode. +func (t *Tucreate) encode(b *buffer) { + t.Tlcreate.encode(b) + b.WriteUID(t.UID) +} + +// Type implements message.Type. +func (t *Tucreate) Type() MsgType { + return MsgTucreate +} + +// String implements fmt.Stringer. +func (t *Tucreate) String() string { + return fmt.Sprintf("Tucreate{Tlcreate: %v, UID: %d}", &t.Tlcreate, t.UID) +} + +// Rucreate is a file creation response. +type Rucreate struct { + Rlcreate +} + +// Type implements message.Type. +func (*Rucreate) Type() MsgType { + return MsgRucreate +} + +// String implements fmt.Stringer. +func (r *Rucreate) String() string { + return fmt.Sprintf("Rucreate{%v}", &r.Rlcreate) +} + +// Tumkdir is a Tmkdir message that includes a UID. +type Tumkdir struct { + Tmkdir + + // UID is the UID to use as the effective UID in creation messages. + UID UID +} + +// decode implements encoder.decode. +func (t *Tumkdir) decode(b *buffer) { + t.Tmkdir.decode(b) + t.UID = b.ReadUID() +} + +// encode implements encoder.encode. +func (t *Tumkdir) encode(b *buffer) { + t.Tmkdir.encode(b) + b.WriteUID(t.UID) +} + +// Type implements message.Type. +func (t *Tumkdir) Type() MsgType { + return MsgTumkdir +} + +// String implements fmt.Stringer. +func (t *Tumkdir) String() string { + return fmt.Sprintf("Tumkdir{Tmkdir: %v, UID: %d}", &t.Tmkdir, t.UID) +} + +// Rumkdir is a umkdir response. +type Rumkdir struct { + Rmkdir +} + +// Type implements message.Type. +func (*Rumkdir) Type() MsgType { + return MsgRumkdir +} + +// String implements fmt.Stringer. +func (r *Rumkdir) String() string { + return fmt.Sprintf("Rumkdir{%v}", &r.Rmkdir) +} + +// Tumknod is a Tmknod message that includes a UID. +type Tumknod struct { + Tmknod + + // UID is the UID to use as the effective UID in creation messages. + UID UID +} + +// decode implements encoder.decode. +func (t *Tumknod) decode(b *buffer) { + t.Tmknod.decode(b) + t.UID = b.ReadUID() +} + +// encode implements encoder.encode. +func (t *Tumknod) encode(b *buffer) { + t.Tmknod.encode(b) + b.WriteUID(t.UID) +} + +// Type implements message.Type. +func (t *Tumknod) Type() MsgType { + return MsgTumknod +} + +// String implements fmt.Stringer. +func (t *Tumknod) String() string { + return fmt.Sprintf("Tumknod{Tmknod: %v, UID: %d}", &t.Tmknod, t.UID) +} + +// Rumknod is a umknod response. +type Rumknod struct { + Rmknod +} + +// Type implements message.Type. +func (*Rumknod) Type() MsgType { + return MsgRumknod +} + +// String implements fmt.Stringer. +func (r *Rumknod) String() string { + return fmt.Sprintf("Rumknod{%v}", &r.Rmknod) +} + +// Tusymlink is a Tsymlink message that includes a UID. +type Tusymlink struct { + Tsymlink + + // UID is the UID to use as the effective UID in creation messages. + UID UID +} + +// decode implements encoder.decode. +func (t *Tusymlink) decode(b *buffer) { + t.Tsymlink.decode(b) + t.UID = b.ReadUID() +} + +// encode implements encoder.encode. +func (t *Tusymlink) encode(b *buffer) { + t.Tsymlink.encode(b) + b.WriteUID(t.UID) +} + +// Type implements message.Type. +func (t *Tusymlink) Type() MsgType { + return MsgTusymlink +} + +// String implements fmt.Stringer. +func (t *Tusymlink) String() string { + return fmt.Sprintf("Tusymlink{Tsymlink: %v, UID: %d}", &t.Tsymlink, t.UID) +} + +// Rusymlink is a usymlink response. +type Rusymlink struct { + Rsymlink +} + +// Type implements message.Type. +func (*Rusymlink) Type() MsgType { + return MsgRusymlink +} + +// String implements fmt.Stringer. +func (r *Rusymlink) String() string { + return fmt.Sprintf("Rusymlink{%v}", &r.Rsymlink) +} + +// Tlconnect is a connect request. +type Tlconnect struct { + // FID is the FID to be connected. + FID FID + + // Flags are the connect flags. + Flags ConnectFlags +} + +// decode implements encoder.decode. +func (t *Tlconnect) decode(b *buffer) { + t.FID = b.ReadFID() + t.Flags = b.ReadConnectFlags() +} + +// encode implements encoder.encode. +func (t *Tlconnect) encode(b *buffer) { + b.WriteFID(t.FID) + b.WriteConnectFlags(t.Flags) +} + +// Type implements message.Type. +func (*Tlconnect) Type() MsgType { + return MsgTlconnect +} + +// String implements fmt.Stringer. +func (t *Tlconnect) String() string { + return fmt.Sprintf("Tlconnect{FID: %d, Flags: %v}", t.FID, t.Flags) +} + +// Rlconnect is a connect response. +type Rlconnect struct { + filePayload +} + +// decode implements encoder.decode. +func (r *Rlconnect) decode(*buffer) {} + +// encode implements encoder.encode. +func (r *Rlconnect) encode(*buffer) {} + +// Type implements message.Type. +func (*Rlconnect) Type() MsgType { + return MsgRlconnect +} + +// String implements fmt.Stringer. +func (r *Rlconnect) String() string { + return fmt.Sprintf("Rlconnect{File: %v}", r.File) +} + +// Tchannel creates a new channel. +type Tchannel struct { + // ID is the channel ID. + ID uint32 + + // Control is 0 if the Rchannel response should provide the flipcall + // component of the channel, and 1 if the Rchannel response should + // provide the fdchannel component of the channel. + Control uint32 +} + +// decode implements encoder.decode. +func (t *Tchannel) decode(b *buffer) { + t.ID = b.Read32() + t.Control = b.Read32() +} + +// encode implements encoder.encode. +func (t *Tchannel) encode(b *buffer) { + b.Write32(t.ID) + b.Write32(t.Control) +} + +// Type implements message.Type. +func (*Tchannel) Type() MsgType { + return MsgTchannel +} + +// String implements fmt.Stringer. +func (t *Tchannel) String() string { + return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control) +} + +// Rchannel is the channel response. +type Rchannel struct { + Offset uint64 + Length uint64 + filePayload +} + +// decode implements encoder.decode. +func (r *Rchannel) decode(b *buffer) { + r.Offset = b.Read64() + r.Length = b.Read64() +} + +// encode implements encoder.encode. +func (r *Rchannel) encode(b *buffer) { + b.Write64(r.Offset) + b.Write64(r.Length) +} + +// Type implements message.Type. +func (*Rchannel) Type() MsgType { + return MsgRchannel +} + +// String implements fmt.Stringer. +func (r *Rchannel) String() string { + return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) +} + +const maxCacheSize = 3 + +// msgFactory is used to reduce allocations by caching messages for reuse. +type msgFactory struct { + create func() message + cache chan message +} + +// msgRegistry indexes all message factories by type. +var msgRegistry registry + +type registry struct { + factories [math.MaxUint8]msgFactory + + // largestFixedSize is computed so that given some message size M, you can + // compute the maximum payload size (e.g. for Twrite, Rread) with + // M-largestFixedSize. You could do this individual on a per-message basis, + // but it's easier to compute a single maximum safe payload. + largestFixedSize uint32 +} + +// get returns a new message by type. +// +// An error is returned in the case of an unknown message. +// +// This takes, and ignores, a message tag so that it may be used directly as a +// lookupTagAndType function for recv (by design). +func (r *registry) get(_ Tag, t MsgType) (message, error) { + entry := &r.factories[t] + if entry.create == nil { + return nil, &ErrInvalidMsgType{t} + } + + select { + case msg := <-entry.cache: + return msg, nil + default: + return entry.create(), nil + } +} + +func (r *registry) put(msg message) { + if p, ok := msg.(payloader); ok { + p.SetPayload(nil) + } + if f, ok := msg.(filer); ok { + f.SetFilePayload(nil) + } + + entry := &r.factories[msg.Type()] + select { + case entry.cache <- msg: + default: + } +} + +// register registers the given message type. +// +// This may cause panic on failure and should only be used from init. +func (r *registry) register(t MsgType, fn func() message) { + if int(t) >= len(r.factories) { + panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(r.factories))) + } + if r.factories[t].create != nil { + panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, r.factories[t].create(), fn())) + } + r.factories[t] = msgFactory{ + create: fn, + cache: make(chan message, maxCacheSize), + } + + if size := calculateSize(fn()); size > r.largestFixedSize { + r.largestFixedSize = size + } +} + +func calculateSize(m message) uint32 { + if p, ok := m.(payloader); ok { + return p.FixedSize() + } + var dataBuf buffer + m.encode(&dataBuf) + return uint32(len(dataBuf.data)) +} + +func init() { + msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} }) + msgRegistry.register(MsgTstatfs, func() message { return &Tstatfs{} }) + msgRegistry.register(MsgRstatfs, func() message { return &Rstatfs{} }) + msgRegistry.register(MsgTlopen, func() message { return &Tlopen{} }) + msgRegistry.register(MsgRlopen, func() message { return &Rlopen{} }) + msgRegistry.register(MsgTlcreate, func() message { return &Tlcreate{} }) + msgRegistry.register(MsgRlcreate, func() message { return &Rlcreate{} }) + msgRegistry.register(MsgTsymlink, func() message { return &Tsymlink{} }) + msgRegistry.register(MsgRsymlink, func() message { return &Rsymlink{} }) + msgRegistry.register(MsgTmknod, func() message { return &Tmknod{} }) + msgRegistry.register(MsgRmknod, func() message { return &Rmknod{} }) + msgRegistry.register(MsgTrename, func() message { return &Trename{} }) + msgRegistry.register(MsgRrename, func() message { return &Rrename{} }) + msgRegistry.register(MsgTreadlink, func() message { return &Treadlink{} }) + msgRegistry.register(MsgRreadlink, func() message { return &Rreadlink{} }) + msgRegistry.register(MsgTgetattr, func() message { return &Tgetattr{} }) + msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} }) + msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} }) + msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} }) + msgRegistry.register(MsgTlistxattr, func() message { return &Tlistxattr{} }) + msgRegistry.register(MsgRlistxattr, func() message { return &Rlistxattr{} }) + msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} }) + msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} }) + msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} }) + msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} }) + msgRegistry.register(MsgTgetxattr, func() message { return &Tgetxattr{} }) + msgRegistry.register(MsgRgetxattr, func() message { return &Rgetxattr{} }) + msgRegistry.register(MsgTsetxattr, func() message { return &Tsetxattr{} }) + msgRegistry.register(MsgRsetxattr, func() message { return &Rsetxattr{} }) + msgRegistry.register(MsgTremovexattr, func() message { return &Tremovexattr{} }) + msgRegistry.register(MsgRremovexattr, func() message { return &Rremovexattr{} }) + msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} }) + msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} }) + msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} }) + msgRegistry.register(MsgRfsync, func() message { return &Rfsync{} }) + msgRegistry.register(MsgTlink, func() message { return &Tlink{} }) + msgRegistry.register(MsgRlink, func() message { return &Rlink{} }) + msgRegistry.register(MsgTmkdir, func() message { return &Tmkdir{} }) + msgRegistry.register(MsgRmkdir, func() message { return &Rmkdir{} }) + msgRegistry.register(MsgTrenameat, func() message { return &Trenameat{} }) + msgRegistry.register(MsgRrenameat, func() message { return &Rrenameat{} }) + msgRegistry.register(MsgTunlinkat, func() message { return &Tunlinkat{} }) + msgRegistry.register(MsgRunlinkat, func() message { return &Runlinkat{} }) + msgRegistry.register(MsgTversion, func() message { return &Tversion{} }) + msgRegistry.register(MsgRversion, func() message { return &Rversion{} }) + msgRegistry.register(MsgTauth, func() message { return &Tauth{} }) + msgRegistry.register(MsgRauth, func() message { return &Rauth{} }) + msgRegistry.register(MsgTattach, func() message { return &Tattach{} }) + msgRegistry.register(MsgRattach, func() message { return &Rattach{} }) + msgRegistry.register(MsgTflush, func() message { return &Tflush{} }) + msgRegistry.register(MsgRflush, func() message { return &Rflush{} }) + msgRegistry.register(MsgTwalk, func() message { return &Twalk{} }) + msgRegistry.register(MsgRwalk, func() message { return &Rwalk{} }) + msgRegistry.register(MsgTread, func() message { return &Tread{} }) + msgRegistry.register(MsgRread, func() message { return &Rread{} }) + msgRegistry.register(MsgTwrite, func() message { return &Twrite{} }) + msgRegistry.register(MsgRwrite, func() message { return &Rwrite{} }) + msgRegistry.register(MsgTclunk, func() message { return &Tclunk{} }) + msgRegistry.register(MsgRclunk, func() message { return &Rclunk{} }) + msgRegistry.register(MsgTremove, func() message { return &Tremove{} }) + msgRegistry.register(MsgRremove, func() message { return &Rremove{} }) + msgRegistry.register(MsgTflushf, func() message { return &Tflushf{} }) + msgRegistry.register(MsgRflushf, func() message { return &Rflushf{} }) + msgRegistry.register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} }) + msgRegistry.register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} }) + msgRegistry.register(MsgTucreate, func() message { return &Tucreate{} }) + msgRegistry.register(MsgRucreate, func() message { return &Rucreate{} }) + msgRegistry.register(MsgTumkdir, func() message { return &Tumkdir{} }) + msgRegistry.register(MsgRumkdir, func() message { return &Rumkdir{} }) + msgRegistry.register(MsgTumknod, func() message { return &Tumknod{} }) + msgRegistry.register(MsgRumknod, func() message { return &Rumknod{} }) + msgRegistry.register(MsgTusymlink, func() message { return &Tusymlink{} }) + msgRegistry.register(MsgRusymlink, func() message { return &Rusymlink{} }) + msgRegistry.register(MsgTlconnect, func() message { return &Tlconnect{} }) + msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) + msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) + msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) + msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) +} diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go new file mode 100644 index 000000000..7facc9f5e --- /dev/null +++ b/pkg/p9/messages_test.go @@ -0,0 +1,483 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "fmt" + "reflect" + "testing" +) + +func TestEncodeDecode(t *testing.T) { + objs := []encoder{ + &QID{ + Type: 1, + Version: 2, + Path: 3, + }, + &FSStat{ + Type: 1, + BlockSize: 2, + Blocks: 3, + BlocksFree: 4, + BlocksAvailable: 5, + Files: 6, + FilesFree: 7, + FSID: 8, + NameLength: 9, + }, + &AttrMask{ + Mode: true, + NLink: true, + UID: true, + GID: true, + RDev: true, + ATime: true, + MTime: true, + CTime: true, + INo: true, + Size: true, + Blocks: true, + BTime: true, + Gen: true, + DataVersion: true, + }, + &Attr{ + Mode: Exec, + UID: 2, + GID: 3, + NLink: 4, + RDev: 5, + Size: 6, + BlockSize: 7, + Blocks: 8, + ATimeSeconds: 9, + ATimeNanoSeconds: 10, + MTimeSeconds: 11, + MTimeNanoSeconds: 12, + CTimeSeconds: 13, + CTimeNanoSeconds: 14, + BTimeSeconds: 15, + BTimeNanoSeconds: 16, + Gen: 17, + DataVersion: 18, + }, + &SetAttrMask{ + Permissions: true, + UID: true, + GID: true, + Size: true, + ATime: true, + MTime: true, + CTime: true, + ATimeNotSystemTime: true, + MTimeNotSystemTime: true, + }, + &SetAttr{ + Permissions: 1, + UID: 2, + GID: 3, + Size: 4, + ATimeSeconds: 5, + ATimeNanoSeconds: 6, + MTimeSeconds: 7, + MTimeNanoSeconds: 8, + }, + &Dirent{ + QID: QID{Type: 1}, + Offset: 2, + Type: 3, + Name: "a", + }, + &Rlerror{ + Error: 1, + }, + &Tstatfs{ + FID: 1, + }, + &Rstatfs{ + FSStat: FSStat{Type: 1}, + }, + &Tlopen{ + FID: 1, + Flags: WriteOnly, + }, + &Rlopen{ + QID: QID{Type: 1}, + IoUnit: 2, + }, + &Tlconnect{ + FID: 1, + }, + &Rlconnect{}, + &Tlcreate{ + FID: 1, + Name: "a", + OpenFlags: 2, + Permissions: 3, + GID: 4, + }, + &Rlcreate{ + Rlopen{QID: QID{Type: 1}}, + }, + &Tsymlink{ + Directory: 1, + Name: "a", + Target: "b", + GID: 2, + }, + &Rsymlink{ + QID: QID{Type: 1}, + }, + &Tmknod{ + Directory: 1, + Name: "a", + Mode: 2, + Major: 3, + Minor: 4, + GID: 5, + }, + &Rmknod{ + QID: QID{Type: 1}, + }, + &Trename{ + FID: 1, + Directory: 2, + Name: "a", + }, + &Rrename{}, + &Treadlink{ + FID: 1, + }, + &Rreadlink{ + Target: "a", + }, + &Tgetattr{ + FID: 1, + AttrMask: AttrMask{Mode: true}, + }, + &Rgetattr{ + Valid: AttrMask{Mode: true}, + QID: QID{Type: 1}, + Attr: Attr{Mode: Write}, + }, + &Tsetattr{ + FID: 1, + Valid: SetAttrMask{Permissions: true}, + SetAttr: SetAttr{Permissions: Write}, + }, + &Rsetattr{}, + &Txattrwalk{ + FID: 1, + NewFID: 2, + Name: "a", + }, + &Rxattrwalk{ + Size: 1, + }, + &Txattrcreate{ + FID: 1, + Name: "a", + AttrSize: 2, + Flags: 3, + }, + &Rxattrcreate{}, + &Tgetxattr{ + FID: 1, + Name: "abc", + Size: 2, + }, + &Rgetxattr{ + Value: "xyz", + }, + &Tsetxattr{ + FID: 1, + Name: "abc", + Value: "xyz", + Flags: 2, + }, + &Rsetxattr{}, + &Treaddir{ + Directory: 1, + Offset: 2, + Count: 3, + }, + &Rreaddir{ + // Count must be sufficient to encode a dirent. + Count: 0x1a, + Entries: []Dirent{{QID: QID{Type: 2}}}, + }, + &Tfsync{ + FID: 1, + }, + &Rfsync{}, + &Tlink{ + Directory: 1, + Target: 2, + Name: "a", + }, + &Rlink{}, + &Tmkdir{ + Directory: 1, + Name: "a", + Permissions: 2, + GID: 3, + }, + &Rmkdir{ + QID: QID{Type: 1}, + }, + &Trenameat{ + OldDirectory: 1, + OldName: "a", + NewDirectory: 2, + NewName: "b", + }, + &Rrenameat{}, + &Tunlinkat{ + Directory: 1, + Name: "a", + Flags: 2, + }, + &Runlinkat{}, + &Tversion{ + MSize: 1, + Version: "a", + }, + &Rversion{ + MSize: 1, + Version: "a", + }, + &Tauth{ + AuthenticationFID: 1, + UserName: "a", + AttachName: "b", + UID: 2, + }, + &Rauth{ + QID: QID{Type: 1}, + }, + &Tattach{ + FID: 1, + Auth: Tauth{AuthenticationFID: 2}, + }, + &Rattach{ + QID: QID{Type: 1}, + }, + &Tflush{ + OldTag: 1, + }, + &Rflush{}, + &Twalk{ + FID: 1, + NewFID: 2, + Names: []string{"a"}, + }, + &Rwalk{ + QIDs: []QID{{Type: 1}}, + }, + &Tread{ + FID: 1, + Offset: 2, + Count: 3, + }, + &Rread{ + Data: []byte{'a'}, + }, + &Twrite{ + FID: 1, + Offset: 2, + Data: []byte{'a'}, + }, + &Rwrite{ + Count: 1, + }, + &Tclunk{ + FID: 1, + }, + &Rclunk{}, + &Tremove{ + FID: 1, + }, + &Rremove{}, + &Tflushf{ + FID: 1, + }, + &Rflushf{}, + &Twalkgetattr{ + FID: 1, + NewFID: 2, + Names: []string{"a"}, + }, + &Rwalkgetattr{ + QIDs: []QID{{Type: 1}}, + Valid: AttrMask{Mode: true}, + Attr: Attr{Mode: Write}, + }, + &Tucreate{ + Tlcreate: Tlcreate{ + FID: 1, + Name: "a", + OpenFlags: 2, + Permissions: 3, + GID: 4, + }, + UID: 5, + }, + &Rucreate{ + Rlcreate{Rlopen{QID: QID{Type: 1}}}, + }, + &Tumkdir{ + Tmkdir: Tmkdir{ + Directory: 1, + Name: "a", + Permissions: 2, + GID: 3, + }, + UID: 4, + }, + &Rumkdir{ + Rmkdir{QID: QID{Type: 1}}, + }, + &Tusymlink{ + Tsymlink: Tsymlink{ + Directory: 1, + Name: "a", + Target: "b", + GID: 2, + }, + UID: 3, + }, + &Rusymlink{ + Rsymlink{QID: QID{Type: 1}}, + }, + &Tumknod{ + Tmknod: Tmknod{ + Directory: 1, + Name: "a", + Mode: 2, + Major: 3, + Minor: 4, + GID: 5, + }, + UID: 6, + }, + &Rumknod{ + Rmknod{QID: QID{Type: 1}}, + }, + } + + for _, enc := range objs { + // Encode the original. + data := make([]byte, initialBufferLength) + buf := buffer{data: data[:0]} + enc.encode(&buf) + + // Create a new object, same as the first. + enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder) + buf2 := buffer{data: buf.data} + + // To be fair, we need to add any payloads (directly). + if pl, ok := enc.(payloader); ok { + enc2.(payloader).SetPayload(pl.Payload()) + } + + // And any file payloads (directly). + if fl, ok := enc.(filer); ok { + enc2.(filer).SetFilePayload(fl.FilePayload()) + } + + // Mark sure it was okay. + enc2.decode(&buf2) + if buf2.isOverrun() { + t.Errorf("object %#v->%#v got overrun on decode", enc, enc2) + continue + } + + // Check that they are equal. + if !reflect.DeepEqual(enc, enc2) { + t.Errorf("object %#v and %#v differ", enc, enc2) + continue + } + } +} + +func TestMessageStrings(t *testing.T) { + for typ := range msgRegistry.factories { + entry := &msgRegistry.factories[typ] + if entry.create != nil { + name := fmt.Sprintf("%+v", typ) + t.Run(name, func(t *testing.T) { + defer func() { // Ensure no panic. + if r := recover(); r != nil { + t.Errorf("printing %s failed: %v", name, r) + } + }() + m := entry.create() + _ = fmt.Sprintf("%v", m) + err := ErrInvalidMsgType{MsgType(typ)} + _ = err.Error() + }) + } + } +} + +func TestRegisterDuplicate(t *testing.T) { + defer func() { + if r := recover(); r == nil { + // We expect a panic. + t.FailNow() + } + }() + + // Register a duplicate. + msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} }) +} + +func TestMsgCache(t *testing.T) { + // Cache starts empty. + if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } + + // Message can be created with an empty cache. + msg, err := msgRegistry.get(0, MsgRlerror) + if err != nil { + t.Errorf("msgRegistry.get(): %v", err) + } + if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } + + // Check that message is added to the cache when returned. + msgRegistry.put(msg) + if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } + + // Check that returned message is reused. + if got, err := msgRegistry.get(0, MsgRlerror); err != nil { + t.Errorf("msgRegistry.get(): %v", err) + } else if msg != got { + t.Errorf("Message not reused, got: %d, want: %d", got, msg) + } + + // Check that cache doesn't grow beyond max size. + for i := 0; i < maxCacheSize+1; i++ { + msgRegistry.put(&Rlerror{}) + } + if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } +} diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go new file mode 100644 index 000000000..28d851ff5 --- /dev/null +++ b/pkg/p9/p9.go @@ -0,0 +1,1158 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package p9 is a 9P2000.L implementation. +package p9 + +import ( + "fmt" + "math" + "os" + "strings" + "sync/atomic" + "syscall" + + "golang.org/x/sys/unix" +) + +// OpenFlags is the mode passed to Open and Create operations. +// +// These correspond to bits sent over the wire. +type OpenFlags uint32 + +const ( + // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode. + ReadOnly OpenFlags = 0 + + // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode. + WriteOnly OpenFlags = 1 + + // ReadWrite is a Tlopen flag indicates read-write mode. + ReadWrite OpenFlags = 2 + + // OpenFlagsModeMask is a mask of valid OpenFlags mode bits. + OpenFlagsModeMask OpenFlags = 3 + + // OpenTruncate is a Tlopen flag indicating that the opened file should be + // truncated. + OpenTruncate OpenFlags = 01000 +) + +// ConnectFlags is the mode passed to Connect operations. +// +// These correspond to bits sent over the wire. +type ConnectFlags uint32 + +const ( + // StreamSocket is a Tlconnect flag indicating SOCK_STREAM mode. + StreamSocket ConnectFlags = 0 + + // DgramSocket is a Tlconnect flag indicating SOCK_DGRAM mode. + DgramSocket ConnectFlags = 1 + + // SeqpacketSocket is a Tlconnect flag indicating SOCK_SEQPACKET mode. + SeqpacketSocket ConnectFlags = 2 + + // AnonymousSocket is a Tlconnect flag indicating that the mode does not + // matter and that the requester will accept any socket type. + AnonymousSocket ConnectFlags = 3 +) + +// OSFlags converts a p9.OpenFlags to an int compatible with open(2). +func (o OpenFlags) OSFlags() int { + // "flags contains Linux open(2) flags bits" - 9P2000.L + return int(o) +} + +// String implements fmt.Stringer. +func (o OpenFlags) String() string { + var buf strings.Builder + switch mode := o & OpenFlagsModeMask; mode { + case ReadOnly: + buf.WriteString("ReadOnly") + case WriteOnly: + buf.WriteString("WriteOnly") + case ReadWrite: + buf.WriteString("ReadWrite") + default: + fmt.Fprintf(&buf, "%#o", mode) + } + otherFlags := o &^ OpenFlagsModeMask + if otherFlags&OpenTruncate != 0 { + buf.WriteString("|OpenTruncate") + otherFlags &^= OpenTruncate + } + if otherFlags != 0 { + fmt.Fprintf(&buf, "|%#o", otherFlags) + } + return buf.String() +} + +// Tag is a message tag. +type Tag uint16 + +// FID is a file identifier. +type FID uint64 + +// FileMode are flags corresponding to file modes. +// +// These correspond to bits sent over the wire. +// These also correspond to mode_t bits. +type FileMode uint32 + +const ( + // FileModeMask is a mask of all the file mode bits of FileMode. + FileModeMask FileMode = 0170000 + + // ModeSocket is an (unused) mode bit for a socket. + ModeSocket FileMode = 0140000 + + // ModeSymlink is a mode bit for a symlink. + ModeSymlink FileMode = 0120000 + + // ModeRegular is a mode bit for regular files. + ModeRegular FileMode = 0100000 + + // ModeBlockDevice is a mode bit for block devices. + ModeBlockDevice FileMode = 060000 + + // ModeDirectory is a mode bit for directories. + ModeDirectory FileMode = 040000 + + // ModeCharacterDevice is a mode bit for a character device. + ModeCharacterDevice FileMode = 020000 + + // ModeNamedPipe is a mode bit for a named pipe. + ModeNamedPipe FileMode = 010000 + + // Read is a mode bit indicating read permission. + Read FileMode = 04 + + // Write is a mode bit indicating write permission. + Write FileMode = 02 + + // Exec is a mode bit indicating exec permission. + Exec FileMode = 01 + + // AllPermissions is a mask with rwx bits set for user, group and others. + AllPermissions FileMode = 0777 + + // Sticky is a mode bit indicating sticky directories. + Sticky FileMode = 01000 + + // permissionsMask is the mask to apply to FileModes for permissions. It + // includes rwx bits for user, group and others, and sticky bit. + permissionsMask FileMode = 01777 +) + +// QIDType is the most significant byte of the FileMode word, to be used as the +// Type field of p9.QID. +func (m FileMode) QIDType() QIDType { + switch { + case m.IsDir(): + return TypeDir + case m.IsSocket(), m.IsNamedPipe(), m.IsCharacterDevice(): + // Best approximation. + return TypeAppendOnly + case m.IsSymlink(): + return TypeSymlink + default: + return TypeRegular + } +} + +// FileType returns the file mode without the permission bits. +func (m FileMode) FileType() FileMode { + return m & FileModeMask +} + +// Permissions returns just the permission bits of the mode. +func (m FileMode) Permissions() FileMode { + return m & permissionsMask +} + +// Writable returns the mode with write bits added. +func (m FileMode) Writable() FileMode { + return m | 0222 +} + +// IsReadable returns true if m represents a file that can be read. +func (m FileMode) IsReadable() bool { + return m&0444 != 0 +} + +// IsWritable returns true if m represents a file that can be written to. +func (m FileMode) IsWritable() bool { + return m&0222 != 0 +} + +// IsExecutable returns true if m represents a file that can be executed. +func (m FileMode) IsExecutable() bool { + return m&0111 != 0 +} + +// IsRegular returns true if m is a regular file. +func (m FileMode) IsRegular() bool { + return m&FileModeMask == ModeRegular +} + +// IsDir returns true if m represents a directory. +func (m FileMode) IsDir() bool { + return m&FileModeMask == ModeDirectory +} + +// IsNamedPipe returns true if m represents a named pipe. +func (m FileMode) IsNamedPipe() bool { + return m&FileModeMask == ModeNamedPipe +} + +// IsCharacterDevice returns true if m represents a character device. +func (m FileMode) IsCharacterDevice() bool { + return m&FileModeMask == ModeCharacterDevice +} + +// IsBlockDevice returns true if m represents a character device. +func (m FileMode) IsBlockDevice() bool { + return m&FileModeMask == ModeBlockDevice +} + +// IsSocket returns true if m represents a socket. +func (m FileMode) IsSocket() bool { + return m&FileModeMask == ModeSocket +} + +// IsSymlink returns true if m represents a symlink. +func (m FileMode) IsSymlink() bool { + return m&FileModeMask == ModeSymlink +} + +// ModeFromOS returns a FileMode from an os.FileMode. +func ModeFromOS(mode os.FileMode) FileMode { + m := FileMode(mode.Perm()) + switch { + case mode.IsDir(): + m |= ModeDirectory + case mode&os.ModeSymlink != 0: + m |= ModeSymlink + case mode&os.ModeSocket != 0: + m |= ModeSocket + case mode&os.ModeNamedPipe != 0: + m |= ModeNamedPipe + case mode&os.ModeCharDevice != 0: + m |= ModeCharacterDevice + case mode&os.ModeDevice != 0: + m |= ModeBlockDevice + default: + m |= ModeRegular + } + return m +} + +// OSMode converts a p9.FileMode to an os.FileMode. +func (m FileMode) OSMode() os.FileMode { + var osMode os.FileMode + osMode |= os.FileMode(m.Permissions()) + switch { + case m.IsDir(): + osMode |= os.ModeDir + case m.IsSymlink(): + osMode |= os.ModeSymlink + case m.IsSocket(): + osMode |= os.ModeSocket + case m.IsNamedPipe(): + osMode |= os.ModeNamedPipe + case m.IsCharacterDevice(): + osMode |= os.ModeCharDevice | os.ModeDevice + case m.IsBlockDevice(): + osMode |= os.ModeDevice + } + return osMode +} + +// UID represents a user ID. +type UID uint32 + +// Ok returns true if uid is not NoUID. +func (uid UID) Ok() bool { + return uid != NoUID +} + +// GID represents a group ID. +type GID uint32 + +// Ok returns true if gid is not NoGID. +func (gid GID) Ok() bool { + return gid != NoGID +} + +const ( + // NoTag is a sentinel used to indicate no valid tag. + NoTag Tag = math.MaxUint16 + + // NoFID is a sentinel used to indicate no valid FID. + NoFID FID = math.MaxUint32 + + // NoUID is a sentinel used to indicate no valid UID. + NoUID UID = math.MaxUint32 + + // NoGID is a sentinel used to indicate no valid GID. + NoGID GID = math.MaxUint32 +) + +// MsgType is a type identifier. +type MsgType uint8 + +// MsgType declarations. +const ( + MsgTlerror MsgType = 6 + MsgRlerror = 7 + MsgTstatfs = 8 + MsgRstatfs = 9 + MsgTlopen = 12 + MsgRlopen = 13 + MsgTlcreate = 14 + MsgRlcreate = 15 + MsgTsymlink = 16 + MsgRsymlink = 17 + MsgTmknod = 18 + MsgRmknod = 19 + MsgTrename = 20 + MsgRrename = 21 + MsgTreadlink = 22 + MsgRreadlink = 23 + MsgTgetattr = 24 + MsgRgetattr = 25 + MsgTsetattr = 26 + MsgRsetattr = 27 + MsgTlistxattr = 28 + MsgRlistxattr = 29 + MsgTxattrwalk = 30 + MsgRxattrwalk = 31 + MsgTxattrcreate = 32 + MsgRxattrcreate = 33 + MsgTgetxattr = 34 + MsgRgetxattr = 35 + MsgTsetxattr = 36 + MsgRsetxattr = 37 + MsgTremovexattr = 38 + MsgRremovexattr = 39 + MsgTreaddir = 40 + MsgRreaddir = 41 + MsgTfsync = 50 + MsgRfsync = 51 + MsgTlink = 70 + MsgRlink = 71 + MsgTmkdir = 72 + MsgRmkdir = 73 + MsgTrenameat = 74 + MsgRrenameat = 75 + MsgTunlinkat = 76 + MsgRunlinkat = 77 + MsgTversion = 100 + MsgRversion = 101 + MsgTauth = 102 + MsgRauth = 103 + MsgTattach = 104 + MsgRattach = 105 + MsgTflush = 108 + MsgRflush = 109 + MsgTwalk = 110 + MsgRwalk = 111 + MsgTread = 116 + MsgRread = 117 + MsgTwrite = 118 + MsgRwrite = 119 + MsgTclunk = 120 + MsgRclunk = 121 + MsgTremove = 122 + MsgRremove = 123 + MsgTflushf = 124 + MsgRflushf = 125 + MsgTwalkgetattr = 126 + MsgRwalkgetattr = 127 + MsgTucreate = 128 + MsgRucreate = 129 + MsgTumkdir = 130 + MsgRumkdir = 131 + MsgTumknod = 132 + MsgRumknod = 133 + MsgTusymlink = 134 + MsgRusymlink = 135 + MsgTlconnect = 136 + MsgRlconnect = 137 + MsgTallocate = 138 + MsgRallocate = 139 + MsgTchannel = 250 + MsgRchannel = 251 +) + +// QIDType represents the file type for QIDs. +// +// QIDType corresponds to the high 8 bits of a Plan 9 file mode. +type QIDType uint8 + +const ( + // TypeDir represents a directory type. + TypeDir QIDType = 0x80 + + // TypeAppendOnly represents an append only file. + TypeAppendOnly QIDType = 0x40 + + // TypeExclusive represents an exclusive-use file. + TypeExclusive QIDType = 0x20 + + // TypeMount represents a mounted channel. + TypeMount QIDType = 0x10 + + // TypeAuth represents an authentication file. + TypeAuth QIDType = 0x08 + + // TypeTemporary represents a temporary file. + TypeTemporary QIDType = 0x04 + + // TypeSymlink represents a symlink. + TypeSymlink QIDType = 0x02 + + // TypeLink represents a hard link. + TypeLink QIDType = 0x01 + + // TypeRegular represents a regular file. + TypeRegular QIDType = 0x00 +) + +// QID is a unique file identifier. +// +// This may be embedded in other requests and responses. +type QID struct { + // Type is the highest order byte of the file mode. + Type QIDType + + // Version is an arbitrary server version number. + Version uint32 + + // Path is a unique server identifier for this path (e.g. inode). + Path uint64 +} + +// String implements fmt.Stringer. +func (q QID) String() string { + return fmt.Sprintf("QID{Type: %d, Version: %d, Path: %d}", q.Type, q.Version, q.Path) +} + +// decode implements encoder.decode. +func (q *QID) decode(b *buffer) { + q.Type = b.ReadQIDType() + q.Version = b.Read32() + q.Path = b.Read64() +} + +// encode implements encoder.encode. +func (q *QID) encode(b *buffer) { + b.WriteQIDType(q.Type) + b.Write32(q.Version) + b.Write64(q.Path) +} + +// QIDGenerator is a simple generator for QIDs that atomically increments Path +// values. +type QIDGenerator struct { + // uids is an ever increasing value that can be atomically incremented + // to provide unique Path values for QIDs. + uids uint64 +} + +// Get returns a new 9P unique ID with a unique Path given a QID type. +// +// While the 9P spec allows Version to be incremented every time the file is +// modified, we currently do not use the Version member for anything. Hence, +// it is set to 0. +func (q *QIDGenerator) Get(t QIDType) QID { + return QID{ + Type: t, + Version: 0, + Path: atomic.AddUint64(&q.uids, 1), + } +} + +// FSStat is used by statfs. +type FSStat struct { + // Type is the filesystem type. + Type uint32 + + // BlockSize is the blocksize. + BlockSize uint32 + + // Blocks is the number of blocks. + Blocks uint64 + + // BlocksFree is the number of free blocks. + BlocksFree uint64 + + // BlocksAvailable is the number of blocks *available*. + BlocksAvailable uint64 + + // Files is the number of files available. + Files uint64 + + // FilesFree is the number of free file nodes. + FilesFree uint64 + + // FSID is the filesystem ID. + FSID uint64 + + // NameLength is the maximum name length. + NameLength uint32 +} + +// decode implements encoder.decode. +func (f *FSStat) decode(b *buffer) { + f.Type = b.Read32() + f.BlockSize = b.Read32() + f.Blocks = b.Read64() + f.BlocksFree = b.Read64() + f.BlocksAvailable = b.Read64() + f.Files = b.Read64() + f.FilesFree = b.Read64() + f.FSID = b.Read64() + f.NameLength = b.Read32() +} + +// encode implements encoder.encode. +func (f *FSStat) encode(b *buffer) { + b.Write32(f.Type) + b.Write32(f.BlockSize) + b.Write64(f.Blocks) + b.Write64(f.BlocksFree) + b.Write64(f.BlocksAvailable) + b.Write64(f.Files) + b.Write64(f.FilesFree) + b.Write64(f.FSID) + b.Write32(f.NameLength) +} + +// AttrMask is a mask of attributes for getattr. +type AttrMask struct { + Mode bool + NLink bool + UID bool + GID bool + RDev bool + ATime bool + MTime bool + CTime bool + INo bool + Size bool + Blocks bool + BTime bool + Gen bool + DataVersion bool +} + +// Contains returns true if a contains all of the attributes masked as b. +func (a AttrMask) Contains(b AttrMask) bool { + if b.Mode && !a.Mode { + return false + } + if b.NLink && !a.NLink { + return false + } + if b.UID && !a.UID { + return false + } + if b.GID && !a.GID { + return false + } + if b.RDev && !a.RDev { + return false + } + if b.ATime && !a.ATime { + return false + } + if b.MTime && !a.MTime { + return false + } + if b.CTime && !a.CTime { + return false + } + if b.INo && !a.INo { + return false + } + if b.Size && !a.Size { + return false + } + if b.Blocks && !a.Blocks { + return false + } + if b.BTime && !a.BTime { + return false + } + if b.Gen && !a.Gen { + return false + } + if b.DataVersion && !a.DataVersion { + return false + } + return true +} + +// Empty returns true if no fields are masked. +func (a AttrMask) Empty() bool { + return !a.Mode && !a.NLink && !a.UID && !a.GID && !a.RDev && !a.ATime && !a.MTime && !a.CTime && !a.INo && !a.Size && !a.Blocks && !a.BTime && !a.Gen && !a.DataVersion +} + +// AttrMaskAll returns an AttrMask with all fields masked. +func AttrMaskAll() AttrMask { + return AttrMask{ + Mode: true, + NLink: true, + UID: true, + GID: true, + RDev: true, + ATime: true, + MTime: true, + CTime: true, + INo: true, + Size: true, + Blocks: true, + BTime: true, + Gen: true, + DataVersion: true, + } +} + +// String implements fmt.Stringer. +func (a AttrMask) String() string { + var masks []string + if a.Mode { + masks = append(masks, "Mode") + } + if a.NLink { + masks = append(masks, "NLink") + } + if a.UID { + masks = append(masks, "UID") + } + if a.GID { + masks = append(masks, "GID") + } + if a.RDev { + masks = append(masks, "RDev") + } + if a.ATime { + masks = append(masks, "ATime") + } + if a.MTime { + masks = append(masks, "MTime") + } + if a.CTime { + masks = append(masks, "CTime") + } + if a.INo { + masks = append(masks, "INo") + } + if a.Size { + masks = append(masks, "Size") + } + if a.Blocks { + masks = append(masks, "Blocks") + } + if a.BTime { + masks = append(masks, "BTime") + } + if a.Gen { + masks = append(masks, "Gen") + } + if a.DataVersion { + masks = append(masks, "DataVersion") + } + return fmt.Sprintf("AttrMask{with: %s}", strings.Join(masks, " ")) +} + +// decode implements encoder.decode. +func (a *AttrMask) decode(b *buffer) { + mask := b.Read64() + a.Mode = mask&0x00000001 != 0 + a.NLink = mask&0x00000002 != 0 + a.UID = mask&0x00000004 != 0 + a.GID = mask&0x00000008 != 0 + a.RDev = mask&0x00000010 != 0 + a.ATime = mask&0x00000020 != 0 + a.MTime = mask&0x00000040 != 0 + a.CTime = mask&0x00000080 != 0 + a.INo = mask&0x00000100 != 0 + a.Size = mask&0x00000200 != 0 + a.Blocks = mask&0x00000400 != 0 + a.BTime = mask&0x00000800 != 0 + a.Gen = mask&0x00001000 != 0 + a.DataVersion = mask&0x00002000 != 0 +} + +// encode implements encoder.encode. +func (a *AttrMask) encode(b *buffer) { + var mask uint64 + if a.Mode { + mask |= 0x00000001 + } + if a.NLink { + mask |= 0x00000002 + } + if a.UID { + mask |= 0x00000004 + } + if a.GID { + mask |= 0x00000008 + } + if a.RDev { + mask |= 0x00000010 + } + if a.ATime { + mask |= 0x00000020 + } + if a.MTime { + mask |= 0x00000040 + } + if a.CTime { + mask |= 0x00000080 + } + if a.INo { + mask |= 0x00000100 + } + if a.Size { + mask |= 0x00000200 + } + if a.Blocks { + mask |= 0x00000400 + } + if a.BTime { + mask |= 0x00000800 + } + if a.Gen { + mask |= 0x00001000 + } + if a.DataVersion { + mask |= 0x00002000 + } + b.Write64(mask) +} + +// Attr is a set of attributes for getattr. +type Attr struct { + Mode FileMode + UID UID + GID GID + NLink uint64 + RDev uint64 + Size uint64 + BlockSize uint64 + Blocks uint64 + ATimeSeconds uint64 + ATimeNanoSeconds uint64 + MTimeSeconds uint64 + MTimeNanoSeconds uint64 + CTimeSeconds uint64 + CTimeNanoSeconds uint64 + BTimeSeconds uint64 + BTimeNanoSeconds uint64 + Gen uint64 + DataVersion uint64 +} + +// String implements fmt.Stringer. +func (a Attr) String() string { + return fmt.Sprintf("Attr{Mode: 0o%o, UID: %d, GID: %d, NLink: %d, RDev: %d, Size: %d, BlockSize: %d, Blocks: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}, CTime: {Sec: %d, NanoSec: %d}, BTime: {Sec: %d, NanoSec: %d}, Gen: %d, DataVersion: %d}", + a.Mode, a.UID, a.GID, a.NLink, a.RDev, a.Size, a.BlockSize, a.Blocks, a.ATimeSeconds, a.ATimeNanoSeconds, a.MTimeSeconds, a.MTimeNanoSeconds, a.CTimeSeconds, a.CTimeNanoSeconds, a.BTimeSeconds, a.BTimeNanoSeconds, a.Gen, a.DataVersion) +} + +// encode implements encoder.encode. +func (a *Attr) encode(b *buffer) { + b.WriteFileMode(a.Mode) + b.WriteUID(a.UID) + b.WriteGID(a.GID) + b.Write64(a.NLink) + b.Write64(a.RDev) + b.Write64(a.Size) + b.Write64(a.BlockSize) + b.Write64(a.Blocks) + b.Write64(a.ATimeSeconds) + b.Write64(a.ATimeNanoSeconds) + b.Write64(a.MTimeSeconds) + b.Write64(a.MTimeNanoSeconds) + b.Write64(a.CTimeSeconds) + b.Write64(a.CTimeNanoSeconds) + b.Write64(a.BTimeSeconds) + b.Write64(a.BTimeNanoSeconds) + b.Write64(a.Gen) + b.Write64(a.DataVersion) +} + +// decode implements encoder.decode. +func (a *Attr) decode(b *buffer) { + a.Mode = b.ReadFileMode() + a.UID = b.ReadUID() + a.GID = b.ReadGID() + a.NLink = b.Read64() + a.RDev = b.Read64() + a.Size = b.Read64() + a.BlockSize = b.Read64() + a.Blocks = b.Read64() + a.ATimeSeconds = b.Read64() + a.ATimeNanoSeconds = b.Read64() + a.MTimeSeconds = b.Read64() + a.MTimeNanoSeconds = b.Read64() + a.CTimeSeconds = b.Read64() + a.CTimeNanoSeconds = b.Read64() + a.BTimeSeconds = b.Read64() + a.BTimeNanoSeconds = b.Read64() + a.Gen = b.Read64() + a.DataVersion = b.Read64() +} + +// StatToAttr converts a Linux syscall stat structure to an Attr. +func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) { + attr := Attr{ + UID: NoUID, + GID: NoGID, + } + if req.Mode { + // p9.FileMode corresponds to Linux mode_t. + attr.Mode = FileMode(s.Mode) + } + if req.NLink { + attr.NLink = uint64(s.Nlink) + } + if req.UID { + attr.UID = UID(s.Uid) + } + if req.GID { + attr.GID = GID(s.Gid) + } + if req.RDev { + attr.RDev = s.Dev + } + if req.ATime { + attr.ATimeSeconds = uint64(s.Atim.Sec) + attr.ATimeNanoSeconds = uint64(s.Atim.Nsec) + } + if req.MTime { + attr.MTimeSeconds = uint64(s.Mtim.Sec) + attr.MTimeNanoSeconds = uint64(s.Mtim.Nsec) + } + if req.CTime { + attr.CTimeSeconds = uint64(s.Ctim.Sec) + attr.CTimeNanoSeconds = uint64(s.Ctim.Nsec) + } + if req.Size { + attr.Size = uint64(s.Size) + } + if req.Blocks { + attr.BlockSize = uint64(s.Blksize) + attr.Blocks = uint64(s.Blocks) + } + + // Use the req field because we already have it. + req.BTime = false + req.Gen = false + req.DataVersion = false + + return attr, req +} + +// SetAttrMask specifies a valid mask for setattr. +type SetAttrMask struct { + Permissions bool + UID bool + GID bool + Size bool + ATime bool + MTime bool + CTime bool + ATimeNotSystemTime bool + MTimeNotSystemTime bool +} + +// IsSubsetOf returns whether s is a subset of m. +func (s SetAttrMask) IsSubsetOf(m SetAttrMask) bool { + sb := s.bitmask() + sm := m.bitmask() + return sm|sb == sm +} + +// String implements fmt.Stringer. +func (s SetAttrMask) String() string { + var masks []string + if s.Permissions { + masks = append(masks, "Permissions") + } + if s.UID { + masks = append(masks, "UID") + } + if s.GID { + masks = append(masks, "GID") + } + if s.Size { + masks = append(masks, "Size") + } + if s.ATime { + masks = append(masks, "ATime") + } + if s.MTime { + masks = append(masks, "MTime") + } + if s.CTime { + masks = append(masks, "CTime") + } + if s.ATimeNotSystemTime { + masks = append(masks, "ATimeNotSystemTime") + } + if s.MTimeNotSystemTime { + masks = append(masks, "MTimeNotSystemTime") + } + return fmt.Sprintf("SetAttrMask{with: %s}", strings.Join(masks, " ")) +} + +// Empty returns true if no fields are masked. +func (s SetAttrMask) Empty() bool { + return !s.Permissions && !s.UID && !s.GID && !s.Size && !s.ATime && !s.MTime && !s.CTime && !s.ATimeNotSystemTime && !s.MTimeNotSystemTime +} + +// decode implements encoder.decode. +func (s *SetAttrMask) decode(b *buffer) { + mask := b.Read32() + s.Permissions = mask&0x00000001 != 0 + s.UID = mask&0x00000002 != 0 + s.GID = mask&0x00000004 != 0 + s.Size = mask&0x00000008 != 0 + s.ATime = mask&0x00000010 != 0 + s.MTime = mask&0x00000020 != 0 + s.CTime = mask&0x00000040 != 0 + s.ATimeNotSystemTime = mask&0x00000080 != 0 + s.MTimeNotSystemTime = mask&0x00000100 != 0 +} + +func (s SetAttrMask) bitmask() uint32 { + var mask uint32 + if s.Permissions { + mask |= 0x00000001 + } + if s.UID { + mask |= 0x00000002 + } + if s.GID { + mask |= 0x00000004 + } + if s.Size { + mask |= 0x00000008 + } + if s.ATime { + mask |= 0x00000010 + } + if s.MTime { + mask |= 0x00000020 + } + if s.CTime { + mask |= 0x00000040 + } + if s.ATimeNotSystemTime { + mask |= 0x00000080 + } + if s.MTimeNotSystemTime { + mask |= 0x00000100 + } + return mask +} + +// encode implements encoder.encode. +func (s *SetAttrMask) encode(b *buffer) { + b.Write32(s.bitmask()) +} + +// SetAttr specifies a set of attributes for a setattr. +type SetAttr struct { + Permissions FileMode + UID UID + GID GID + Size uint64 + ATimeSeconds uint64 + ATimeNanoSeconds uint64 + MTimeSeconds uint64 + MTimeNanoSeconds uint64 +} + +// String implements fmt.Stringer. +func (s SetAttr) String() string { + return fmt.Sprintf("SetAttr{Permissions: 0o%o, UID: %d, GID: %d, Size: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}}", s.Permissions, s.UID, s.GID, s.Size, s.ATimeSeconds, s.ATimeNanoSeconds, s.MTimeSeconds, s.MTimeNanoSeconds) +} + +// decode implements encoder.decode. +func (s *SetAttr) decode(b *buffer) { + s.Permissions = b.ReadPermissions() + s.UID = b.ReadUID() + s.GID = b.ReadGID() + s.Size = b.Read64() + s.ATimeSeconds = b.Read64() + s.ATimeNanoSeconds = b.Read64() + s.MTimeSeconds = b.Read64() + s.MTimeNanoSeconds = b.Read64() +} + +// encode implements encoder.encode. +func (s *SetAttr) encode(b *buffer) { + b.WritePermissions(s.Permissions) + b.WriteUID(s.UID) + b.WriteGID(s.GID) + b.Write64(s.Size) + b.Write64(s.ATimeSeconds) + b.Write64(s.ATimeNanoSeconds) + b.Write64(s.MTimeSeconds) + b.Write64(s.MTimeNanoSeconds) +} + +// Apply applies this to the given Attr. +func (a *Attr) Apply(mask SetAttrMask, attr SetAttr) { + if mask.Permissions { + a.Mode = a.Mode&^permissionsMask | (attr.Permissions & permissionsMask) + } + if mask.UID { + a.UID = attr.UID + } + if mask.GID { + a.GID = attr.GID + } + if mask.Size { + a.Size = attr.Size + } + if mask.ATime { + a.ATimeSeconds = attr.ATimeSeconds + a.ATimeNanoSeconds = attr.ATimeNanoSeconds + } + if mask.MTime { + a.MTimeSeconds = attr.MTimeSeconds + a.MTimeNanoSeconds = attr.MTimeNanoSeconds + } +} + +// Dirent is used for readdir. +type Dirent struct { + // QID is the entry QID. + QID QID + + // Offset is the offset in the directory. + // + // This will be communicated back the original caller. + Offset uint64 + + // Type is the 9P type. + Type QIDType + + // Name is the name of the entry (i.e. basename). + Name string +} + +// String implements fmt.Stringer. +func (d Dirent) String() string { + return fmt.Sprintf("Dirent{QID: %d, Offset: %d, Type: 0x%X, Name: %s}", d.QID, d.Offset, d.Type, d.Name) +} + +// decode implements encoder.decode. +func (d *Dirent) decode(b *buffer) { + d.QID.decode(b) + d.Offset = b.Read64() + d.Type = b.ReadQIDType() + d.Name = b.ReadString() +} + +// encode implements encoder.encode. +func (d *Dirent) encode(b *buffer) { + d.QID.encode(b) + b.Write64(d.Offset) + b.WriteQIDType(d.Type) + b.WriteString(d.Name) +} + +// AllocateMode are possible modes to p9.File.Allocate(). +type AllocateMode struct { + KeepSize bool + PunchHole bool + NoHideStale bool + CollapseRange bool + ZeroRange bool + InsertRange bool + Unshare bool +} + +// ToLinux converts to a value compatible with fallocate(2)'s mode. +func (a *AllocateMode) ToLinux() uint32 { + rv := uint32(0) + if a.KeepSize { + rv |= unix.FALLOC_FL_KEEP_SIZE + } + if a.PunchHole { + rv |= unix.FALLOC_FL_PUNCH_HOLE + } + if a.NoHideStale { + rv |= unix.FALLOC_FL_NO_HIDE_STALE + } + if a.CollapseRange { + rv |= unix.FALLOC_FL_COLLAPSE_RANGE + } + if a.ZeroRange { + rv |= unix.FALLOC_FL_ZERO_RANGE + } + if a.InsertRange { + rv |= unix.FALLOC_FL_INSERT_RANGE + } + if a.Unshare { + rv |= unix.FALLOC_FL_UNSHARE_RANGE + } + return rv +} + +// decode implements encoder.decode. +func (a *AllocateMode) decode(b *buffer) { + mask := b.Read32() + a.KeepSize = mask&0x01 != 0 + a.PunchHole = mask&0x02 != 0 + a.NoHideStale = mask&0x04 != 0 + a.CollapseRange = mask&0x08 != 0 + a.ZeroRange = mask&0x10 != 0 + a.InsertRange = mask&0x20 != 0 + a.Unshare = mask&0x40 != 0 +} + +// encode implements encoder.encode. +func (a *AllocateMode) encode(b *buffer) { + mask := uint32(0) + if a.KeepSize { + mask |= 0x01 + } + if a.PunchHole { + mask |= 0x02 + } + if a.NoHideStale { + mask |= 0x04 + } + if a.CollapseRange { + mask |= 0x08 + } + if a.ZeroRange { + mask |= 0x10 + } + if a.InsertRange { + mask |= 0x20 + } + if a.Unshare { + mask |= 0x40 + } + b.Write32(mask) +} diff --git a/pkg/p9/p9_test.go b/pkg/p9/p9_test.go new file mode 100644 index 000000000..8dda6cc64 --- /dev/null +++ b/pkg/p9/p9_test.go @@ -0,0 +1,188 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "os" + "testing" +) + +func TestFileModeHelpers(t *testing.T) { + fns := map[FileMode]struct { + // name identifies the file mode. + name string + + // function is the function that should return true given the + // right FileMode. + function func(m FileMode) bool + }{ + ModeRegular: { + name: "regular", + function: FileMode.IsRegular, + }, + ModeDirectory: { + name: "directory", + function: FileMode.IsDir, + }, + ModeNamedPipe: { + name: "named pipe", + function: FileMode.IsNamedPipe, + }, + ModeCharacterDevice: { + name: "character device", + function: FileMode.IsCharacterDevice, + }, + ModeBlockDevice: { + name: "block device", + function: FileMode.IsBlockDevice, + }, + ModeSymlink: { + name: "symlink", + function: FileMode.IsSymlink, + }, + ModeSocket: { + name: "socket", + function: FileMode.IsSocket, + }, + } + for mode, info := range fns { + // Make sure the mode doesn't identify as anything but itself. + for testMode, testfns := range fns { + if mode != testMode && testfns.function(mode) { + t.Errorf("Mode %s returned true when asked if it was mode %s", info.name, testfns.name) + } + } + + // Make sure mode identifies as itself. + if !info.function(mode) { + t.Errorf("Mode %s returned false when asked if it was itself", info.name) + } + } +} + +func TestFileModeToQID(t *testing.T) { + for _, test := range []struct { + // name identifies the test. + name string + + // mode is the FileMode we start out with. + mode FileMode + + // want is the corresponding QIDType we expect. + want QIDType + }{ + { + name: "Directories are of type directory", + mode: ModeDirectory, + want: TypeDir, + }, + { + name: "Sockets are append-only files", + mode: ModeSocket, + want: TypeAppendOnly, + }, + { + name: "Named pipes are append-only files", + mode: ModeNamedPipe, + want: TypeAppendOnly, + }, + { + name: "Character devices are append-only files", + mode: ModeCharacterDevice, + want: TypeAppendOnly, + }, + { + name: "Symlinks are of type symlink", + mode: ModeSymlink, + want: TypeSymlink, + }, + { + name: "Regular files are of type regular", + mode: ModeRegular, + want: TypeRegular, + }, + { + name: "Block devices are regular files", + mode: ModeBlockDevice, + want: TypeRegular, + }, + } { + if qidType := test.mode.QIDType(); qidType != test.want { + t.Errorf("ModeToQID test %s failed: got %o, wanted %o", test.name, qidType, test.want) + } + } +} + +func TestP9ModeConverters(t *testing.T) { + for _, m := range []FileMode{ + ModeRegular, + ModeDirectory, + ModeCharacterDevice, + ModeBlockDevice, + ModeSocket, + ModeSymlink, + ModeNamedPipe, + } { + if mb := ModeFromOS(m.OSMode()); mb != m { + t.Errorf("Converting %o to OS.FileMode gives %o and is converted back as %o", m, m.OSMode(), mb) + } + } +} + +func TestOSModeConverters(t *testing.T) { + // Modes that can be converted back and forth. + for _, m := range []os.FileMode{ + 0, // Regular file. + os.ModeDir, + os.ModeCharDevice | os.ModeDevice, + os.ModeDevice, + os.ModeSocket, + os.ModeSymlink, + os.ModeNamedPipe, + } { + if mb := ModeFromOS(m).OSMode(); mb != m { + t.Errorf("Converting %o to p9.FileMode gives %o and is converted back as %o", m, ModeFromOS(m), mb) + } + } + + // Modes that will be converted to a regular file since p9 cannot + // express these. + for _, m := range []os.FileMode{ + os.ModeAppend, + os.ModeExclusive, + os.ModeTemporary, + } { + if p9Mode := ModeFromOS(m); p9Mode != ModeRegular { + t.Errorf("Converting %o to p9.FileMode should have given ModeRegular, but yielded %o", m, p9Mode) + } + } +} + +func TestAttrMaskContains(t *testing.T) { + req := AttrMask{Mode: true, Size: true} + have := AttrMask{} + if have.Contains(req) { + t.Fatalf("AttrMask %v should not be a superset of %v", have, req) + } + have.Mode = true + if have.Contains(req) { + t.Fatalf("AttrMask %v should not be a superset of %v", have, req) + } + have.Size = true + have.MTime = true + if !have.Contains(req) { + t.Fatalf("AttrMask %v should be a superset of %v", have, req) + } +} diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD new file mode 100644 index 000000000..7ca67cb19 --- /dev/null +++ b/pkg/p9/p9test/BUILD @@ -0,0 +1,88 @@ +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") + +package(licenses = ["notice"]) + +alias( + name = "mockgen", + actual = "@com_github_golang_mock//mockgen:mockgen", +) + +MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9" + +# mockgen_reflect is a source file that contains mock generation code that +# imports the p9 package and generates a specification via reflection. The +# usual generation path must be split into two distinct parts because the full +# source tree is not available to all build targets. Only declared depencies +# are available (and even then, not the Go source files). +genrule( + name = "mockgen_reflect", + testonly = 1, + outs = ["mockgen_reflect.go"], + cmd = ( + "$(location :mockgen) " + + "-package p9test " + + "-prog_only " + MOCK_SRC_PACKAGE + " " + + "Attacher,File > $@" + ), + tools = [":mockgen"], +) + +# mockgen_exec is the binary that includes the above reflection generator. +# Running this binary will emit an encoded version of the p9 Attacher and File +# structures. This is consumed by the mocks genrule, below. +go_binary( + name = "mockgen_exec", + testonly = 1, + srcs = ["mockgen_reflect.go"], + deps = [ + "//pkg/p9", + "@com_github_golang_mock//mockgen/model:go_default_library", + ], +) + +# mocks consumes the encoded output above, and generates the full source for a +# set of mocks. These are included directly in the p9test library. +genrule( + name = "mocks", + testonly = 1, + outs = ["mocks.go"], + cmd = ( + "$(location :mockgen) " + + "-package p9test " + + "-exec_only $(location :mockgen_exec) " + MOCK_SRC_PACKAGE + " File > $@" + ), + tools = [ + ":mockgen", + ":mockgen_exec", + ], +) + +go_library( + name = "p9test", + srcs = [ + "mocks.go", + "p9test.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/fd", + "//pkg/log", + "//pkg/p9", + "//pkg/sync", + "//pkg/unet", + "@com_github_golang_mock//gomock:go_default_library", + ], +) + +go_test( + name = "client_test", + size = "medium", + srcs = ["client_test.go"], + library = ":p9test", + deps = [ + "//pkg/fd", + "//pkg/p9", + "//pkg/sync", + "@com_github_golang_mock//gomock:go_default_library", + ], +) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go new file mode 100644 index 000000000..6e7bb3db2 --- /dev/null +++ b/pkg/p9/p9test/client_test.go @@ -0,0 +1,2242 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9test + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "os" + "reflect" + "strings" + "syscall" + "testing" + "time" + + "github.com/golang/mock/gomock" + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" +) + +func TestPanic(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + // Create a new root. + d := h.NewDirectory(nil)(nil) + defer d.Close() // Needed manually. + h.Attacher.EXPECT().Attach().Return(d, nil).Do(func() { + // Panic here, and ensure that we get back EFAULT. + panic("handler") + }) + + // Attach to the client. + if _, err := c.Attach("/"); err != syscall.EFAULT { + t.Fatalf("got attach err %v, want EFAULT", err) + } +} + +func TestAttachNoLeak(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + // Create a new root. + d := h.NewDirectory(nil)(nil) + h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) + + // Attach to the client. + f, err := c.Attach("/") + if err != nil { + t.Fatalf("got attach err %v, want nil", err) + } + + // Don't close the file. This should be closed automatically when the + // client disconnects. The mock asserts that everything is closed + // exactly once. This statement just removes the unused variable error. + _ = f +} + +func TestBadAttach(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + // Return an error on attach. + h.Attacher.EXPECT().Attach().Return(nil, syscall.EINVAL).Times(1) + + // Attach to the client. + if _, err := c.Attach("/"); err != syscall.EINVAL { + t.Fatalf("got attach err %v, want syscall.EINVAL", err) + } +} + +func TestWalkAttach(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + // Create a new root. + d := h.NewDirectory(map[string]Generator{ + "a": h.NewDirectory(map[string]Generator{ + "b": h.NewFile(), + }), + })(nil) + h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) + + // Attach to the client as a non-root, and ensure that the walk above + // occurs as expected. We should get back b, and all references should + // be dropped when the file is closed. + f, err := c.Attach("/a/b") + if err != nil { + t.Fatalf("got attach err %v, want nil", err) + } + defer f.Close() + + // Check that's a regular file. + if _, _, attr, err := f.GetAttr(p9.AttrMaskAll()); err != nil { + t.Errorf("got err %v, want nil", err) + } else if !attr.Mode.IsRegular() { + t.Errorf("got mode %v, want regular file", err) + } +} + +// newTypeMap returns a new type map dictionary. +func newTypeMap(h *Harness) map[string]Generator { + return map[string]Generator{ + "directory": h.NewDirectory(map[string]Generator{}), + "file": h.NewFile(), + "symlink": h.NewSymlink(), + "block-device": h.NewBlockDevice(), + "character-device": h.NewCharacterDevice(), + "named-pipe": h.NewNamedPipe(), + "socket": h.NewSocket(), + } +} + +// newRoot returns a new root filesystem. +// +// This is set up in a deterministic way for testing most operations. +// +// The represented file system looks like: +// - file +// - symlink +// - directory +// ... +// + one +// - file +// - symlink +// - directory +// ... +// + two +// - file +// - symlink +// - directory +// ... +// + three +// - file +// - symlink +// - directory +// ... +func newRoot(h *Harness, c *p9.Client) (*Mock, p9.File) { + root := newTypeMap(h) + one := newTypeMap(h) + two := newTypeMap(h) + three := newTypeMap(h) + one["two"] = h.NewDirectory(two) // Will be nested in one. + root["one"] = h.NewDirectory(one) // Top level. + root["three"] = h.NewDirectory(three) // Alternate top-level. + + // Create a new root. + rootBackend := h.NewDirectory(root)(nil) + h.Attacher.EXPECT().Attach().Return(rootBackend, nil) + + // Attach to the client. + r, err := c.Attach("/") + if err != nil { + h.t.Fatalf("got attach err %v, want nil", err) + } + + return rootBackend, r +} + +func allInvalidNames(from string) []string { + return []string{ + from + "/other", + from + "/..", + from + "/.", + from + "/", + "other/" + from, + "/" + from, + "./" + from, + "../" + from, + ".", + "..", + "/", + "", + } +} + +func TestWalkInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + // Run relevant tests. + for name := range newTypeMap(h) { + // These are all the various ways that one might attempt to + // construct compound paths. They should all be rejected, as + // any compound that contains a / is not allowed, as well as + // the singular paths of '.' and '..'. + if _, _, err := root.Walk([]string{".", name}); err != syscall.EINVAL { + t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) + } + if _, _, err := root.Walk([]string{"..", name}); err != syscall.EINVAL { + t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) + } + if _, _, err := root.Walk([]string{name, "."}); err != syscall.EINVAL { + t.Errorf("Walk through %s . wanted EINVAL, got %v", name, err) + } + if _, _, err := root.Walk([]string{name, ".."}); err != syscall.EINVAL { + t.Errorf("Walk through %s .. wanted EINVAL, got %v", name, err) + } + for _, invalidName := range allInvalidNames(name) { + if _, _, err := root.Walk([]string{invalidName}); err != syscall.EINVAL { + t.Errorf("Walk through %s wanted EINVAL, got %v", invalidName, err) + } + } + wantErr := syscall.EINVAL + if name == "directory" { + // We can attempt a walk through a directory. However, + // we should never see a file named "other", so we + // expect this to return ENOENT. + wantErr = syscall.ENOENT + } + if _, _, err := root.Walk([]string{name, "other"}); err != wantErr { + t.Errorf("Walk through %s/other wanted %v, got %v", name, wantErr, err) + } + + // Do a successful walk. + _, f, err := root.Walk([]string{name}) + if err != nil { + t.Errorf("Walk to %s wanted nil, got %v", name, err) + } + defer f.Close() + local := h.Pop(f) + + // Check that the file matches. + _, localMask, localAttr, localErr := local.GetAttr(p9.AttrMaskAll()) + if _, mask, attr, err := f.GetAttr(p9.AttrMaskAll()); mask != localMask || attr != localAttr || err != localErr { + t.Errorf("GetAttr got (%v, %v, %v), wanted (%v, %v, %v)", + mask, attr, err, localMask, localAttr, localErr) + } + + // Ensure we can't walk backwards. + if _, _, err := f.Walk([]string{"."}); err != syscall.EINVAL { + t.Errorf("Walk through %s/. wanted EINVAL, got %v", name, err) + } + if _, _, err := f.Walk([]string{".."}); err != syscall.EINVAL { + t.Errorf("Walk through %s/.. wanted EINVAL, got %v", name, err) + } + } +} + +// fileGenerator is a function to generate files via walk or create. +// +// Examples are: +// - walkHelper +// - walkAndOpenHelper +// - createHelper +type fileGenerator func(*Harness, string, p9.File) (*Mock, *Mock, p9.File) + +// walkHelper walks to the given file. +// +// The backends of the parent and walked file are returned, as well as the +// walked client file. +func walkHelper(h *Harness, name string, dir p9.File) (parentBackend *Mock, walkedBackend *Mock, walked p9.File) { + _, parent, err := dir.Walk(nil) + if err != nil { + h.t.Fatalf("Walk(nil) got err %v, want nil", err) + } + defer parent.Close() + parentBackend = h.Pop(parent) + + _, walked, err = parent.Walk([]string{name}) + if err != nil { + h.t.Fatalf("Walk(%s) got err %v, want nil", name, err) + } + walkedBackend = h.Pop(walked) + + return parentBackend, walkedBackend, walked +} + +// walkAndOpenHelper additionally opens the walked file, if possible. +func walkAndOpenHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { + parentBackend, walkedBackend, walked := walkHelper(h, name, dir) + if p9.CanOpen(walkedBackend.Attr.Mode) { + // Open for all file types that we can. We stick to a read-only + // open here because directories may not be opened otherwise. + walkedBackend.EXPECT().Open(p9.ReadOnly).Times(1) + if _, _, _, err := walked.Open(p9.ReadOnly); err != nil { + h.t.Errorf("got open err %v, want nil", err) + } + } else { + // ... or assert an error for others. + if _, _, _, err := walked.Open(p9.ReadOnly); err != syscall.EINVAL { + h.t.Errorf("got open err %v, want EINVAL", err) + } + } + return parentBackend, walkedBackend, walked +} + +// createHelper creates the given file and returns the parent directory, +// created file and client file, which must be closed when done. +func createHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { + // Clone the directory first, since Create replaces the existing file. + // We change the type after calling create. + _, dirThenFile, err := dir.Walk(nil) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + + // Create a new server-side file. On the server-side, the a new file is + // returned from a create call. The client will reuse the same file, + // but we still expect the normal chain of closes. This complicates + // things a bit because the "parent" will always chain to the cloned + // dir above. + dirBackend := h.Pop(dirThenFile) // New backend directory. + newFile := h.NewFile()(dirBackend) // New file with backend parent. + dirBackend.EXPECT().Create(name, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, newFile, newFile.QID, uint32(0), nil) + + // Create via the client. + _, dirThenFile, _, _, err = dirThenFile.Create(name, p9.ReadOnly, 0, 0, 0) + if err != nil { + h.t.Fatalf("got create err %v, want nil", err) + } + + // Ensure subsequent walks succeed. + dirBackend.AddChild(name, h.NewFile()) + return dirBackend, newFile, dirThenFile +} + +// deprecatedRemover allows us to access the deprecated Remove operation within +// the p9.File client object. +type deprecatedRemover interface { + Remove() error +} + +// checkDeleted asserts that relevant methods fail for an unlinked file. +// +// This function will close the file at the end. +func checkDeleted(h *Harness, file p9.File) { + defer file.Close() // See doc. + + if _, _, _, err := file.Open(p9.ReadOnly); err != syscall.EINVAL { + h.t.Errorf("open while deleted, got %v, want EINVAL", err) + } + if _, _, _, _, err := file.Create("created", p9.ReadOnly, 0, 0, 0); err != syscall.EINVAL { + h.t.Errorf("create while deleted, got %v, want EINVAL", err) + } + if _, err := file.Symlink("old", "new", 0, 0); err != syscall.EINVAL { + h.t.Errorf("symlink while deleted, got %v, want EINVAL", err) + } + // N.B. This link is technically invalid, but if a call to link is + // actually made in the backend then the mock will panic. + if err := file.Link(file, "new"); err != syscall.EINVAL { + h.t.Errorf("link while deleted, got %v, want EINVAL", err) + } + if err := file.RenameAt("src", file, "dst"); err != syscall.EINVAL { + h.t.Errorf("renameAt while deleted, got %v, want EINVAL", err) + } + if err := file.UnlinkAt("file", 0); err != syscall.EINVAL { + h.t.Errorf("unlinkAt while deleted, got %v, want EINVAL", err) + } + if err := file.Rename(file, "dst"); err != syscall.EINVAL { + h.t.Errorf("rename while deleted, got %v, want EINVAL", err) + } + if _, err := file.Readlink(); err != syscall.EINVAL { + h.t.Errorf("readlink while deleted, got %v, want EINVAL", err) + } + if _, err := file.Mkdir("dir", p9.ModeDirectory, 0, 0); err != syscall.EINVAL { + h.t.Errorf("mkdir while deleted, got %v, want EINVAL", err) + } + if _, err := file.Mknod("dir", p9.ModeDirectory, 0, 0, 0, 0); err != syscall.EINVAL { + h.t.Errorf("mknod while deleted, got %v, want EINVAL", err) + } + if _, err := file.Readdir(0, 1); err != syscall.EINVAL { + h.t.Errorf("readdir while deleted, got %v, want EINVAL", err) + } + if _, err := file.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL { + h.t.Errorf("connect while deleted, got %v, want EINVAL", err) + } + + // The remove method is technically deprecated, but we want to ensure + // that it still checks for deleted appropriately. We must first clone + // the file because remove is equivalent to close. + _, newFile, err := file.Walk(nil) + if err == syscall.EBUSY { + // We can't walk from here because this reference is open + // already. Okay, we will also have unopened cases through + // TestUnlink, just skip the remove operation for now. + return + } else if err != nil { + h.t.Fatalf("clone failed, got %v, want nil", err) + } + if err := newFile.(deprecatedRemover).Remove(); err != syscall.EINVAL { + h.t.Errorf("remove while deleted, got %v, want EINVAL", err) + } +} + +// deleter is a function to remove a file. +type deleter func(parent p9.File, name string) error + +// unlinkAt is a deleter. +func unlinkAt(parent p9.File, name string) error { + // Call unlink. Note that a filesystem may normally impose additional + // constaints on unlinkat success, such as ensuring that a directory is + // empty, requiring AT_REMOVEDIR in flags to remove a directory, etc. + // None of that is required internally (entire trees can be marked + // deleted when this operation succeeds), so the mock will succeed. + return parent.UnlinkAt(name, 0) +} + +// remove is a deleter. +func remove(parent p9.File, name string) error { + // See notes above re: remove. + _, newFile, err := parent.Walk([]string{name}) + if err != nil { + // Should not be expected. + return err + } + + // Do the actual remove. + if err := newFile.(deprecatedRemover).Remove(); err != nil { + return err + } + + // Ensure that the remove closed the file. + if err := newFile.(deprecatedRemover).Remove(); err != syscall.EBADF { + return syscall.EBADF // Propagate this code. + } + + return nil +} + +// unlinkHelper unlinks the noted path, and ensures that all relevant +// operations on that path, acquired from multiple paths, start failing. +func unlinkHelper(h *Harness, root p9.File, targetNames []string, targetGen fileGenerator, deleteFn deleter) { + // name is the file to be unlinked. + name := targetNames[len(targetNames)-1] + + // Walk to the directory containing the target. + _, parent, err := root.Walk(targetNames[:len(targetNames)-1]) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer parent.Close() + parentBackend := h.Pop(parent) + + // Walk to or generate the target file. + _, _, target := targetGen(h, name, parent) + defer checkDeleted(h, target) + + // Walk to a second reference. + _, second, err := parent.Walk([]string{name}) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer checkDeleted(h, second) + + // Walk to a third reference, from the start. + _, third, err := root.Walk(targetNames) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer checkDeleted(h, third) + + // This will be translated in the backend to an unlinkat. + parentBackend.EXPECT().UnlinkAt(name, uint32(0)).Return(nil) + + // Actually perform the deletion. + if err := deleteFn(parent, name); err != nil { + h.t.Fatalf("got delete err %v, want nil", err) + } +} + +func unlinkTest(t *testing.T, targetNames []string, targetGen fileGenerator) { + t.Run(fmt.Sprintf("unlinkAt(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + unlinkHelper(h, root, targetNames, targetGen, unlinkAt) + }) + t.Run(fmt.Sprintf("remove(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + unlinkHelper(h, root, targetNames, targetGen, remove) + }) +} + +func TestUnlink(t *testing.T) { + // Unlink all files. + for name := range newTypeMap(nil) { + unlinkTest(t, []string{name}, walkHelper) + unlinkTest(t, []string{name}, walkAndOpenHelper) + unlinkTest(t, []string{"one", name}, walkHelper) + unlinkTest(t, []string{"one", name}, walkAndOpenHelper) + unlinkTest(t, []string{"one", "two", name}, walkHelper) + unlinkTest(t, []string{"one", "two", name}, walkAndOpenHelper) + } + + // Unlink a directory. + unlinkTest(t, []string{"one"}, walkHelper) + unlinkTest(t, []string{"one"}, walkAndOpenHelper) + unlinkTest(t, []string{"one", "two"}, walkHelper) + unlinkTest(t, []string{"one", "two"}, walkAndOpenHelper) + + // Unlink created files. + unlinkTest(t, []string{"created"}, createHelper) + unlinkTest(t, []string{"one", "created"}, createHelper) + unlinkTest(t, []string{"one", "two", "created"}, createHelper) +} + +func TestUnlinkAtInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + for name := range newTypeMap(nil) { + for _, invalidName := range allInvalidNames(name) { + if err := root.UnlinkAt(invalidName, 0); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + } + } +} + +// expectRenamed asserts an ordered sequence of rename calls, based on all the +// elements in elements being the source, and the first element therein +// changing to dstName, parented at dstParent. +func expectRenamed(file *Mock, elements []string, dstParent *Mock, dstName string) *gomock.Call { + if len(elements) > 0 { + // Recurse to the parent, if necessary. + call := expectRenamed(file.parent, elements[:len(elements)-1], dstParent, dstName) + + // Recursive case: this element is unchanged, but should have + // it's hook called after the parent. + return file.EXPECT().Renamed(file.parent, elements[len(elements)-1]).Do(func(p p9.File, _ string) { + file.parent = p.(*Mock) + }).After(call) + } + + // Base case: this is the changed element. + return file.EXPECT().Renamed(dstParent, dstName).Do(func(p p9.File, name string) { + file.parent = p.(*Mock) + }) +} + +// renamer is a rename function. +type renamer func(h *Harness, srcParent, dstParent p9.File, origName, newName string, selfRename bool) error + +// renameAt is a renamer. +func renameAt(_ *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { + return srcParent.RenameAt(srcName, dstParent, dstName) +} + +// rename is a renamer. +func rename(h *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { + _, f, err := srcParent.Walk([]string{srcName}) + if err != nil { + return err + } + defer f.Close() + if !selfRename { + backend := h.Pop(f) + backend.EXPECT().Renamed(gomock.Any(), dstName).Do(func(p p9.File, name string) { + backend.parent = p.(*Mock) // Required for close ordering. + }) + } + return f.Rename(dstParent, dstName) +} + +// renameHelper executes a rename, and asserts that all relevant elements +// receive expected notifications. If overwriting a file, this includes +// ensuring that the target has been appropriately marked as unlinked. +func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string, target fileGenerator, renameFn renamer) { + // Walk to the directory containing the target. + srcQID, targetParent, err := root.Walk(srcNames[:len(srcNames)-1]) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer targetParent.Close() + targetParentBackend := h.Pop(targetParent) + + // Walk to or generate the target file. + _, targetBackend, src := target(h, srcNames[len(srcNames)-1], targetParent) + defer src.Close() + + // Walk to a second reference. + _, second, err := targetParent.Walk([]string{srcNames[len(srcNames)-1]}) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer second.Close() + secondBackend := h.Pop(second) + + // Walk to a third reference, from the start. + _, third, err := root.Walk(srcNames) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer third.Close() + thirdBackend := h.Pop(third) + + // Find the common suffix to identify the rename parent. + var ( + renameDestPath []string + renameSrcPath []string + selfRename bool + ) + for i := 1; i <= len(srcNames) && i <= len(dstNames); i++ { + if srcNames[len(srcNames)-i] != dstNames[len(dstNames)-i] { + // Take the full prefix of dstNames up until this + // point, including the first mismatched name. The + // first mismatch must be the renamed entry. + renameDestPath = dstNames[:len(dstNames)-i+1] + renameSrcPath = srcNames[:len(srcNames)-i+1] + + // Does the renameDestPath fully contain the + // renameSrcPath here? If yes, then this is a mismatch. + // We can't rename the src to some subpath of itself. + if len(renameDestPath) > len(renameSrcPath) && + reflect.DeepEqual(renameDestPath[:len(renameSrcPath)], renameSrcPath) { + renameDestPath = nil + renameSrcPath = nil + continue + } + break + } + } + if len(renameSrcPath) == 0 || len(renameDestPath) == 0 { + // This must be a rename to self, or a tricky look-alike. This + // happens iff we fail to find a suitable divergence in the two + // paths. It's a true self move if the path length is the same. + renameDestPath = dstNames + renameSrcPath = srcNames + selfRename = len(srcNames) == len(dstNames) + } + + // Walk to the source parent. + _, srcParent, err := root.Walk(renameSrcPath[:len(renameSrcPath)-1]) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer srcParent.Close() + srcParentBackend := h.Pop(srcParent) + + // Walk to the destination parent. + _, dstParent, err := root.Walk(renameDestPath[:len(renameDestPath)-1]) + if err != nil { + h.t.Fatalf("got walk err %v, want nil", err) + } + defer dstParent.Close() + dstParentBackend := h.Pop(dstParent) + + // expectedErr is the result of the rename operation. + var expectedErr error + + // Walk to the target file, if one exists. + dstQID, dst, err := root.Walk(renameDestPath) + if err == nil { + if !selfRename && srcQID[0].Type == dstQID[0].Type { + // If there is a destination file, and is it of the + // same type as the source file, then we expect the + // rename to succeed. We expect the destination file to + // be deleted, so we run a deletion test on it in this + // case. + defer checkDeleted(h, dst) + } else { + if !selfRename { + // If the type is different than the + // destination, then we expect the rename to + // fail. We expect ensure that this is + // returned. + expectedErr = syscall.EINVAL + } else { + // This is the file being renamed to itself. + // This is technically allowed and a no-op, but + // all the triggers will fire. + } + dst.Close() + } + } + dstName := renameDestPath[len(renameDestPath)-1] // Renamed element. + srcName := renameSrcPath[len(renameSrcPath)-1] // Renamed element. + if expectedErr == nil && !selfRename { + // Expect all to be renamed appropriately. Note that if this is + // a final file being renamed, then we expect the file to be + // called with the new parent. If not, then we expect the + // rename hook to be called, but the parent will remain + // unchanged. + elements := srcNames[len(renameSrcPath):] + expectRenamed(targetBackend, elements, dstParentBackend, dstName) + expectRenamed(secondBackend, elements, dstParentBackend, dstName) + expectRenamed(thirdBackend, elements, dstParentBackend, dstName) + + // The target parent has also been opened, and may be moved + // directly or indirectly. + if len(elements) > 1 { + expectRenamed(targetParentBackend, elements[:len(elements)-1], dstParentBackend, dstName) + } + } + + // Expect the rename if it's not the same file. Note that like unlink, + // renames are always translated to the at variant in the backend. + if !selfRename { + srcParentBackend.EXPECT().RenameAt(srcName, dstParentBackend, dstName).Return(expectedErr) + } + + // Perform the actual rename; everything has been lined up. + if err := renameFn(h, srcParent, dstParent, srcName, dstName, selfRename); err != expectedErr { + h.t.Fatalf("got rename err %v, want %v", err, expectedErr) + } +} + +func renameTest(t *testing.T, srcNames []string, dstNames []string, target fileGenerator) { + t.Run(fmt.Sprintf("renameAt(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + renameHelper(h, root, srcNames, dstNames, target, renameAt) + }) + t.Run(fmt.Sprintf("rename(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + renameHelper(h, root, srcNames, dstNames, target, rename) + }) +} + +func TestRename(t *testing.T) { + // In-directory rename, simple case. + for name := range newTypeMap(nil) { + // Within the root. + renameTest(t, []string{name}, []string{"renamed"}, walkHelper) + renameTest(t, []string{name}, []string{"renamed"}, walkAndOpenHelper) + + // Within a subdirectory. + renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkHelper) + renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkAndOpenHelper) + } + + // ... with created files. + renameTest(t, []string{"created"}, []string{"renamed"}, createHelper) + renameTest(t, []string{"one", "created"}, []string{"one", "renamed"}, createHelper) + + // Across directories. + for name := range newTypeMap(nil) { + // Down one level. + renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkHelper) + renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkAndOpenHelper) + + // Up one level. + renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkHelper) + renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkAndOpenHelper) + + // Across at the same level. + renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkHelper) + renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkAndOpenHelper) + } + + // ... with created files. + renameTest(t, []string{"one", "created"}, []string{"one", "two", "renamed"}, createHelper) + renameTest(t, []string{"one", "two", "created"}, []string{"one", "renamed"}, createHelper) + renameTest(t, []string{"one", "created"}, []string{"three", "renamed"}, createHelper) + + // Renaming parents. + for name := range newTypeMap(nil) { + // Rename a parent. + renameTest(t, []string{"one", name}, []string{"renamed", name}, walkHelper) + renameTest(t, []string{"one", name}, []string{"renamed", name}, walkAndOpenHelper) + + // Rename a super parent. + renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkHelper) + renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkAndOpenHelper) + } + + // ... with created files. + renameTest(t, []string{"one", "created"}, []string{"renamed", "created"}, createHelper) + renameTest(t, []string{"one", "two", "created"}, []string{"renamed", "created"}, createHelper) + + // Over existing files, including itself. + for name := range newTypeMap(nil) { + for other := range newTypeMap(nil) { + // Overwrite the noted file (may be itself). + renameTest(t, []string{"one", name}, []string{"one", other}, walkHelper) + renameTest(t, []string{"one", name}, []string{"one", other}, walkAndOpenHelper) + + // Overwrite other files in another directory. + renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkHelper) + renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkAndOpenHelper) + } + + // Overwrite by moving the parent. + renameTest(t, []string{"three", name}, []string{"one", name}, walkHelper) + renameTest(t, []string{"three", name}, []string{"one", name}, walkAndOpenHelper) + + // Create over the types. + renameTest(t, []string{"one", "created"}, []string{"one", name}, createHelper) + renameTest(t, []string{"one", "created"}, []string{"one", "two", name}, createHelper) + renameTest(t, []string{"three", "created"}, []string{"one", name}, createHelper) + } +} + +func TestRenameInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + for name := range newTypeMap(nil) { + for _, invalidName := range allInvalidNames(name) { + if err := root.Rename(root, invalidName); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + } + } +} + +func TestRenameAtInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + for name := range newTypeMap(nil) { + for _, invalidName := range allInvalidNames(name) { + if err := root.RenameAt(invalidName, root, "okay"); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + if err := root.RenameAt("okay", root, invalidName); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + } + } +} + +// TestRenameSecondOrder tests that indirect rename targets continue to receive +// Renamed calls after a rename of its renamed parent. i.e., +// +// 1. Create /one/file +// 2. Create /directory +// 3. Rename /one -> /directory/one +// 4. Rename /directory -> /three/foo +// 5. file from (1) should still receive Renamed. +// +// This is a regression test for b/135219260. +func TestRenameSecondOrder(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + rootBackend, root := newRoot(h, c) + defer root.Close() + + // Walk to /one. + _, oneBackend, oneFile := walkHelper(h, "one", root) + defer oneFile.Close() + + // Walk to and generate /one/file. + // + // walkHelper re-walks to oneFile, so we need the second backend, + // which will also receive Renamed calls. + oneSecondBackend, fileBackend, fileFile := walkHelper(h, "file", oneFile) + defer fileFile.Close() + + // Walk to and generate /directory. + _, directoryBackend, directoryFile := walkHelper(h, "directory", root) + defer directoryFile.Close() + + // Rename /one to /directory/one. + rootBackend.EXPECT().RenameAt("one", directoryBackend, "one").Return(nil) + expectRenamed(oneBackend, []string{}, directoryBackend, "one") + expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one") + expectRenamed(fileBackend, []string{}, oneBackend, "file") + if err := renameAt(h, root, directoryFile, "one", "one", false); err != nil { + h.t.Fatalf("got rename err %v, want nil", err) + } + + // Walk to /three. + _, threeBackend, threeFile := walkHelper(h, "three", root) + defer threeFile.Close() + + // Rename /directory to /three/foo. + rootBackend.EXPECT().RenameAt("directory", threeBackend, "foo").Return(nil) + expectRenamed(directoryBackend, []string{}, threeBackend, "foo") + expectRenamed(oneBackend, []string{}, directoryBackend, "one") + expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one") + expectRenamed(fileBackend, []string{}, oneBackend, "file") + if err := renameAt(h, root, threeFile, "directory", "foo", false); err != nil { + h.t.Fatalf("got rename err %v, want nil", err) + } +} + +func TestReadlink(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + // Walk to the file normally. + _, f, err := root.Walk([]string{name}) + if err != nil { + t.Fatalf("walk failed: got %v, wanted nil", err) + } + defer f.Close() + backend := h.Pop(f) + + const symlinkTarget = "symlink-target" + + if backend.Attr.Mode.IsSymlink() { + // This should only go through on symlinks. + backend.EXPECT().Readlink().Return(symlinkTarget, nil) + } + + // Attempt a Readlink operation. + target, err := f.Readlink() + if err != nil && err != syscall.EINVAL { + t.Errorf("readlink got %v, wanted EINVAL", err) + } else if err == nil && target != symlinkTarget { + t.Errorf("readlink got %v, wanted %v", target, symlinkTarget) + } + }) + } +} + +// fdTest is a wrapper around operations that may send file descriptors. This +// asserts that the file descriptors are working as intended. +func fdTest(t *testing.T, sendFn func(*fd.FD) *fd.FD) { + // Create a pipe that we can read from. + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("unable to create pipe: %v", err) + } + defer r.Close() + defer w.Close() + + // Attempt to send the write end. + wFD, err := fd.NewFromFile(w) + if err != nil { + t.Fatalf("unable to convert file: %v", err) + } + defer wFD.Close() // This is a copy. + + // Send wFD and receive newFD. + newFD := sendFn(wFD) + defer newFD.Close() + + // Attempt to write. + const message = "hello" + if _, err := newFD.Write([]byte(message)); err != nil { + t.Fatalf("write got %v, wanted nil", err) + } + + // Should see the message on our end. + buffer := []byte(message) + if _, err := io.ReadFull(r, buffer); err != nil { + t.Fatalf("read got %v, wanted nil", err) + } + if string(buffer) != message { + t.Errorf("got message %v, wanted %v", string(buffer), message) + } +} + +func TestConnect(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + // Walk to the file normally. + _, backend, f := walkHelper(h, name, root) + defer f.Close() + + // Catch all the non-socket cases. + if !backend.Attr.Mode.IsSocket() { + // This has been set up to fail if Connect is called. + if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL { + t.Errorf("connect got %v, wanted EINVAL", err) + } + return + } + + // Ensure the fd exchange works. + fdTest(t, func(send *fd.FD) *fd.FD { + backend.EXPECT().Connect(p9.ConnectFlags(0)).Return(send, nil) + recv, err := backend.Connect(p9.ConnectFlags(0)) + if err != nil { + t.Fatalf("connect got %v, wanted nil", err) + } + return recv + }) + }) + } +} + +func TestReaddir(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + // Walk to the file normally. + _, backend, f := walkHelper(h, name, root) + defer f.Close() + + // Catch all the non-directory cases. + if !backend.Attr.Mode.IsDir() { + // This has also been set up to fail if Readdir is called. + if _, err := f.Readdir(0, 1); err != syscall.EINVAL { + t.Errorf("readdir got %v, wanted EINVAL", err) + } + return + } + + // Ensure that readdir works for directories. + if _, err := f.Readdir(0, 1); err != syscall.EINVAL { + t.Errorf("readdir got %v, wanted EINVAL", err) + } + if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) + } + if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) + } + backend.EXPECT().Open(p9.ReadOnly).Times(1) + if _, _, _, err := f.Open(p9.ReadOnly); err != nil { + t.Errorf("readdir got %v, wanted nil", err) + } + backend.EXPECT().Readdir(uint64(0), uint32(1)).Times(1) + if _, err := f.Readdir(0, 1); err != nil { + t.Errorf("readdir got %v, wanted nil", err) + } + }) + } +} + +func TestOpen(t *testing.T) { + type openTest struct { + name string + flags p9.OpenFlags + err error + match func(p9.FileMode) bool + } + + cases := []openTest{ + { + name: "not-openable-read-only", + flags: p9.ReadOnly, + err: syscall.EINVAL, + match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, + }, + { + name: "not-openable-write-only", + flags: p9.WriteOnly, + err: syscall.EINVAL, + match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, + }, + { + name: "not-openable-read-write", + flags: p9.ReadWrite, + err: syscall.EINVAL, + match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, + }, + { + name: "directory-read-only", + flags: p9.ReadOnly, + err: nil, + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + }, + { + name: "directory-read-write", + flags: p9.ReadWrite, + err: syscall.EISDIR, + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + }, + { + name: "directory-write-only", + flags: p9.WriteOnly, + err: syscall.EISDIR, + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + }, + { + name: "read-only", + flags: p9.ReadOnly, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, + }, + { + name: "write-only", + flags: p9.WriteOnly, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "read-write", + flags: p9.ReadWrite, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "directory-read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: syscall.EISDIR, + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + }, + { + name: "read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "write-only-truncate", + flags: p9.WriteOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "read-write-truncate", + flags: p9.ReadWrite | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + } + + // Open(flags OpenFlags) (*fd.FD, QID, uint32, error) + // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice + // - returning a file works as expected + for name := range newTypeMap(nil) { + for _, tc := range cases { + t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + // Walk to the file normally. + _, backend, f := walkHelper(h, name, root) + defer f.Close() + + // Does this match the case? + if !tc.match(backend.Attr.Mode) { + t.SkipNow() + } + + // Ensure open-required operations fail. + if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EINVAL { + t.Errorf("readAt got %v, wanted EINVAL", err) + } + if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EINVAL { + t.Errorf("writeAt got %v, wanted EINVAL", err) + } + if err := f.FSync(); err != syscall.EINVAL { + t.Errorf("fsync got %v, wanted EINVAL", err) + } + if _, err := f.Readdir(0, 1); err != syscall.EINVAL { + t.Errorf("readdir got %v, wanted EINVAL", err) + } + + // Attempt the given open. + if tc.err != nil { + // We expect an error, just test and return. + if _, _, _, err := f.Open(tc.flags); err != tc.err { + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) + } + return + } + + // Run an FD test, since we expect success. + fdTest(t, func(send *fd.FD) *fd.FD { + backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1) + recv, _, _, err := f.Open(tc.flags) + if err != tc.err { + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) + } + return recv + }) + + // If the open was successful, attempt another one. + if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL { + t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err) + } + + // Ensure that all illegal operations fail. + if _, _, err := f.Walk(nil); err != syscall.EINVAL && err != syscall.EBUSY { + t.Errorf("walk got %v, wanted EINVAL or EBUSY", err) + } + if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EINVAL && err != syscall.EBUSY { + t.Errorf("walkgetattr got %v, wanted EINVAL or EBUSY", err) + } + }) + } + } +} + +func TestClose(t *testing.T) { + type closeTest struct { + name string + closeFn func(backend *Mock, f p9.File) + } + + cases := []closeTest{ + { + name: "close", + closeFn: func(_ *Mock, f p9.File) { + f.Close() + }, + }, + { + name: "remove", + closeFn: func(backend *Mock, f p9.File) { + // Allow the rename call in the parent, automatically translated. + backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1) + f.(deprecatedRemover).Remove() + }, + }, + } + + for name := range newTypeMap(nil) { + for _, tc := range cases { + t.Run(fmt.Sprintf("%s(%s)", tc.name, name), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + // Walk to the file normally. + _, backend, f := walkHelper(h, name, root) + + // Close via the prescribed method. + tc.closeFn(backend, f) + + // Everything should fail with EBADF. + if _, _, err := f.Walk(nil); err != syscall.EBADF { + t.Errorf("walk got %v, wanted EBADF", err) + } + if _, err := f.StatFS(); err != syscall.EBADF { + t.Errorf("statfs got %v, wanted EBADF", err) + } + if _, _, _, err := f.GetAttr(p9.AttrMaskAll()); err != syscall.EBADF { + t.Errorf("getattr got %v, wanted EBADF", err) + } + if err := f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}); err != syscall.EBADF { + t.Errorf("setattrk got %v, wanted EBADF", err) + } + if err := f.Rename(root, "new-name"); err != syscall.EBADF { + t.Errorf("rename got %v, wanted EBADF", err) + } + if err := f.Close(); err != syscall.EBADF { + t.Errorf("close got %v, wanted EBADF", err) + } + if _, _, _, err := f.Open(p9.ReadOnly); err != syscall.EBADF { + t.Errorf("open got %v, wanted EBADF", err) + } + if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EBADF { + t.Errorf("readAt got %v, wanted EBADF", err) + } + if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EBADF { + t.Errorf("writeAt got %v, wanted EBADF", err) + } + if err := f.FSync(); err != syscall.EBADF { + t.Errorf("fsync got %v, wanted EBADF", err) + } + if _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 0, 0); err != syscall.EBADF { + t.Errorf("create got %v, wanted EBADF", err) + } + if _, err := f.Mkdir("new-directory", 0, 0, 0); err != syscall.EBADF { + t.Errorf("mkdir got %v, wanted EBADF", err) + } + if _, err := f.Symlink("old-name", "new-name", 0, 0); err != syscall.EBADF { + t.Errorf("symlink got %v, wanted EBADF", err) + } + if err := f.Link(root, "new-name"); err != syscall.EBADF { + t.Errorf("link got %v, wanted EBADF", err) + } + if _, err := f.Mknod("new-block-device", 0, 0, 0, 0, 0); err != syscall.EBADF { + t.Errorf("mknod got %v, wanted EBADF", err) + } + if err := f.RenameAt("old-name", root, "new-name"); err != syscall.EBADF { + t.Errorf("renameAt got %v, wanted EBADF", err) + } + if err := f.UnlinkAt("name", 0); err != syscall.EBADF { + t.Errorf("unlinkAt got %v, wanted EBADF", err) + } + if _, err := f.Readdir(0, 1); err != syscall.EBADF { + t.Errorf("readdir got %v, wanted EBADF", err) + } + if _, err := f.Readlink(); err != syscall.EBADF { + t.Errorf("readlink got %v, wanted EBADF", err) + } + if err := f.Flush(); err != syscall.EBADF { + t.Errorf("flush got %v, wanted EBADF", err) + } + if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EBADF { + t.Errorf("walkgetattr got %v, wanted EBADF", err) + } + if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EBADF { + t.Errorf("connect got %v, wanted EBADF", err) + } + }) + } + } +} + +// onlyWorksOnOpenThings is a helper test method for operations that should +// only work on files that have been explicitly opened. +func onlyWorksOnOpenThings(h *Harness, t *testing.T, name string, root p9.File, mode p9.OpenFlags, expectedErr error, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { + // Walk to the file normally. + _, backend, f := walkHelper(h, name, root) + defer f.Close() + + // Does it work before opening? + if err := fn(backend, f, false); err != syscall.EINVAL { + t.Errorf("operation got %v, wanted EINVAL", err) + } + + // Is this openable? + if !p9.CanOpen(backend.Attr.Mode) { + return // Nothing to do. + } + + // If this is a directory, we can't handle writing. + if backend.Attr.Mode.IsDir() && (mode == p9.ReadWrite || mode == p9.WriteOnly) { + return // Skip. + } + + // Open the file. + backend.EXPECT().Open(mode) + if _, _, _, err := f.Open(mode); err != nil { + t.Fatalf("open got %v, wanted nil", err) + } + + // Attempt the operation. + if err := fn(backend, f, expectedErr == nil); err != expectedErr { + t.Fatalf("operation got %v, wanted %v", err, expectedErr) + } +} + +func TestRead(t *testing.T) { + type readTest struct { + name string + mode p9.OpenFlags + err error + } + + cases := []readTest{ + { + name: "read-only", + mode: p9.ReadOnly, + err: nil, + }, + { + name: "read-write", + mode: p9.ReadWrite, + err: nil, + }, + { + name: "write-only", + mode: p9.WriteOnly, + err: syscall.EPERM, + }, + } + + for name := range newTypeMap(nil) { + for _, tc := range cases { + t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + const message = "hello" + + onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if !shouldSucceed { + _, err := f.ReadAt([]byte(message), 0) + return err + } + + // Prepare for the call to readAt in the backend. + backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + copy(p, message) + }).Return(len(message), nil) + + // Make the client call. + p := make([]byte, 2*len(message)) // Double size. + n, err := f.ReadAt(p, 0) + + // Sanity check result. + if err != nil { + return err + } + if n != len(message) { + t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) + } + if !bytes.Equal(p[:n], []byte(message)) { + t.Fatalf("message incorrect, got %v, want %v", p, []byte(message)) + } + return nil // Success. + }) + }) + } + } +} + +func TestWrite(t *testing.T) { + type writeTest struct { + name string + mode p9.OpenFlags + err error + } + + cases := []writeTest{ + { + name: "read-only", + mode: p9.ReadOnly, + err: syscall.EPERM, + }, + { + name: "read-write", + mode: p9.ReadWrite, + err: nil, + }, + { + name: "write-only", + mode: p9.WriteOnly, + err: nil, + }, + } + + for name := range newTypeMap(nil) { + for _, tc := range cases { + t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + const message = "hello" + + onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if !shouldSucceed { + _, err := f.WriteAt([]byte(message), 0) + return err + } + + // Prepare for the call to readAt in the backend. + var output []byte // Saved by Do below. + backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + output = p + }).Return(len(message), nil) + + // Make the client call. + n, err := f.WriteAt([]byte(message), 0) + + // Sanity check result. + if err != nil { + return err + } + if n != len(message) { + t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) + } + if !bytes.Equal(output, []byte(message)) { + t.Fatalf("message incorrect, got %v, want %v", output, []byte(message)) + } + return nil // Success. + }) + }) + } + } +} + +func TestFSync(t *testing.T) { + for name := range newTypeMap(nil) { + for _, mode := range []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} { + t.Run(fmt.Sprintf("%s-%s", mode, name), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + onlyWorksOnOpenThings(h, t, name, root, mode, nil, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if shouldSucceed { + backend.EXPECT().FSync().Times(1) + } + return f.FSync() + }) + }) + } + } +} + +func TestFlush(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + _, backend, f := walkHelper(h, name, root) + defer f.Close() + + backend.EXPECT().Flush() + f.Flush() + }) + } +} + +// onlyWorksOnDirectories is a helper test method for operations that should +// only work on unopened directories, such as create, mkdir and symlink. +func onlyWorksOnDirectories(h *Harness, t *testing.T, name string, root p9.File, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { + // Walk to the file normally. + _, backend, f := walkHelper(h, name, root) + defer f.Close() + + // Only directories support mknod. + if !backend.Attr.Mode.IsDir() { + if err := fn(backend, f, false); err != syscall.EINVAL { + t.Errorf("operation got %v, wanted EINVAL", err) + } + return // Nothing else to do. + } + + // Should succeed. + if err := fn(backend, f, true); err != nil { + t.Fatalf("operation got %v, wanted nil", err) + } + + // Open the directory. + backend.EXPECT().Open(p9.ReadOnly).Times(1) + if _, _, _, err := f.Open(p9.ReadOnly); err != nil { + t.Fatalf("open got %v, wanted nil", err) + } + + // Should not work again. + if err := fn(backend, f, false); err != syscall.EINVAL { + t.Fatalf("operation got %v, wanted EINVAL", err) + } +} + +func TestCreate(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if !shouldSucceed { + _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 1, 2) + return err + } + + // If the create is going to succeed, then we + // need to create a new backend file, and we + // clone to ensure that we don't close the + // original. + _, newF, err := f.Walk(nil) + if err != nil { + t.Fatalf("clone got %v, wanted nil", err) + } + defer newF.Close() + newBackend := h.Pop(newF) + + // Run a regular FD test to validate that path. + fdTest(t, func(send *fd.FD) *fd.FD { + // Return the send FD on success. + newFile := h.NewFile()(backend) // New file with the parent backend. + newBackend.EXPECT().Create("new-file", p9.ReadWrite, p9.FileMode(0), p9.UID(1), p9.GID(2)).Return(send, newFile, p9.QID{}, uint32(0), nil) + + // Receive the fd back. + recv, _, _, _, err := newF.Create("new-file", p9.ReadWrite, 0, 1, 2) + if err != nil { + t.Fatalf("create got %v, wanted nil", err) + } + return recv + }) + + // The above will fail via normal test flow, so + // we can assume that it passed. + return nil + }) + }) + } +} + +func TestCreateInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + for name := range newTypeMap(nil) { + for _, invalidName := range allInvalidNames(name) { + if _, _, _, _, err := root.Create(invalidName, p9.ReadWrite, 0, 0, 0); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + } + } +} + +func TestMkdir(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if shouldSucceed { + backend.EXPECT().Mkdir("new-directory", p9.FileMode(0), p9.UID(1), p9.GID(2)) + } + _, err := f.Mkdir("new-directory", 0, 1, 2) + return err + }) + }) + } +} + +func TestMkdirInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + for name := range newTypeMap(nil) { + for _, invalidName := range allInvalidNames(name) { + if _, err := root.Mkdir(invalidName, 0, 0, 0); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + } + } +} + +func TestSymlink(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if shouldSucceed { + backend.EXPECT().Symlink("old-name", "new-name", p9.UID(1), p9.GID(2)) + } + _, err := f.Symlink("old-name", "new-name", 1, 2) + return err + }) + }) + } +} + +func TestSyminkInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + for name := range newTypeMap(nil) { + for _, invalidName := range allInvalidNames(name) { + // We need only test for invalid names in the new name, + // the target can be an arbitrary string and we don't + // need to sanity check it. + if _, err := root.Symlink("old-name", invalidName, 0, 0); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + } + } +} + +func TestLink(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if shouldSucceed { + backend.EXPECT().Link(gomock.Any(), "new-link") + } + return f.Link(f, "new-link") + }) + }) + } +} + +func TestLinkInvalid(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + for name := range newTypeMap(nil) { + for _, invalidName := range allInvalidNames(name) { + if err := root.Link(root, invalidName); err != syscall.EINVAL { + t.Errorf("got %v for name %q, want EINVAL", err, invalidName) + } + } + } +} + +func TestMknod(t *testing.T) { + for name := range newTypeMap(nil) { + t.Run(name, func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { + if shouldSucceed { + backend.EXPECT().Mknod("new-block-device", p9.FileMode(0), uint32(1), uint32(2), p9.UID(3), p9.GID(4)).Times(1) + } + _, err := f.Mknod("new-block-device", 0, 1, 2, 3, 4) + return err + }) + }) + } +} + +// concurrentFn is a specification of a concurrent operation. This is used to +// drive the concurrency tests below. +type concurrentFn struct { + name string + match func(p9.FileMode) bool + op func(h *Harness, backend *Mock, f p9.File, callback func()) +} + +func concurrentTest(t *testing.T, name string, fn1, fn2 concurrentFn, sameDir, expectedOkay bool) { + var ( + names1 []string + names2 []string + ) + if sameDir { + // Use the same file one directory up. + names1, names2 = []string{"one", name}, []string{"one", name} + } else { + // For different directories, just use siblings. + names1, names2 = []string{"one", name}, []string{"three", name} + } + + t.Run(fmt.Sprintf("%s(%v)+%s(%v)", fn1.name, names1, fn2.name, names2), func(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + // Walk to both files as given. + _, f1, err := root.Walk(names1) + if err != nil { + t.Fatalf("error walking, got %v, want nil", err) + } + defer f1.Close() + b1 := h.Pop(f1) + _, f2, err := root.Walk(names2) + if err != nil { + t.Fatalf("error walking, got %v, want nil", err) + } + defer f2.Close() + b2 := h.Pop(f2) + + // Are these a good match for the current test case? + if !fn1.match(b1.Attr.Mode) { + t.SkipNow() + } + if !fn2.match(b2.Attr.Mode) { + t.SkipNow() + } + + // Construct our "concurrency creator". + in1 := make(chan struct{}, 1) + in2 := make(chan struct{}, 1) + var top sync.WaitGroup + var fns sync.WaitGroup + defer top.Wait() + top.Add(2) // Accounting for below. + defer fns.Done() + fns.Add(1) // See line above; released before top.Wait. + go func() { + defer top.Done() + fn1.op(h, b1, f1, func() { + in1 <- struct{}{} + fns.Wait() + }) + }() + go func() { + defer top.Done() + fn2.op(h, b2, f2, func() { + in2 <- struct{}{} + fns.Wait() + }) + }() + + // Compute a reasonable timeout. If we expect the operation to hang, + // give it 10 milliseconds before we assert that it's fine. After all, + // there will be a lot of these tests. If we don't expect it to hang, + // give it a full minute, since the machine could be slow. + timeout := 10 * time.Millisecond + if expectedOkay { + timeout = 1 * time.Minute + } + + // Read the first channel. + var second chan struct{} + select { + case <-in1: + second = in2 + case <-in2: + second = in1 + } + + // Catch concurrency. + select { + case <-second: + // We finished successful. Is this good? Depends on the + // expected result. + if !expectedOkay { + t.Errorf("%q and %q proceeded concurrently!", fn1.name, fn2.name) + } + case <-time.After(timeout): + // Great, things did not proceed concurrently. Is that what we + // expected? + if expectedOkay { + t.Errorf("%q and %q hung concurrently!", fn1.name, fn2.name) + } + } + }) +} + +func randomFileName() string { + return fmt.Sprintf("%x", rand.Int63()) +} + +func TestConcurrency(t *testing.T) { + readExclusive := []concurrentFn{ + { + // N.B. We can't explicitly check WalkGetAttr behavior, + // but we rely on the fact that the internal code paths + // are the same. + name: "walk", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + // See the documentation of WalkCallback. + // Because walk is actually implemented by the + // mock, we need a special place for this + // callback. + // + // Note that a clone actually locks the parent + // node. So we walk from this node to test + // concurrent operations appropriately. + backend.WalkCallback = func() error { + callback() + return nil + } + f.Walk([]string{randomFileName()}) // Won't exist. + }, + }, + { + name: "fsync", + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Open(gomock.Any()) + backend.EXPECT().FSync().Do(func() { + callback() + }) + f.Open(p9.ReadOnly) // Required. + f.FSync() + }, + }, + { + name: "readdir", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Open(gomock.Any()) + backend.EXPECT().Readdir(gomock.Any(), gomock.Any()).Do(func(uint64, uint32) { + callback() + }) + f.Open(p9.ReadOnly) // Required. + f.Readdir(0, 1) + }, + }, + { + name: "readlink", + match: func(mode p9.FileMode) bool { return mode.IsSymlink() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Readlink().Do(func() { + callback() + }) + f.Readlink() + }, + }, + { + name: "connect", + match: func(mode p9.FileMode) bool { return mode.IsSocket() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Connect(gomock.Any()).Do(func(p9.ConnectFlags) { + callback() + }) + f.Connect(0) + }, + }, + { + name: "open", + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Open(gomock.Any()).Do(func(p9.OpenFlags) { + callback() + }) + f.Open(p9.ReadOnly) + }, + }, + { + name: "flush", + match: func(mode p9.FileMode) bool { return true }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Flush().Do(func() { + callback() + }) + f.Flush() + }, + }, + } + writeExclusive := []concurrentFn{ + { + // N.B. We can't really check getattr. But this is an + // extremely low-risk function, it seems likely that + // this check is paranoid anyways. + name: "setattr", + match: func(mode p9.FileMode) bool { return true }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().SetAttr(gomock.Any(), gomock.Any()).Do(func(p9.SetAttrMask, p9.SetAttr) { + callback() + }) + f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}) + }, + }, + { + name: "unlinkAt", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { + callback() + }) + f.UnlinkAt(randomFileName(), 0) + }, + }, + { + name: "mknod", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Mknod(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, uint32, uint32, p9.UID, p9.GID) { + callback() + }) + f.Mknod(randomFileName(), 0, 0, 0, 0, 0) + }, + }, + { + name: "link", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Link(gomock.Any(), gomock.Any()).Do(func(p9.File, string) { + callback() + }) + f.Link(f, randomFileName()) + }, + }, + { + name: "symlink", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Symlink(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, string, p9.UID, p9.GID) { + callback() + }) + f.Symlink(randomFileName(), randomFileName(), 0, 0) + }, + }, + { + name: "mkdir", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().Mkdir(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, p9.UID, p9.GID) { + callback() + }) + f.Mkdir(randomFileName(), 0, 0, 0) + }, + }, + { + name: "create", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + // Return an error for the creation operation, as this is the simplest. + backend.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil, p9.QID{}, uint32(0), syscall.EINVAL).Do(func(string, p9.OpenFlags, p9.FileMode, p9.UID, p9.GID) { + callback() + }) + f.Create(randomFileName(), p9.ReadOnly, 0, 0, 0) + }, + }, + } + globalExclusive := []concurrentFn{ + { + name: "remove", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + // Remove operates on a locked parent. So we + // add a child, walk to it and call remove. + // Note that because this operation can operate + // concurrently with itself, we need to + // generate a random file name. + randomFile := randomFileName() + backend.AddChild(randomFile, h.NewFile()) + defer backend.RemoveChild(randomFile) + _, file, err := f.Walk([]string{randomFile}) + if err != nil { + h.t.Fatalf("walk got %v, want nil", err) + } + + // Remove is automatically translated to the parent. + backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { + callback() + }) + + // Remove is also a close. + file.(deprecatedRemover).Remove() + }, + }, + { + name: "rename", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + // Similarly to remove, because we need to + // operate on a child, we allow a walk. + randomFile := randomFileName() + backend.AddChild(randomFile, h.NewFile()) + defer backend.RemoveChild(randomFile) + _, file, err := f.Walk([]string{randomFile}) + if err != nil { + h.t.Fatalf("walk got %v, want nil", err) + } + defer file.Close() + fileBackend := h.Pop(file) + + // Rename is automatically translated to the parent. + backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { + callback() + }) + + // Attempt the rename. + fileBackend.EXPECT().Renamed(gomock.Any(), gomock.Any()) + file.Rename(f, randomFileName()) + }, + }, + { + name: "renameAt", + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + op: func(h *Harness, backend *Mock, f p9.File, callback func()) { + backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { + callback() + }) + + // Attempt the rename. There are no active fids + // with this name, so we don't need to expect + // Renamed hooks on anything. + f.RenameAt(randomFileName(), f, randomFileName()) + }, + }, + } + + for _, fn1 := range readExclusive { + for _, fn2 := range readExclusive { + for name := range newTypeMap(nil) { + // Everything should be able to proceed in parallel. + concurrentTest(t, name, fn1, fn2, true, true) + concurrentTest(t, name, fn1, fn2, false, true) + } + } + } + + for _, fn1 := range append(readExclusive, writeExclusive...) { + for _, fn2 := range writeExclusive { + for name := range newTypeMap(nil) { + // Only cross-directory functions should proceed in parallel. + concurrentTest(t, name, fn1, fn2, true, false) + concurrentTest(t, name, fn1, fn2, false, true) + } + } + } + + for _, fn1 := range append(append(readExclusive, writeExclusive...), globalExclusive...) { + for _, fn2 := range globalExclusive { + for name := range newTypeMap(nil) { + // Nothing should be able to run in parallel. + concurrentTest(t, name, fn1, fn2, true, false) + concurrentTest(t, name, fn1, fn2, false, false) + } + } + } +} + +func TestReadWriteConcurrent(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + const ( + instances = 10 + iterations = 10000 + dataSize = 1024 + ) + var ( + dataSets [instances][dataSize]byte + backends [instances]*Mock + files [instances]p9.File + ) + + // Walk to the file normally. + for i := 0; i < instances; i++ { + _, backends[i], files[i] = walkHelper(h, "file", root) + defer files[i].Close() + } + + // Open the files. + for i := 0; i < instances; i++ { + backends[i].EXPECT().Open(p9.ReadWrite) + if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil { + t.Fatalf("open got %v, wanted nil", err) + } + } + + // Initialize random data for each instance. + for i := 0; i < instances; i++ { + if _, err := rand.Read(dataSets[i][:]); err != nil { + t.Fatalf("error initializing dataSet#%d, got %v", i, err) + } + } + + // Define our random read/write mechanism. + randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) { + // Prepare the backend. + backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if n := copy(p, data); n != len(data) { + // Note that we have to assert the result here, as the Return statement + // below cannot be dynamic: it will be bound before this call is made. + h.t.Errorf("wanted length %d, got %d", len(data), n) + } + }).Return(len(data), nil) + + // Execute the read. + if n, err := f.ReadAt(test, 0); n != len(test) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err) + return // No sense doing check below. + } + if !bytes.Equal(test, data) { + t.Errorf("data integrity failed during read") // Not as expected. + } + } + randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) { + // Prepare the backend. + backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if !bytes.Equal(p, data) { + h.t.Errorf("data integrity failed during write") // Not as expected. + } + }).Return(len(data), nil) + + // Execute the write. + if n, err := f.WriteAt(data, 0); n != len(data) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err) + } + } + randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) { + test := make([]byte, len(data)) + for i := 0; i < n; i++ { + if rand.Intn(2) == 0 { + randRead(h, backend, f, data, test) + } else { + randWrite(h, backend, f, data) + } + } + } + + // Start reading and writing. + var wg sync.WaitGroup + for i := 0; i < instances; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:]) + }(i) + } + wg.Wait() +} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go new file mode 100644 index 000000000..dd8b01b6d --- /dev/null +++ b/pkg/p9/p9test/p9test.go @@ -0,0 +1,329 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package p9test provides standard mocks for p9. +package p9test + +import ( + "fmt" + "sync/atomic" + "syscall" + "testing" + + "github.com/golang/mock/gomock" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// Harness is an attacher mock. +type Harness struct { + t *testing.T + mockCtrl *gomock.Controller + Attacher *MockAttacher + wg sync.WaitGroup + clientSocket *unet.Socket + mu sync.Mutex + created []*Mock +} + +// globalPath is a QID.Path Generator. +var globalPath uint64 + +// MakePath returns a globally unique path. +func MakePath() uint64 { + return atomic.AddUint64(&globalPath, 1) +} + +// Generator is a function that generates a new file. +type Generator func(parent *Mock) *Mock + +// Mock is a common mock element. +type Mock struct { + p9.DefaultWalkGetAttr + *MockFile + parent *Mock + closed bool + harness *Harness + QID p9.QID + Attr p9.Attr + children map[string]Generator + + // WalkCallback is a special function that will be called from within + // the walk context. This is needed for the concurrent tests within + // this package. + WalkCallback func() error +} + +// globalMu protects the children maps in all mocks. Note that this is not a +// particularly elegant solution, but because the test has walks from the root +// through to final nodes, we must share maps below, and it's easiest to simply +// protect against concurrent access globally. +var globalMu sync.RWMutex + +// AddChild adds a new child to the Mock. +func (m *Mock) AddChild(name string, generator Generator) { + globalMu.Lock() + defer globalMu.Unlock() + m.children[name] = generator +} + +// RemoveChild removes the child with the given name. +func (m *Mock) RemoveChild(name string) { + globalMu.Lock() + defer globalMu.Unlock() + delete(m.children, name) +} + +// Matches implements gomock.Matcher.Matches. +func (m *Mock) Matches(x interface{}) bool { + if om, ok := x.(*Mock); ok { + return m.QID.Path == om.QID.Path + } + return false +} + +// String implements gomock.Matcher.String. +func (m *Mock) String() string { + return fmt.Sprintf("Mock{Mode: 0x%x, QID.Path: %d}", m.Attr.Mode, m.QID.Path) +} + +// GetAttr returns the current attributes. +func (m *Mock) GetAttr(mask p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { + return m.QID, p9.AttrMaskAll(), m.Attr, nil +} + +// Walk supports clone and walking in directories. +func (m *Mock) Walk(names []string) ([]p9.QID, p9.File, error) { + if m.WalkCallback != nil { + if err := m.WalkCallback(); err != nil { + return nil, nil, err + } + } + if len(names) == 0 { + // Clone the file appropriately. + nm := m.harness.NewMock(m.parent, m.QID.Path, m.Attr) + nm.children = m.children // Inherit children. + return []p9.QID{nm.QID}, nm, nil + } else if len(names) != 1 { + m.harness.t.Fail() // Should not happen. + return nil, nil, syscall.EINVAL + } + + if m.Attr.Mode.IsDir() { + globalMu.RLock() + defer globalMu.RUnlock() + if fn, ok := m.children[names[0]]; ok { + // Generate the child. + nm := fn(m) + return []p9.QID{nm.QID}, nm, nil + } + // No child found. + return nil, nil, syscall.ENOENT + } + + // Call the underlying mock. + return m.MockFile.Walk(names) +} + +// WalkGetAttr calls the default implementation; this is a client-side optimization. +func (m *Mock) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) { + return m.DefaultWalkGetAttr.WalkGetAttr(names) +} + +// Pop pops off the most recently created Mock and assert that this mock +// represents the same file passed in. If nil is passed in, no check is +// performed. +// +// Precondition: there must be at least one Mock or this will panic. +func (h *Harness) Pop(clientFile p9.File) *Mock { + h.mu.Lock() + defer h.mu.Unlock() + + if clientFile == nil { + // If no clientFile is provided, then we always return the last + // created file. The caller can safely use this as long as + // there is no concurrency. + m := h.created[len(h.created)-1] + h.created = h.created[:len(h.created)-1] + return m + } + + qid, _, _, err := clientFile.GetAttr(p9.AttrMaskAll()) + if err != nil { + // We do not expect this to happen. + panic(fmt.Sprintf("err during Pop: %v", err)) + } + + // Find the relevant file in our created list. We must scan the last + // from back to front to ensure that we favor the most recently + // generated file. + for i := len(h.created) - 1; i >= 0; i-- { + m := h.created[i] + if qid.Path == m.QID.Path { + // Copy and truncate. + copy(h.created[i:], h.created[i+1:]) + h.created = h.created[:len(h.created)-1] + return m + } + } + + // Unable to find relevant file. + panic(fmt.Sprintf("unable to locate file with QID %+v", qid.Path)) +} + +// NewMock returns a new base file. +func (h *Harness) NewMock(parent *Mock, path uint64, attr p9.Attr) *Mock { + m := &Mock{ + MockFile: NewMockFile(h.mockCtrl), + parent: parent, + harness: h, + QID: p9.QID{ + Type: p9.QIDType((attr.Mode & p9.FileModeMask) >> 12), + Path: path, + }, + Attr: attr, + } + + // Always ensure Close is after the parent's close. Note that this + // can't be done via a straight-forward After call, because the parent + // might change after initial creation. We ensure that this is true at + // close time. + m.EXPECT().Close().Return(nil).Times(1).Do(func() { + if m.parent != nil && m.parent.closed { + h.t.FailNow() + } + // Note that this should not be racy, as this operation should + // be protected by the Times(1) above first. + m.closed = true + }) + + // Remember what was created. + h.mu.Lock() + defer h.mu.Unlock() + h.created = append(h.created, m) + + return m +} + +// NewFile returns a new file mock. +// +// Note that ReadAt and WriteAt must be mocked separately. +func (h *Harness) NewFile() Generator { + return func(parent *Mock) *Mock { + return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeRegular}) + } +} + +// NewDirectory returns a new mock directory. +// +// Note that Mkdir, Link, Mknod, RenameAt, UnlinkAt and Readdir must be mocked +// separately. Walk is provided and children may be manipulated via AddChild +// and RemoveChild. After calling Walk remotely, one can use Pop to find the +// corresponding backend mock on the server side. +func (h *Harness) NewDirectory(contents map[string]Generator) Generator { + return func(parent *Mock) *Mock { + m := h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeDirectory}) + m.children = contents // Save contents. + return m + } +} + +// NewSymlink returns a new mock directory. +// +// Note that Readlink must be mocked separately. +func (h *Harness) NewSymlink() Generator { + return func(parent *Mock) *Mock { + return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSymlink}) + } +} + +// NewBlockDevice returns a new mock block device. +func (h *Harness) NewBlockDevice() Generator { + return func(parent *Mock) *Mock { + return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeBlockDevice}) + } +} + +// NewCharacterDevice returns a new mock character device. +func (h *Harness) NewCharacterDevice() Generator { + return func(parent *Mock) *Mock { + return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeCharacterDevice}) + } +} + +// NewNamedPipe returns a new mock named pipe. +func (h *Harness) NewNamedPipe() Generator { + return func(parent *Mock) *Mock { + return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeNamedPipe}) + } +} + +// NewSocket returns a new mock socket. +func (h *Harness) NewSocket() Generator { + return func(parent *Mock) *Mock { + return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSocket}) + } +} + +// Finish completes all checks and shuts down the server. +func (h *Harness) Finish() { + h.clientSocket.Shutdown() + h.wg.Wait() + h.mockCtrl.Finish() +} + +// NewHarness creates and returns a new test server. +// +// It should always be used as: +// +// h, c := NewHarness(t) +// defer h.Finish() +// +func NewHarness(t *testing.T) (*Harness, *p9.Client) { + // Create the mock. + mockCtrl := gomock.NewController(t) + h := &Harness{ + t: t, + mockCtrl: mockCtrl, + Attacher: NewMockAttacher(mockCtrl), + } + + // Make socket pair. + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v wanted nil", err) + } + + // Start the server, synchronized on exit. + server := p9.NewServer(h.Attacher) + h.wg.Add(1) + go func() { + defer h.wg.Done() + server.Handle(serverSocket) + }() + + // Create the client. + client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString()) + if err != nil { + serverSocket.Close() + clientSocket.Close() + t.Fatalf("new client got %v, expected nil", err) + return nil, nil // Never hit. + } + + // Capture the client socket. + h.clientSocket = clientSocket + return h, client +} diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go new file mode 100644 index 000000000..72ef53313 --- /dev/null +++ b/pkg/p9/path_tree.go @@ -0,0 +1,222 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" +) + +// pathNode is a single node in a path traversal. +// +// These are shared by all fidRefs that point to the same path. +// +// Lock ordering: +// opMu +// childMu +// +// Two different pathNodes may only be locked if Server.renameMu is held for +// write, in which case they can be acquired in any order. +type pathNode struct { + // opMu synchronizes high-level, sematic operations, such as the + // simultaneous creation and deletion of a file. + // + // opMu does not directly protect any fields in pathNode. + opMu sync.RWMutex + + // childMu protects the fields below. + childMu sync.RWMutex + + // childNodes maps child path component names to their pathNode. + childNodes map[string]*pathNode + + // childRefs maps child path component names to all of the their + // references. + childRefs map[string]map[*fidRef]struct{} + + // childRefNames maps child references back to their path component + // name. + childRefNames map[*fidRef]string +} + +func newPathNode() *pathNode { + return &pathNode{ + childNodes: make(map[string]*pathNode), + childRefs: make(map[string]map[*fidRef]struct{}), + childRefNames: make(map[*fidRef]string), + } +} + +// forEachChildRef calls fn for each child reference. +func (p *pathNode) forEachChildRef(fn func(ref *fidRef, name string)) { + p.childMu.RLock() + defer p.childMu.RUnlock() + + for name, m := range p.childRefs { + for ref := range m { + fn(ref, name) + } + } +} + +// forEachChildNode calls fn for each child pathNode. +func (p *pathNode) forEachChildNode(fn func(pn *pathNode)) { + p.childMu.RLock() + defer p.childMu.RUnlock() + + for _, pn := range p.childNodes { + fn(pn) + } +} + +// pathNodeFor returns the path node for the given name, or a new one. +func (p *pathNode) pathNodeFor(name string) *pathNode { + p.childMu.RLock() + // Fast path, node already exists. + if pn, ok := p.childNodes[name]; ok { + p.childMu.RUnlock() + return pn + } + p.childMu.RUnlock() + + // Slow path, create a new pathNode for shared use. + p.childMu.Lock() + + // Re-check after re-lock. + if pn, ok := p.childNodes[name]; ok { + p.childMu.Unlock() + return pn + } + + pn := newPathNode() + p.childNodes[name] = pn + p.childMu.Unlock() + return pn +} + +// nameFor returns the name for the given fidRef. +// +// Precondition: addChild is called for ref before nameFor. +func (p *pathNode) nameFor(ref *fidRef) string { + p.childMu.RLock() + n, ok := p.childRefNames[ref] + p.childMu.RUnlock() + + if !ok { + // This should not happen, don't proceed. + panic(fmt.Sprintf("expected name for %+v, none found", ref)) + } + + return n +} + +// addChildLocked adds a child reference to p. +// +// Precondition: As addChild, plus childMu is locked for write. +func (p *pathNode) addChildLocked(ref *fidRef, name string) { + if n, ok := p.childRefNames[ref]; ok { + // This should not happen, don't proceed. + panic(fmt.Sprintf("unexpected fidRef %+v with path %q, wanted %q", ref, n, name)) + } + + p.childRefNames[ref] = name + + m, ok := p.childRefs[name] + if !ok { + m = make(map[*fidRef]struct{}) + p.childRefs[name] = m + } + + m[ref] = struct{}{} +} + +// addChild adds a child reference to p. +// +// Precondition: ref may only be added once at a time. +func (p *pathNode) addChild(ref *fidRef, name string) { + p.childMu.Lock() + p.addChildLocked(ref, name) + p.childMu.Unlock() +} + +// removeChild removes the given child. +// +// This applies only to an individual fidRef, which is not required to exist. +func (p *pathNode) removeChild(ref *fidRef) { + p.childMu.Lock() + + // This ref may not exist anymore. This can occur, e.g., in unlink, + // where a removeWithName removes the ref, and then a DecRef on the ref + // attempts to remove again. + if name, ok := p.childRefNames[ref]; ok { + m, ok := p.childRefs[name] + if !ok { + // This should not happen, don't proceed. + p.childMu.Unlock() + panic(fmt.Sprintf("name %s missing from childfidRefs", name)) + } + + delete(m, ref) + if len(m) == 0 { + delete(p.childRefs, name) + } + } + + delete(p.childRefNames, ref) + + p.childMu.Unlock() +} + +// addPathNodeFor adds an existing pathNode as the node for name. +// +// Preconditions: newName does not exist. +func (p *pathNode) addPathNodeFor(name string, pn *pathNode) { + p.childMu.Lock() + + if opn, ok := p.childNodes[name]; ok { + p.childMu.Unlock() + panic(fmt.Sprintf("unexpected pathNode %+v with path %q", opn, name)) + } + + p.childNodes[name] = pn + p.childMu.Unlock() +} + +// removeWithName removes all references with the given name. +// +// The provided function is executed after reference removal. The only method +// it may (transitively) call on this pathNode is addChildLocked. +// +// If a child pathNode for name exists, it is removed from this pathNode and +// returned by this function. Any operations on the removed tree must use this +// value. +func (p *pathNode) removeWithName(name string, fn func(ref *fidRef)) *pathNode { + p.childMu.Lock() + defer p.childMu.Unlock() + + if m, ok := p.childRefs[name]; ok { + for ref := range m { + delete(m, ref) + delete(p.childRefNames, ref) + fn(ref) + } + } + + // Return the original path node, if it exists. + origPathNode := p.childNodes[name] + delete(p.childNodes, name) + return origPathNode +} diff --git a/pkg/p9/server.go b/pkg/p9/server.go new file mode 100644 index 000000000..fdfa83648 --- /dev/null +++ b/pkg/p9/server.go @@ -0,0 +1,694 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "io" + "runtime/debug" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// Server is a 9p2000.L server. +type Server struct { + // attacher provides the attach function. + attacher Attacher + + // pathTree is the full set of paths opened on this server. + // + // These may be across different connections, but rename operations + // must be serialized globally for safely. There is a single pathTree + // for the entire server, and not per connection. + pathTree *pathNode + + // renameMu is a global lock protecting rename operations. With this + // lock, we can be certain that any given rename operation can safely + // acquire two path nodes in any order, as all other concurrent + // operations acquire at most a single node. + renameMu sync.RWMutex +} + +// NewServer returns a new server. +func NewServer(attacher Attacher) *Server { + return &Server{ + attacher: attacher, + pathTree: newPathNode(), + } +} + +// connState is the state for a single connection. +type connState struct { + // server is the backing server. + server *Server + + // sendMu is the send lock. + sendMu sync.Mutex + + // conn is the connection. + conn *unet.Socket + + // fids is the set of active FIDs. + // + // This is used to find FIDs for files. + fidMu sync.Mutex + fids map[FID]*fidRef + + // tags is the set of active tags. + // + // The given channel is closed when the + // tag is finished with processing. + tagMu sync.Mutex + tags map[Tag]chan struct{} + + // messageSize is the maximum message size. The server does not + // do automatic splitting of messages. + messageSize uint32 + + // version is the agreed upon version X of 9P2000.L.Google.X. + // version 0 implies 9P2000.L. + version uint32 + + // -- below relates to the legacy handler -- + + // recvOkay indicates that a receive may start. + recvOkay chan bool + + // recvDone is signalled when a message is received. + recvDone chan error + + // sendDone is signalled when a send is finished. + sendDone chan error + + // -- below relates to the flipcall handler -- + + // channelMu protects below. + channelMu sync.Mutex + + // channelWg represents active workers. + channelWg sync.WaitGroup + + // channelAlloc allocates channel memory. + channelAlloc *flipcall.PacketWindowAllocator + + // channels are the set of initialized channels. + channels []*channel +} + +// fidRef wraps a node and tracks references. +type fidRef struct { + // server is the associated server. + server *Server + + // file is the associated File. + file File + + // refs is an active refence count. + // + // The node above will be closed only when refs reaches zero. + refs int64 + + // openedMu protects opened and openFlags. + openedMu sync.Mutex + + // opened indicates whether this has been opened already. + // + // This is updated in handlers.go. + opened bool + + // mode is the fidRef's mode from the walk. Only the type bits are + // valid, the permissions may change. This is used to sanity check + // operations on this element, and prevent walks across + // non-directories. + mode FileMode + + // openFlags is the mode used in the open. + // + // This is updated in handlers.go. + openFlags OpenFlags + + // pathNode is the current pathNode for this FID. + pathNode *pathNode + + // parent is the parent fidRef. We hold on to a parent reference to + // ensure that hooks, such as Renamed, can be executed safely by the + // server code. + // + // Note that parent cannot be changed without holding both the global + // rename lock and a writable lock on the associated pathNode for this + // fidRef. Holding either of these locks is sufficient to examine + // parent safely. + // + // The parent will be nil for root fidRefs, and non-nil otherwise. The + // method maybeParent can be used to return a cyclical reference, and + // isRoot should be used to check for root over looking at parent + // directly. + parent *fidRef + + // deleted indicates that the backing file has been deleted. We stop + // many operations at the API level if they are incompatible with a + // file that has already been unlinked. + deleted uint32 +} + +// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously. +func (f *fidRef) OpenFlags() (OpenFlags, bool) { + f.openedMu.Lock() + defer f.openedMu.Unlock() + return f.openFlags, f.opened +} + +// IncRef increases the references on a fid. +func (f *fidRef) IncRef() { + atomic.AddInt64(&f.refs, 1) +} + +// DecRef should be called when you're finished with a fid. +func (f *fidRef) DecRef() { + if atomic.AddInt64(&f.refs, -1) == 0 { + f.file.Close() + + // Drop the parent reference. + // + // Since this fidRef is guaranteed to be non-discoverable when + // the references reach zero, we don't need to worry about + // clearing the parent. + if f.parent != nil { + // If we've been previously deleted, this removing this + // ref is a no-op. That's expected. + f.parent.pathNode.removeChild(f) + f.parent.DecRef() + } + } +} + +// isDeleted returns true if this fidRef has been deleted. +func (f *fidRef) isDeleted() bool { + return atomic.LoadUint32(&f.deleted) != 0 +} + +// isRoot indicates whether this is a root fid. +func (f *fidRef) isRoot() bool { + return f.parent == nil +} + +// maybeParent returns a cyclic reference for roots, and the parent otherwise. +func (f *fidRef) maybeParent() *fidRef { + if f.parent != nil { + return f.parent + } + return f // Root has itself. +} + +// notifyDelete marks all fidRefs as deleted. +// +// Precondition: this must be called via safelyWrite or safelyGlobal. +func notifyDelete(pn *pathNode) { + // Call on all local references. + pn.forEachChildRef(func(ref *fidRef, _ string) { + atomic.StoreUint32(&ref.deleted, 1) + }) + + // Call on all subtrees. + pn.forEachChildNode(func(pn *pathNode) { + notifyDelete(pn) + }) +} + +// markChildDeleted marks all children below the given name as deleted. +// +// Precondition: this must be called via safelyWrite or safelyGlobal. +func (f *fidRef) markChildDeleted(name string) { + origPathNode := f.pathNode.removeWithName(name, func(ref *fidRef) { + atomic.StoreUint32(&ref.deleted, 1) + }) + + if origPathNode != nil { + // Mark all children as deleted. + notifyDelete(origPathNode) + } +} + +// notifyNameChange calls the relevant Renamed method on all nodes in the path, +// recursively. Note that this applies only for subtrees, as these +// notifications do not apply to the actual file whose name has changed. +// +// Precondition: this must be called via safelyGlobal. +func notifyNameChange(pn *pathNode) { + // Call on all local references. + pn.forEachChildRef(func(ref *fidRef, name string) { + ref.file.Renamed(ref.parent.file, name) + }) + + // Call on all subtrees. + pn.forEachChildNode(func(pn *pathNode) { + notifyNameChange(pn) + }) +} + +// renameChildTo renames the given child to the target. +// +// Precondition: this must be called via safelyGlobal. +func (f *fidRef) renameChildTo(oldName string, target *fidRef, newName string) { + target.markChildDeleted(newName) + origPathNode := f.pathNode.removeWithName(oldName, func(ref *fidRef) { + // N.B. DecRef can take f.pathNode's parent's childMu. This is + // allowed because renameMu is held for write via safelyGlobal. + ref.parent.DecRef() // Drop original reference. + ref.parent = target // Change parent. + ref.parent.IncRef() // Acquire new one. + if f.pathNode == target.pathNode { + target.pathNode.addChildLocked(ref, newName) + } else { + target.pathNode.addChild(ref, newName) + } + ref.file.Renamed(target.file, newName) + }) + + if origPathNode != nil { + // Replace the previous (now deleted) path node. + target.pathNode.addPathNodeFor(newName, origPathNode) + // Call Renamed on all children. + notifyNameChange(origPathNode) + } +} + +// safelyRead executes the given operation with the local path node locked. +// This implies that paths will not change during the operation. +func (f *fidRef) safelyRead(fn func() error) (err error) { + f.server.renameMu.RLock() + defer f.server.renameMu.RUnlock() + f.pathNode.opMu.RLock() + defer f.pathNode.opMu.RUnlock() + return fn() +} + +// safelyWrite executes the given operation with the local path node locked in +// a writable fashion. This implies some paths may change. +func (f *fidRef) safelyWrite(fn func() error) (err error) { + f.server.renameMu.RLock() + defer f.server.renameMu.RUnlock() + f.pathNode.opMu.Lock() + defer f.pathNode.opMu.Unlock() + return fn() +} + +// safelyGlobal executes the given operation with the global path lock held. +func (f *fidRef) safelyGlobal(fn func() error) (err error) { + f.server.renameMu.Lock() + defer f.server.renameMu.Unlock() + return fn() +} + +// LookupFID finds the given FID. +// +// You should call fid.DecRef when you are finished using the fid. +func (cs *connState) LookupFID(fid FID) (*fidRef, bool) { + cs.fidMu.Lock() + defer cs.fidMu.Unlock() + fidRef, ok := cs.fids[fid] + if ok { + fidRef.IncRef() + return fidRef, true + } + return nil, false +} + +// InsertFID installs the given FID. +// +// This fid starts with a reference count of one. If a FID exists in +// the slot already it is closed, per the specification. +func (cs *connState) InsertFID(fid FID, newRef *fidRef) { + cs.fidMu.Lock() + defer cs.fidMu.Unlock() + origRef, ok := cs.fids[fid] + if ok { + defer origRef.DecRef() + } + newRef.IncRef() + cs.fids[fid] = newRef +} + +// DeleteFID removes the given FID. +// +// This simply removes it from the map and drops a reference. +func (cs *connState) DeleteFID(fid FID) bool { + cs.fidMu.Lock() + defer cs.fidMu.Unlock() + fidRef, ok := cs.fids[fid] + if !ok { + return false + } + delete(cs.fids, fid) + fidRef.DecRef() + return true +} + +// StartTag starts handling the tag. +// +// False is returned if this tag is already active. +func (cs *connState) StartTag(t Tag) bool { + cs.tagMu.Lock() + defer cs.tagMu.Unlock() + _, ok := cs.tags[t] + if ok { + return false + } + cs.tags[t] = make(chan struct{}) + return true +} + +// ClearTag finishes handling a tag. +func (cs *connState) ClearTag(t Tag) { + cs.tagMu.Lock() + defer cs.tagMu.Unlock() + ch, ok := cs.tags[t] + if !ok { + // Should never happen. + panic("unused tag cleared") + } + delete(cs.tags, t) + + // Notify. + close(ch) +} + +// WaitTag waits for a tag to finish. +func (cs *connState) WaitTag(t Tag) { + cs.tagMu.Lock() + ch, ok := cs.tags[t] + cs.tagMu.Unlock() + if !ok { + return + } + + // Wait for close. + <-ch +} + +// initializeChannels initializes all channels. +// +// This is a no-op if channels are already initialized. +func (cs *connState) initializeChannels() (err error) { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + + // Initialize our channel allocator. + if cs.channelAlloc == nil { + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return err + } + cs.channelAlloc = alloc + } + + // Create all the channels. + for len(cs.channels) < channelsPerClient { + res := &channel{ + done: make(chan struct{}), + } + + res.desc, err = cs.channelAlloc.Allocate(channelSize) + if err != nil { + return err + } + if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil { + return err + } + + socks, err := fdchannel.NewConnectedSockets() + if err != nil { + res.data.Destroy() // Cleanup. + return err + } + res.fds.Init(socks[0]) + res.client = fd.New(socks[1]) + + cs.channels = append(cs.channels, res) + + // Start servicing the channel. + // + // When we call stop, we will close all the channels and these + // routines should finish. We need the wait group to ensure + // that active handlers are actually finished before cleanup. + cs.channelWg.Add(1) + go func() { // S/R-SAFE: Server side. + defer cs.channelWg.Done() + if err := res.service(cs); err != nil { + // Don't log flipcall.ShutdownErrors, which we expect to be + // returned during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("p9.channel.service: %v", err) + } + } + }() + } + + return nil +} + +// lookupChannel looks up the channel with given id. +// +// The function returns nil if no such channel is available. +func (cs *connState) lookupChannel(id uint32) *channel { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + if id >= uint32(len(cs.channels)) { + return nil + } + return cs.channels[id] +} + +// handle handles a single message. +func (cs *connState) handle(m message) (r message) { + defer func() { + if r == nil { + // Don't allow a panic to propagate. + recover() + + // Include a useful log message. + log.Warningf("panic in handler: %s", debug.Stack()) + + // Wrap in an EFAULT error; we don't really have a + // better way to describe this kind of error. It will + // usually manifest as a result of the test framework. + r = newErr(syscall.EFAULT) + } + }() + if handler, ok := m.(handler); ok { + // Call the message handler. + r = handler.handle(cs) + } else { + // Produce an ENOSYS error. + r = newErr(syscall.ENOSYS) + } + return +} + +// handleRequest handles a single request. +// +// The recvDone channel is signaled when recv is done (with a error if +// necessary). The sendDone channel is signaled with the result of the send. +func (cs *connState) handleRequest() { + messageSize := atomic.LoadUint32(&cs.messageSize) + if messageSize == 0 { + // Default or not yet negotiated. + messageSize = maximumLength + } + + // Receive a message. + tag, m, err := recv(cs.conn, messageSize, msgRegistry.get) + if errSocket, ok := err.(ErrSocket); ok { + // Connection problem; stop serving. + cs.recvDone <- errSocket.error + return + } + + // Signal receive is done. + cs.recvDone <- nil + + // Deal with other errors. + if err != nil && err != io.EOF { + // If it's not a connection error, but some other protocol error, + // we can send a response immediately. + cs.sendMu.Lock() + err := send(cs.conn, tag, newErr(err)) + cs.sendMu.Unlock() + cs.sendDone <- err + return + } + + // Try to start the tag. + if !cs.StartTag(tag) { + // Nothing we can do at this point; client is bogus. + log.Debugf("no valid tag [%05d]", tag) + cs.sendDone <- ErrNoValidMessage + return + } + + // Handle the message. + r := cs.handle(m) + + // Clear the tag before sending. That's because as soon as this hits + // the wire, the client can legally send the same tag. + cs.ClearTag(tag) + + // Send back the result. + cs.sendMu.Lock() + err = send(cs.conn, tag, r) + cs.sendMu.Unlock() + cs.sendDone <- err + + // Return the message to the cache. + msgRegistry.put(m) +} + +func (cs *connState) handleRequests() { + for range cs.recvOkay { + cs.handleRequest() + } +} + +func (cs *connState) stop() { + // Close all channels. + close(cs.recvOkay) + close(cs.recvDone) + close(cs.sendDone) + + // Free the channels. + cs.channelMu.Lock() + for _, ch := range cs.channels { + ch.Shutdown() + } + cs.channelWg.Wait() + for _, ch := range cs.channels { + ch.Close() + } + cs.channels = nil // Clear. + cs.channelMu.Unlock() + + // Free the channel memory. + if cs.channelAlloc != nil { + cs.channelAlloc.Destroy() + } + + // Close all remaining fids. + for fid, fidRef := range cs.fids { + delete(cs.fids, fid) + + // Drop final reference in the FID table. Note this should + // always close the file, since we've ensured that there are no + // handlers running via the wait for Pending => 0 below. + fidRef.DecRef() + } + + // Ensure the connection is closed. + cs.conn.Close() +} + +// service services requests concurrently. +func (cs *connState) service() error { + // Pending is the number of handlers that have finished receiving but + // not finished processing requests. These must be waiting on properly + // below. See the next comment for an explanation of the loop. + pending := 0 + + // Start the first request handler. + go cs.handleRequests() // S/R-SAFE: Irrelevant. + cs.recvOkay <- true + + // We loop and make sure there's always one goroutine waiting for a new + // request. We process all the data for a single request in one + // goroutine however, to ensure the best turnaround time possible. + for { + select { + case err := <-cs.recvDone: + if err != nil { + // Wait for pending handlers. + for i := 0; i < pending; i++ { + <-cs.sendDone + } + return nil + } + + // This handler is now pending. + pending++ + + // Kick the next receiver, or start a new handler + // if no receiver is currently waiting. + select { + case cs.recvOkay <- true: + default: + go cs.handleRequests() // S/R-SAFE: Irrelevant. + cs.recvOkay <- true + } + + case <-cs.sendDone: + // This handler is finished. + pending-- + + // Error sending a response? Nothing can be done. + // + // We don't terminate on a send error though, since + // we still have a pending receive. The error would + // have been logged above, we just ignore it here. + } + } +} + +// Handle handles a single connection. +func (s *Server) Handle(conn *unet.Socket) error { + cs := &connState{ + server: s, + conn: conn, + fids: make(map[FID]*fidRef), + tags: make(map[Tag]chan struct{}), + recvOkay: make(chan bool), + recvDone: make(chan error, 10), + sendDone: make(chan error, 10), + } + defer cs.stop() + return cs.service() +} + +// Serve handles requests from the bound socket. +// +// The passed serverSocket _must_ be created in packet mode. +func (s *Server) Serve(serverSocket *unet.ServerSocket) error { + var wg sync.WaitGroup + defer wg.Wait() + + for { + conn, err := serverSocket.Accept() + if err != nil { + // Something went wrong. + // + // Socket closed? + return err + } + + wg.Add(1) + go func(conn *unet.Socket) { // S/R-SAFE: Irrelevant. + s.Handle(conn) + wg.Done() + }(conn) + } +} diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go new file mode 100644 index 000000000..7cec0e86d --- /dev/null +++ b/pkg/p9/transport.go @@ -0,0 +1,345 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// ErrSocket is returned in cases of a socket issue. +// +// This may be treated differently than other errors. +type ErrSocket struct { + // error is the socket error. + error +} + +// ErrMessageTooLarge indicates the size was larger than reasonable. +type ErrMessageTooLarge struct { + size uint32 + msize uint32 +} + +// Error returns a sensible error. +func (e *ErrMessageTooLarge) Error() string { + return fmt.Sprintf("message too large for fixed buffer: size is %d, limit is %d", e.size, e.msize) +} + +// ErrNoValidMessage indicates no valid message could be decoded. +var ErrNoValidMessage = errors.New("buffer contained no valid message") + +const ( + // headerLength is the number of bytes required for a header. + headerLength uint32 = 7 + + // maximumLength is the largest possible message. + maximumLength uint32 = 1 << 20 + + // DefaultMessageSize is a sensible default. + DefaultMessageSize uint32 = 64 << 10 + + // initialBufferLength is the initial data buffer we allocate. + initialBufferLength uint32 = 64 +) + +var dataPool = sync.Pool{ + New: func() interface{} { + // These buffers are used for decoding without a payload. + return make([]byte, initialBufferLength) + }, +} + +// send sends the given message over the socket. +func send(s *unet.Socket, tag Tag, m message) error { + data := dataPool.Get().([]byte) + dataBuf := buffer{data: data[:0]} + + if log.IsLogging(log.Debug) { + log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String()) + } + + // Encode the message. The buffer will grow automatically. + m.encode(&dataBuf) + + // Get our vectors to send. + var hdr [headerLength]byte + vecs := make([][]byte, 0, 3) + vecs = append(vecs, hdr[:]) + if len(dataBuf.data) > 0 { + vecs = append(vecs, dataBuf.data) + } + totalLength := headerLength + uint32(len(dataBuf.data)) + + // Is there a payload? + if payloader, ok := m.(payloader); ok { + p := payloader.Payload() + if len(p) > 0 { + vecs = append(vecs, p) + totalLength += uint32(len(p)) + } + } + + // Construct the header. + headerBuf := buffer{data: hdr[:0]} + headerBuf.Write32(totalLength) + headerBuf.WriteMsgType(m.Type()) + headerBuf.WriteTag(tag) + + // Pack any files if necessary. + w := s.Writer(true) + if filer, ok := m.(filer); ok { + if f := filer.FilePayload(); f != nil { + defer f.Close() + // Pack the file into the message. + w.PackFDs(f.FD()) + } + } + + for n := 0; n < int(totalLength); { + cur, err := w.WriteVec(vecs) + if err != nil { + return ErrSocket{err} + } + n += cur + + // Consume iovecs. + for consumed := 0; consumed < cur; { + if len(vecs[0]) <= cur-consumed { + consumed += len(vecs[0]) + vecs = vecs[1:] + } else { + vecs[0] = vecs[0][cur-consumed:] + break + } + } + + if n > 0 && n < int(totalLength) { + // Don't resend any control message. + w.UnpackFDs() + } + } + + // All set. + dataPool.Put(dataBuf.data) + return nil +} + +// lookupTagAndType looks up an existing message or creates a new one. +// +// This is called by recv after decoding the header. Any error returned will be +// propagating back to the caller. You may use messageByType directly as a +// lookupTagAndType function (by design). +type lookupTagAndType func(tag Tag, t MsgType) (message, error) + +// recv decodes a message from the socket. +// +// This is done in two parts, and is thus not safe for multiple callers. +// +// On a socket error, the special error type ErrSocket is returned. +// +// The tag value NoTag will always be returned if err is non-nil. +func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, error) { + // Read a header. + // + // Since the send above is atomic, we must always receive control + // messages along with the header. This means we need to be careful + // about closing FDs during errors to prevent leaks. + var hdr [headerLength]byte + r := s.Reader(true) + r.EnableFDs(1) + + n, err := r.ReadVec([][]byte{hdr[:]}) + if err != nil && (n == 0 || err != io.EOF) { + r.CloseFDs() + return NoTag, nil, ErrSocket{err} + } + + fds, err := r.ExtractFDs() + if err != nil { + return NoTag, nil, ErrSocket{err} + } + defer func() { + // Close anything left open. The case where + // fds are caught and used is handled below, + // and the fds variable will be set to nil. + for _, fd := range fds { + syscall.Close(fd) + } + }() + r.EnableFDs(0) + + // Continuing reading for a short header. + for n < int(headerLength) { + cur, err := r.ReadVec([][]byte{hdr[n:]}) + if err != nil && (cur == 0 || err != io.EOF) { + return NoTag, nil, ErrSocket{err} + } + n += cur + } + + // Decode the header. + headerBuf := buffer{data: hdr[:]} + size := headerBuf.Read32() + t := headerBuf.ReadMsgType() + tag := headerBuf.ReadTag() + if size < headerLength { + // The message is too small. + // + // See above: it's probably screwed. + return NoTag, nil, ErrSocket{ErrNoValidMessage} + } + if size > maximumLength || size > msize { + // The message is too big. + return NoTag, nil, ErrSocket{&ErrMessageTooLarge{size, msize}} + } + remaining := size - headerLength + + // Find our message to decode. + m, err := lookup(tag, t) + if err != nil { + // Throw away the contents of this message. + if remaining > 0 { + io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)}) + } + return tag, nil, err + } + + // Not yet initialized. + var dataBuf buffer + + // Read the rest of the payload. + // + // This requires some special care to ensure that the vectors all line + // up the way they should. We do this to minimize copying data around. + var vecs [][]byte + if payloader, ok := m.(payloader); ok { + fixedSize := payloader.FixedSize() + + // Do we need more than there is? + if fixedSize > remaining { + // This is not a valid message. + if remaining > 0 { + io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)}) + } + return NoTag, nil, ErrNoValidMessage + } + + if fixedSize != 0 { + // Pull a data buffer from the pool. + data := dataPool.Get().([]byte) + if int(fixedSize) > len(data) { + // Create a larger data buffer, ensuring + // sufficient capicity for the message. + data = make([]byte, fixedSize) + defer dataPool.Put(data) + dataBuf = buffer{data: data} + vecs = append(vecs, data) + } else { + // Limit the data buffer, and make sure it + // gets filled before the payload buffer. + defer dataPool.Put(data) + dataBuf = buffer{data: data[:fixedSize]} + vecs = append(vecs, data[:fixedSize]) + } + } + + // Include the payload. + p := payloader.Payload() + if p == nil || len(p) != int(remaining-fixedSize) { + p = make([]byte, remaining-fixedSize) + payloader.SetPayload(p) + } + if len(p) > 0 { + vecs = append(vecs, p) + } + } else if remaining != 0 { + // Pull a data buffer from the pool. + data := dataPool.Get().([]byte) + if int(remaining) > len(data) { + // Create a larger data buffer. + data = make([]byte, remaining) + defer dataPool.Put(data) + dataBuf = buffer{data: data} + vecs = append(vecs, data) + } else { + // Limit the data buffer. + defer dataPool.Put(data) + dataBuf = buffer{data: data[:remaining]} + vecs = append(vecs, data[:remaining]) + } + } + + if len(vecs) > 0 { + // Read the rest of the message. + // + // No need to handle a control message. + r := s.Reader(true) + for n := 0; n < int(remaining); { + cur, err := r.ReadVec(vecs) + if err != nil && (cur == 0 || err != io.EOF) { + return NoTag, nil, ErrSocket{err} + } + n += cur + + // Consume iovecs. + for consumed := 0; consumed < cur; { + if len(vecs[0]) <= cur-consumed { + consumed += len(vecs[0]) + vecs = vecs[1:] + } else { + vecs[0] = vecs[0][cur-consumed:] + break + } + } + } + } + + // Decode the message data. + m.decode(&dataBuf) + if dataBuf.isOverrun() { + // No need to drain the socket. + return NoTag, nil, ErrNoValidMessage + } + + // Save the file, if any came out. + if filer, ok := m.(filer); ok && len(fds) > 0 { + // Set the file object. + filer.SetFilePayload(fd.New(fds[0])) + + // Close the rest. We support only one. + for i := 1; i < len(fds); i++ { + syscall.Close(fds[i]) + } + + // Don't close in the defer. + fds = nil + } + + if log.IsLogging(log.Debug) { + log.Debugf("recv [FD %d] [Tag %06d] %s", s.FD(), tag, m.String()) + } + + // All set. + return tag, m, nil +} diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go new file mode 100644 index 000000000..38038abdf --- /dev/null +++ b/pkg/p9/transport_flipcall.go @@ -0,0 +1,243 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "runtime" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +// channelsPerClient is the number of channels to create per client. +// +// While the client and server will generally agree on this number, in reality +// it's completely up to the server. We simply define a minimum of 2, and a +// maximum of 4, and select the number of available processes as a tie-breaker. +// Note that we don't want the number of channels to be too large, because each +// will account for channelSize memory used, which can be large. +var channelsPerClient = func() int { + n := runtime.NumCPU() + if n < 2 { + return 2 + } + if n > 4 { + return 4 + } + return n +}() + +// channelSize is the channel size to create. +// +// We simply ensure that this is larger than the largest possible message size, +// plus the flipcall packet header, plus the two bytes we write below. +const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength) + +// channel is a fast IPC channel. +// +// The same object is used by both the server and client implementations. In +// general, the client will use only the send and recv methods. +type channel struct { + desc flipcall.PacketWindowDescriptor + data flipcall.Endpoint + fds fdchannel.Endpoint + buf buffer + + // -- client only -- + connected bool + active bool + + // -- server only -- + client *fd.FD + done chan struct{} +} + +// reset resets the channel buffer. +func (ch *channel) reset(sz uint32) { + ch.buf.data = ch.data.Data()[:sz] +} + +// service services the channel. +func (ch *channel) service(cs *connState) error { + rsz, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rsz > 0 { + m, err := ch.recv(nil, rsz) + if err != nil { + return err + } + r := cs.handle(m) + msgRegistry.put(m) + rsz, err = ch.send(r) + if err != nil { + return err + } + } + return nil // Done. +} + +// Shutdown shuts down the channel. +// +// This must be called before Close. +func (ch *channel) Shutdown() { + ch.data.Shutdown() +} + +// Close closes the channel. +// +// This must only be called once, and cannot return an error. Note that +// synchronization for this method is provided at a high-level, depending on +// whether it is the client or server. This cannot be called while there are +// active callers in either service or sendRecv. +// +// Precondition: the channel should be shutdown. +func (ch *channel) Close() error { + // Close all backing transports. + ch.fds.Destroy() + ch.data.Destroy() + if ch.client != nil { + ch.client.Close() + } + return nil +} + +// send sends the given message. +// +// The return value is the size of the received response. Not that in the +// server case, this is the size of the next request. +func (ch *channel) send(m message) (uint32, error) { + if log.IsLogging(log.Debug) { + log.Debugf("send [channel @%p] %s", ch, m.String()) + } + + // Send any file payload. + sentFD := false + if filer, ok := m.(filer); ok { + if f := filer.FilePayload(); f != nil { + if err := ch.fds.SendFD(f.FD()); err != nil { + return 0, err + } + f.Close() // Per sendRecvLegacy. + sentFD = true // To mark below. + } + } + + // Encode the message. + // + // Note that IPC itself encodes the length of messages, so we don't + // need to encode a standard 9P header. We write only the message type. + ch.reset(0) + + ch.buf.WriteMsgType(m.Type()) + if sentFD { + ch.buf.Write8(1) // Incoming FD. + } else { + ch.buf.Write8(0) // No incoming FD. + } + m.encode(&ch.buf) + ssz := uint32(len(ch.buf.data)) // Updated below. + + // Is there a payload? + if payloader, ok := m.(payloader); ok { + p := payloader.Payload() + copy(ch.data.Data()[ssz:], p) + ssz += uint32(len(p)) + } + + // Perform the one-shot communication. + return ch.data.SendRecv(ssz) +} + +// recv decodes a message that exists on the channel. +// +// If the passed r is non-nil, then the type must match or an error will be +// generated. If the passed r is nil, then a new message will be created and +// returned. +func (ch *channel) recv(r message, rsz uint32) (message, error) { + // Decode the response from the inline buffer. + ch.reset(rsz) + t := ch.buf.ReadMsgType() + hasFD := ch.buf.Read8() != 0 + if t == MsgRlerror { + // Change the message type. We check for this special case + // after decoding below, and transform into an error. + r = &Rlerror{} + } else if r == nil { + nr, err := msgRegistry.get(0, t) + if err != nil { + return nil, err + } + r = nr // New message. + } else if t != r.Type() { + // Not an error and not the expected response; propagate. + return nil, &ErrBadResponse{Got: t, Want: r.Type()} + } + + // Is there a payload? Copy from the latter portion. + if payloader, ok := r.(payloader); ok { + fs := payloader.FixedSize() + p := payloader.Payload() + payloadData := ch.buf.data[fs:] + if len(p) < len(payloadData) { + p = make([]byte, len(payloadData)) + copy(p, payloadData) + payloader.SetPayload(p) + } else if n := copy(p, payloadData); n < len(p) { + payloader.SetPayload(p[:n]) + } + ch.buf.data = ch.buf.data[:fs] + } + + r.decode(&ch.buf) + if ch.buf.isOverrun() { + // Nothing valid was available. + log.Debugf("recv [got %d bytes, needed more]", rsz) + return nil, ErrNoValidMessage + } + + // Read any FD result. + if hasFD { + if rfd, err := ch.fds.RecvFDNonblock(); err == nil { + f := fd.New(rfd) + if filer, ok := r.(filer); ok { + // Set the payload. + filer.SetFilePayload(f) + } else { + // Don't want the FD. + f.Close() + } + } else { + // The header bit was set but nothing came in. + log.Warningf("expected FD, got err: %v", err) + } + } + + // Log a message. + if log.IsLogging(log.Debug) { + log.Debugf("recv [channel @%p] %s", ch, r.String()) + } + + // Convert errors appropriately; see above. + if rlerr, ok := r.(*Rlerror); ok { + return r, syscall.Errno(rlerr.Error) + } + + return r, nil +} diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go new file mode 100644 index 000000000..3668fcad7 --- /dev/null +++ b/pkg/p9/transport_test.go @@ -0,0 +1,231 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "io/ioutil" + "os" + "testing" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + MsgTypeBadEncode = iota + 252 + MsgTypeBadDecode + MsgTypeUnregistered +) + +func TestSendRecv(t *testing.T) { + server, client, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer server.Close() + defer client.Close() + + if err := send(client, Tag(1), &Tlopen{}); err != nil { + t.Fatalf("send got err %v expected nil", err) + } + + tag, m, err := recv(server, maximumLength, msgRegistry.get) + if err != nil { + t.Fatalf("recv got err %v expected nil", err) + } + if tag != Tag(1) { + t.Fatalf("got tag %v expected 1", tag) + } + if _, ok := m.(*Tlopen); !ok { + t.Fatalf("got message %v expected *Tlopen", m) + } +} + +// badDecode overruns on decode. +type badDecode struct{} + +func (*badDecode) decode(b *buffer) { b.markOverrun() } +func (*badDecode) encode(b *buffer) {} +func (*badDecode) Type() MsgType { return MsgTypeBadDecode } +func (*badDecode) String() string { return "badDecode{}" } + +func TestRecvOverrun(t *testing.T) { + server, client, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer server.Close() + defer client.Close() + + if err := send(client, Tag(1), &badDecode{}); err != nil { + t.Fatalf("send got err %v expected nil", err) + } + + if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil { + t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err) + } +} + +// unregistered is not registered on decode. +type unregistered struct{} + +func (*unregistered) decode(b *buffer) {} +func (*unregistered) encode(b *buffer) {} +func (*unregistered) Type() MsgType { return MsgTypeUnregistered } +func (*unregistered) String() string { return "unregistered{}" } + +func TestRecvInvalidType(t *testing.T) { + server, client, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer server.Close() + defer client.Close() + + if err := send(client, Tag(1), &unregistered{}); err != nil { + t.Fatalf("send got err %v expected nil", err) + } + + _, _, err = recv(server, maximumLength, msgRegistry.get) + if _, ok := err.(*ErrInvalidMsgType); !ok { + t.Fatalf("recv got err %v expected ErrInvalidMsgType", err) + } +} + +func TestSendRecvWithFile(t *testing.T) { + server, client, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer server.Close() + defer client.Close() + + // Create a tempfile. + osf, err := ioutil.TempFile("", "p9") + if err != nil { + t.Fatalf("tempfile got err %v expected nil", err) + } + os.Remove(osf.Name()) + f, err := fd.NewFromFile(osf) + osf.Close() + if err != nil { + t.Fatalf("unable to create file: %v", err) + } + + rlopen := &Rlopen{} + rlopen.SetFilePayload(f) + if err := send(client, Tag(1), rlopen); err != nil { + t.Fatalf("send got err %v expected nil", err) + } + + // Enable withFile. + tag, m, err := recv(server, maximumLength, msgRegistry.get) + if err != nil { + t.Fatalf("recv got err %v expected nil", err) + } + if tag != Tag(1) { + t.Fatalf("got tag %v expected 1", tag) + } + rlopen, ok := m.(*Rlopen) + if !ok { + t.Fatalf("got m %v expected *Rlopen", m) + } + if rlopen.File == nil { + t.Fatalf("got nil file expected non-nil") + } +} + +func TestRecvClosed(t *testing.T) { + server, client, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer server.Close() + client.Close() + + _, _, err = recv(server, maximumLength, msgRegistry.get) + if err == nil { + t.Fatalf("got err nil expected non-nil") + } + if _, ok := err.(ErrSocket); !ok { + t.Fatalf("got err %v expected ErrSocket", err) + } +} + +func TestSendClosed(t *testing.T) { + server, client, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + server.Close() + defer client.Close() + + err = send(client, Tag(1), &Tlopen{}) + if err == nil { + t.Fatalf("send got err nil expected non-nil") + } + if _, ok := err.(ErrSocket); !ok { + t.Fatalf("got err %v expected ErrSocket", err) + } +} + +func BenchmarkSendRecv(b *testing.B) { + server, client, err := unet.SocketPair(false) + if err != nil { + b.Fatalf("socketpair got err %v expected nil", err) + } + defer server.Close() + defer client.Close() + + // Exchange Rflush messages since these contain no data and therefore incur + // no additional marshaling overhead. + go func() { + for i := 0; i < b.N; i++ { + tag, m, err := recv(server, maximumLength, msgRegistry.get) + if err != nil { + b.Fatalf("recv got err %v expected nil", err) + } + if tag != Tag(1) { + b.Fatalf("got tag %v expected 1", tag) + } + if _, ok := m.(*Rflush); !ok { + b.Fatalf("got message %T expected *Rflush", m) + } + if err := send(server, Tag(2), &Rflush{}); err != nil { + b.Fatalf("send got err %v expected nil", err) + } + } + }() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := send(client, Tag(1), &Rflush{}); err != nil { + b.Fatalf("send got err %v expected nil", err) + } + tag, m, err := recv(client, maximumLength, msgRegistry.get) + if err != nil { + b.Fatalf("recv got err %v expected nil", err) + } + if tag != Tag(2) { + b.Fatalf("got tag %v expected 2", tag) + } + if _, ok := m.(*Rflush); !ok { + b.Fatalf("got message %v expected *Rflush", m) + } + } +} + +func init() { + msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} }) +} diff --git a/pkg/p9/version.go b/pkg/p9/version.go new file mode 100644 index 000000000..09cde9f5a --- /dev/null +++ b/pkg/p9/version.go @@ -0,0 +1,175 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "fmt" + "strconv" + "strings" +) + +const ( + // highestSupportedVersion is the highest supported version X in a + // version string of the format 9P2000.L.Google.X. + // + // Clients are expected to start requesting this version number and + // to continuously decrement it until a Tversion request succeeds. + highestSupportedVersion uint32 = 11 + + // lowestSupportedVersion is the lowest supported version X in a + // version string of the format 9P2000.L.Google.X. + // + // Clients are free to send a Tversion request at a version below this + // value but are expected to encounter an Rlerror in response. + lowestSupportedVersion uint32 = 0 + + // baseVersion is the base version of 9P that this package must always + // support. It is equivalent to 9P2000.L.Google.0. + baseVersion = "9P2000.L" +) + +// HighestVersionString returns the highest possible version string that a client +// may request or a server may support. +func HighestVersionString() string { + return versionString(highestSupportedVersion) +} + +// parseVersion parses a Tversion version string into a numeric version number +// if the version string is supported by p9. Otherwise returns (0, false). +// +// From Tversion(9P): "Version strings are defined such that, if the client string +// contains one or more period characters, the initial substring up to but not +// including any single period in the version string defines a version of the protocol." +// +// p9 intentionally diverges from this and always requires that the version string +// start with 9P2000.L to express that it is always compatible with 9P2000.L. The +// only supported versions extensions are of the format 9p2000.L.Google.X where X +// is an ever increasing version counter. +// +// Version 9P2000.L.Google.0 implies 9P2000.L. +// +// New versions must always be a strict superset of 9P2000.L. A version increase must +// define a predicate representing the feature extension introduced by that version. The +// predicate must be commented and should take the format: +// +// // VersionSupportsX returns true if version v supports X and must be checked when ... +// func VersionSupportsX(v int32) bool { +// ... +// ) +func parseVersion(str string) (uint32, bool) { + // Special case the base version which lacks the ".Google.X" suffix. This + // version always means version 0. + if str == baseVersion { + return 0, true + } + substr := strings.Split(str, ".") + if len(substr) != 4 { + return 0, false + } + if substr[0] != "9P2000" || substr[1] != "L" || substr[2] != "Google" || len(substr[3]) == 0 { + return 0, false + } + version, err := strconv.ParseUint(substr[3], 10, 32) + if err != nil { + return 0, false + } + return uint32(version), true +} + +// versionString formats a p9 version number into a Tversion version string. +func versionString(version uint32) string { + // Special case the base version so that clients expecting this string + // instead of the 9P2000.L.Google.0 equivalent get it. This is important + // for backwards compatibility with legacy servers that check for exactly + // the baseVersion and allow nothing else. + if version == 0 { + return baseVersion + } + return fmt.Sprintf("9P2000.L.Google.%d", version) +} + +// VersionSupportsTflushf returns true if version v supports the Tflushf message. +// This predicate must be checked by clients before attempting to make a Tflushf +// request. If this predicate returns false, then clients may safely no-op. +func VersionSupportsTflushf(v uint32) bool { + return v >= 1 +} + +// versionSupportsTwalkgetattr returns true if version v supports the +// Twalkgetattr message. This predicate must be checked by clients before +// attempting to make a Twalkgetattr request. +func versionSupportsTwalkgetattr(v uint32) bool { + return v >= 2 +} + +// versionSupportsTucreation returns true if version v supports the Tucreation +// messages (Tucreate, Tusymlink, Tumkdir, Tumknod). This predicate must be +// checked by clients before attempting to make a Tucreation request. +// If Tucreation messages are not supported, their non-UID supporting +// counterparts (Tlcreate, Tsymlink, Tmkdir, Tmknod) should be used. +func versionSupportsTucreation(v uint32) bool { + return v >= 3 +} + +// VersionSupportsConnect returns true if version v supports the Tlconnect +// message. This predicate must be checked by clients +// before attempting to make a Tlconnect request. If Tlconnect messages are not +// supported, Tlopen should be used. +func VersionSupportsConnect(v uint32) bool { + return v >= 4 +} + +// VersionSupportsAnonymous returns true if version v supports Tlconnect +// with the AnonymousSocket mode. This predicate must be checked by clients +// before attempting to use the AnonymousSocket Tlconnect mode. +func VersionSupportsAnonymous(v uint32) bool { + return v >= 5 +} + +// VersionSupportsMultiUser returns true if version v supports multi-user fake +// directory permissions and ID values. +func VersionSupportsMultiUser(v uint32) bool { + return v >= 6 +} + +// versionSupportsTallocate returns true if version v supports Allocate(). +func versionSupportsTallocate(v uint32) bool { + return v >= 7 +} + +// versionSupportsFlipcall returns true if version v supports IPC channels from +// the flipcall package. Note that these must be negotiated, but this version +// string indicates that such a facility exists. +func versionSupportsFlipcall(v uint32) bool { + return v >= 8 +} + +// VersionSupportsOpenTruncateFlag returns true if version v supports +// passing the OpenTruncate flag to Tlopen. +func VersionSupportsOpenTruncateFlag(v uint32) bool { + return v >= 9 +} + +// versionSupportsGetSetXattr returns true if version v supports +// the Tgetxattr and Tsetxattr messages. +func versionSupportsGetSetXattr(v uint32) bool { + return v >= 10 +} + +// versionSupportsListRemoveXattr returns true if version v supports +// the Tlistxattr and Tremovexattr messages. +func versionSupportsListRemoveXattr(v uint32) bool { + return v >= 11 +} diff --git a/pkg/p9/version_test.go b/pkg/p9/version_test.go new file mode 100644 index 000000000..291e8580e --- /dev/null +++ b/pkg/p9/version_test.go @@ -0,0 +1,145 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "testing" +) + +func TestVersionNumberEquivalent(t *testing.T) { + for i := uint32(0); i < 1024; i++ { + str := versionString(i) + version, ok := parseVersion(str) + if !ok { + t.Errorf("#%d: parseVersion(%q) failed, want success", i, str) + continue + } + if i != version { + t.Errorf("#%d: got version %d, want %d", i, i, version) + } + } +} + +func TestVersionStringEquivalent(t *testing.T) { + // There is one case where the version is not equivalent on purpose, + // that is 9P2000.L.Google.0. It is not equivalent because versionString + // must always return the more generic 9P2000.L for legacy servers that + // check for it. See net/9p/client.c. + str := "9P2000.L.Google.0" + version, ok := parseVersion(str) + if !ok { + t.Errorf("parseVersion(%q) failed, want success", str) + } + if got := versionString(version); got != "9P2000.L" { + t.Errorf("versionString(%d) got %q, want %q", version, got, "9P2000.L") + } + + for _, test := range []struct { + versionString string + }{ + { + versionString: "9P2000.L", + }, + { + versionString: "9P2000.L.Google.1", + }, + { + versionString: "9P2000.L.Google.347823894", + }, + } { + version, ok := parseVersion(test.versionString) + if !ok { + t.Errorf("parseVersion(%q) failed, want success", test.versionString) + continue + } + if got := versionString(version); got != test.versionString { + t.Errorf("versionString(%d) got %q, want %q", version, got, test.versionString) + } + } +} + +func TestParseVersion(t *testing.T) { + for _, test := range []struct { + versionString string + expectSuccess bool + expectedVersion uint32 + }{ + { + versionString: "9P", + expectSuccess: false, + }, + { + versionString: "9P.L", + expectSuccess: false, + }, + { + versionString: "9P200.L", + expectSuccess: false, + }, + { + versionString: "9P2000", + expectSuccess: false, + }, + { + versionString: "9P2000.L.Google.-1", + expectSuccess: false, + }, + { + versionString: "9P2000.L.Google.", + expectSuccess: false, + }, + { + versionString: "9P2000.L.Google.3546343826724305832", + expectSuccess: false, + }, + { + versionString: "9P2001.L", + expectSuccess: false, + }, + { + versionString: "9P2000.L", + expectSuccess: true, + expectedVersion: 0, + }, + { + versionString: "9P2000.L.Google.0", + expectSuccess: true, + expectedVersion: 0, + }, + { + versionString: "9P2000.L.Google.1", + expectSuccess: true, + expectedVersion: 1, + }, + } { + version, ok := parseVersion(test.versionString) + if ok != test.expectSuccess { + t.Errorf("parseVersion(%q) got (_, %v), want (_, %v)", test.versionString, ok, test.expectSuccess) + continue + } + if !test.expectSuccess { + continue + } + if version != test.expectedVersion { + t.Errorf("parseVersion(%q) got (%d, _), want (%d, _)", test.versionString, version, test.expectedVersion) + } + } +} + +func BenchmarkParseVersion(b *testing.B) { + for n := 0; n < b.N; n++ { + parseVersion("9P2000.L.Google.1") + } +} |