summaryrefslogtreecommitdiffhomepage
path: root/pkg/p9
diff options
context:
space:
mode:
authorAndrei Vagin <avagin@google.com>2019-10-02 13:00:07 -0700
committerGitHub <noreply@github.com>2019-10-02 13:00:07 -0700
commit9a875306dbabcf335a2abccc08119a1b67d0e51a (patch)
tree0f72c12e951a5eee7156df7a5d63351bc89befa6 /pkg/p9
parent38bc0b6b6addd25ceec4f66ef1af41c1e61e2985 (diff)
parent03ce4dd86c9acd6b6148f68d5d2cf025d8c254bb (diff)
Merge branch 'master' into pr_syscall_linux
Diffstat (limited to 'pkg/p9')
-rw-r--r--pkg/p9/BUILD6
-rw-r--r--pkg/p9/client.go273
-rw-r--r--pkg/p9/client_test.go50
-rw-r--r--pkg/p9/handlers.go53
-rw-r--r--pkg/p9/messages.go103
-rw-r--r--pkg/p9/p9.go2
-rw-r--r--pkg/p9/p9test/BUILD6
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/p9/p9test/p9test.go4
-rw-r--r--pkg/p9/server.go178
-rw-r--r--pkg/p9/transport.go5
-rw-r--r--pkg/p9/transport_flipcall.go263
-rw-r--r--pkg/p9/transport_test.go4
-rw-r--r--pkg/p9/version.go9
14 files changed, 957 insertions, 94 deletions
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index c6737bf97..f32244c69 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
@@ -19,11 +20,14 @@ go_library(
"pool.go",
"server.go",
"transport.go",
+ "transport_flipcall.go",
"version.go",
],
importpath = "gvisor.dev/gvisor/pkg/p9",
deps = [
"//pkg/fd",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 7dc20aeef..2412aa5e1 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -20,6 +20,8 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -77,6 +79,47 @@ type Client struct {
// fidPool is the collection of available fids.
fidPool pool
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
// pending is the set of pending messages.
pending map[Tag]*response
pendingMu sync.Mutex
@@ -89,25 +132,12 @@ type Client struct {
// Whoever writes to this channel is permitted to call recv. When
// finished calling recv, this channel should be emptied.
recvr chan bool
-
- // messageSize is the maximum total size of a message.
- messageSize uint32
-
- // payloadSize is the maximum payload size of a read or write
- // request. For large reads and writes this means that the
- // read or write is broken up into buffer-size/payloadSize
- // requests.
- payloadSize uint32
-
- // version is the agreed upon version X of 9P2000.L.Google.X.
- // version 0 implies 9P2000.L.
- version uint32
}
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
return nil, ErrBadVersionString
}
for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion)
+ err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
// The server told us to try again with a lower version.
if err == syscall.EAGAIN {
@@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.version = version
break
}
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacy
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacy
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
// handleOne handles a single incoming message.
//
// This should only be called with the token from recvr. Note that the received
@@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
-// sendRecv performs a roundtrip message exchange.
+// sendRecvLegacy performs a roundtrip message exchange.
//
// This is called by internal functions.
-func (c *Client) sendRecv(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) error {
tag, ok := c.tagPool.Get()
if !ok {
return ErrOutOfTags
@@ -296,12 +479,62 @@ func (c *Client) sendRecv(t message, r message) error {
return nil
}
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacy(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ return err
+ }
+ }
+
+ // Send the message.
+ err := ch.sendRecv(c, t, r)
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return err
+}
+
// Version returns the negotiated 9P2000.L.Google version number.
func (c *Client) Version() uint32 {
return c.version
}
-// Close closes the underlying socket.
-func (c *Client) Close() error {
- return c.socket.Close()
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 87b2dd61e..29a0afadf 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -35,23 +35,23 @@ func TestVersion(t *testing.T) {
go s.Handle(serverSocket)
// NewClient does a Tversion exchange, so this is our test for success.
- c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString())
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
if err != nil {
t.Fatalf("got %v, expected nil", err)
}
// Check a bogus version string.
- if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a bogus version number.
- if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a too high version number.
- if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN {
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN {
t.Errorf("got %v expected %v", err, syscall.EAGAIN)
}
@@ -60,3 +60,45 @@ func TestVersion(t *testing.T) {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
}
+
+func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ // See above.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // See above.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // See above.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ b.Fatalf("got %v, expected nil", err)
+ }
+
+ // Initialize messages.
+ sendRecv := fn(c)
+ tversion := &Tversion{
+ Version: versionString(highestSupportedVersion),
+ MSize: DefaultMessageSize,
+ }
+ rversion := new(Rversion)
+
+ // Run in a loop.
+ for i := 0; i < b.N; i++ {
+ if err := sendRecv(tversion, rversion); err != nil {
+ b.Fatalf("got unexpected err: %v", err)
+ }
+ }
+}
+
+func BenchmarkSendRecvLegacy(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+}
+
+func BenchmarkSendRecvChannel(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel })
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 999b4f684..ba9a55d6d 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message {
ref.opened = true
ref.openFlags = t.Flags
- return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}
+ rlopen := &Rlopen{QID: qid, IoUnit: ioUnit}
+ rlopen.SetFilePayload(osFile)
+ return rlopen
}
func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
@@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
// Replace the FID reference.
cs.InsertFID(t.FID, newRef)
- return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil
+ rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}}
+ rlcreate.SetFilePayload(osFile)
+ return rlcreate, nil
}
// handle implements handler.handle.
@@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message {
return newErr(err)
}
- return &Rlconnect{File: osFile}
+ rlconnect := &Rlconnect{}
+ rlconnect.SetFilePayload(osFile)
+ return rlconnect
+}
+
+// handle implements handler.handle.
+func (t *Tchannel) handle(cs *connState) message {
+ // Ensure that channels are enabled.
+ if err := cs.initializeChannels(); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the given channel.
+ ch := cs.lookupChannel(t.ID)
+ if ch == nil {
+ return newErr(syscall.ENOSYS)
+ }
+
+ // Return the payload. Note that we need to duplicate the file
+ // descriptor for the channel allocator, because sending is a
+ // destructive operation between sendRecvLegacy (and now the newer
+ // channel send operations). Same goes for the client FD.
+ rchannel := &Rchannel{
+ Offset: uint64(ch.desc.Offset),
+ Length: uint64(ch.desc.Length),
+ }
+ switch t.Control {
+ case 0:
+ // Open the main data channel.
+ mfd, err := syscall.Dup(int(cs.channelAlloc.FD()))
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(mfd))
+ case 1:
+ cfd, err := syscall.Dup(ch.client.FD())
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(cfd))
+ default:
+ return newErr(syscall.EINVAL)
+ }
+ return rchannel
}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index fd9eb1c5d..ffdd7e8c6 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -64,6 +64,21 @@ type filer interface {
SetFilePayload(*fd.FD)
}
+// filePayload embeds a File object.
+type filePayload struct {
+ File *fd.FD
+}
+
+// FilePayload returns the file payload.
+func (f *filePayload) FilePayload() *fd.FD {
+ return f.File
+}
+
+// SetFilePayload sets the received file.
+func (f *filePayload) SetFilePayload(file *fd.FD) {
+ f.File = file
+}
+
// Tversion is a version request.
type Tversion struct {
// MSize is the message size to use.
@@ -524,10 +539,7 @@ type Rlopen struct {
// IoUnit is the recommended I/O unit.
IoUnit uint32
- // File may be attached via the socket.
- //
- // This is an extension specific to this package.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType {
return MsgRlopen
}
-// FilePayload returns the file payload.
-func (r *Rlopen) FilePayload() *fd.FD {
- return r.File
-}
-
-// SetFilePayload sets the received file.
-func (r *Rlopen) SetFilePayload(file *fd.FD) {
- r.File = file
-}
-
// String implements fmt.Stringer.
func (r *Rlopen) String() string {
return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
@@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string {
// Rlconnect is a connect response.
type Rlconnect struct {
- // File is a host socket.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType {
return MsgRlconnect
}
-// FilePayload returns the file payload.
-func (r *Rlconnect) FilePayload() *fd.FD {
- return r.File
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
}
-// SetFilePayload sets the received file.
-func (r *Rlconnect) SetFilePayload(file *fd.FD) {
- r.File = file
+// Tchannel creates a new channel.
+type Tchannel struct {
+ // ID is the channel ID.
+ ID uint32
+
+ // Control is 0 if the Rchannel response should provide the flipcall
+ // component of the channel, and 1 if the Rchannel response should
+ // provide the fdchannel component of the channel.
+ Control uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tchannel) Decode(b *buffer) {
+ t.ID = b.Read32()
+ t.Control = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tchannel) Encode(b *buffer) {
+ b.Write32(t.ID)
+ b.Write32(t.Control)
+}
+
+// Type implements message.Type.
+func (*Tchannel) Type() MsgType {
+ return MsgTchannel
}
// String implements fmt.Stringer.
-func (r *Rlconnect) String() string {
- return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+func (t *Tchannel) String() string {
+ return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control)
+}
+
+// Rchannel is the channel response.
+type Rchannel struct {
+ Offset uint64
+ Length uint64
+ filePayload
+}
+
+// Decode implements encoder.Decode.
+func (r *Rchannel) Decode(b *buffer) {
+ r.Offset = b.Read64()
+ r.Length = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rchannel) Encode(b *buffer) {
+ b.Write64(r.Offset)
+ b.Write64(r.Length)
+}
+
+// Type implements message.Type.
+func (*Rchannel) Type() MsgType {
+ return MsgRchannel
+}
+
+// String implements fmt.Stringer.
+func (r *Rchannel) String() string {
+ return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
}
const maxCacheSize = 3
@@ -2356,4 +2409,6 @@ func init() {
msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
+ msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index e12831dbd..25530adca 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -378,6 +378,8 @@ const (
MsgRlconnect = 137
MsgTallocate = 138
MsgRallocate = 139
+ MsgTchannel = 250
+ MsgRchannel = 251
)
// QIDType represents the file type for QIDs.
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 6e939a49a..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
package(licenses = ["notice"])
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 95846e5f7..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
@@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) {
}()
// Create the client.
- client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString())
+ client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString())
if err != nil {
serverSocket.Close()
clientSocket.Close()
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index b294efbb0..69c886a5d 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -21,6 +21,9 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -45,7 +48,6 @@ type Server struct {
}
// NewServer returns a new server.
-//
func NewServer(attacher Attacher) *Server {
return &Server{
attacher: attacher,
@@ -85,6 +87,8 @@ type connState struct {
// version 0 implies 9P2000.L.
version uint32
+ // -- below relates to the legacy handler --
+
// recvOkay indicates that a receive may start.
recvOkay chan bool
@@ -93,6 +97,20 @@ type connState struct {
// sendDone is signalled when a send is finished.
sendDone chan error
+
+ // -- below relates to the flipcall handler --
+
+ // channelMu protects below.
+ channelMu sync.Mutex
+
+ // channelWg represents active workers.
+ channelWg sync.WaitGroup
+
+ // channelAlloc allocates channel memory.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ // channels are the set of initialized channels.
+ channels []*channel
}
// fidRef wraps a node and tracks references.
@@ -386,6 +404,99 @@ func (cs *connState) WaitTag(t Tag) {
<-ch
}
+// initializeChannels initializes all channels.
+//
+// This is a no-op if channels are already initialized.
+func (cs *connState) initializeChannels() (err error) {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+
+ // Initialize our channel allocator.
+ if cs.channelAlloc == nil {
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return err
+ }
+ cs.channelAlloc = alloc
+ }
+
+ // Create all the channels.
+ for len(cs.channels) < channelsPerClient {
+ res := &channel{
+ done: make(chan struct{}),
+ }
+
+ res.desc, err = cs.channelAlloc.Allocate(channelSize)
+ if err != nil {
+ return err
+ }
+ if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil {
+ return err
+ }
+
+ socks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ res.data.Destroy() // Cleanup.
+ return err
+ }
+ res.fds.Init(socks[0])
+ res.client = fd.New(socks[1])
+
+ cs.channels = append(cs.channels, res)
+
+ // Start servicing the channel.
+ //
+ // When we call stop, we will close all the channels and these
+ // routines should finish. We need the wait group to ensure
+ // that active handlers are actually finished before cleanup.
+ cs.channelWg.Add(1)
+ go func() { // S/R-SAFE: Server side.
+ defer cs.channelWg.Done()
+ res.service(cs)
+ }()
+ }
+
+ return nil
+}
+
+// lookupChannel looks up the channel with given id.
+//
+// The function returns nil if no such channel is available.
+func (cs *connState) lookupChannel(id uint32) *channel {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+ if id >= uint32(len(cs.channels)) {
+ return nil
+ }
+ return cs.channels[id]
+}
+
+// handle handles a single message.
+func (cs *connState) handle(m message) (r message) {
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %s", debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ return
+}
+
// handleRequest handles a single request.
//
// The recvDone channel is signaled when recv is done (with a error if
@@ -428,41 +539,20 @@ func (cs *connState) handleRequest() {
}
// Handle the message.
- var r message // r is the response.
- defer func() {
- if r == nil {
- // Don't allow a panic to propagate.
- recover()
+ r := cs.handle(m)
- // Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ // Clear the tag before sending. That's because as soon as this hits
+ // the wire, the client can legally send the same tag.
+ cs.ClearTag(tag)
- // Wrap in an EFAULT error; we don't really have a
- // better way to describe this kind of error. It will
- // usually manifest as a result of the test framework.
- r = newErr(syscall.EFAULT)
- }
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
- // Clear the tag before sending. That's because as soon as this
- // hits the wire, the client can legally send another message
- // with the same tag.
- cs.ClearTag(tag)
-
- // Send back the result.
- cs.sendMu.Lock()
- err = send(cs.conn, tag, r)
- cs.sendMu.Unlock()
- cs.sendDone <- err
- }()
- if handler, ok := m.(handler); ok {
- // Call the message handler.
- r = handler.handle(cs)
- } else {
- // Produce an ENOSYS error.
- r = newErr(syscall.ENOSYS)
- }
+ // Return the message to the cache.
msgRegistry.put(m)
- m = nil // 'm' should not be touched after this point.
}
func (cs *connState) handleRequests() {
@@ -477,7 +567,27 @@ func (cs *connState) stop() {
close(cs.recvDone)
close(cs.sendDone)
- for _, fidRef := range cs.fids {
+ // Free the channels.
+ cs.channelMu.Lock()
+ for _, ch := range cs.channels {
+ ch.Shutdown()
+ }
+ cs.channelWg.Wait()
+ for _, ch := range cs.channels {
+ ch.Close()
+ }
+ cs.channels = nil // Clear.
+ cs.channelMu.Unlock()
+
+ // Free the channel memory.
+ if cs.channelAlloc != nil {
+ cs.channelAlloc.Destroy()
+ }
+
+ // Close all remaining fids.
+ for fid, fidRef := range cs.fids {
+ delete(cs.fids, fid)
+
// Drop final reference in the FID table. Note this should
// always close the file, since we've ensured that there are no
// handlers running via the wait for Pending => 0 below.
@@ -510,7 +620,7 @@ func (cs *connState) service() error {
for i := 0; i < pending; i++ {
<-cs.sendDone
}
- return err
+ return nil
}
// This handler is now pending.
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 5648df589..6e8b4bbcd 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -54,7 +54,10 @@ const (
headerLength uint32 = 7
// maximumLength is the largest possible message.
- maximumLength uint32 = 4 * 1024 * 1024
+ maximumLength uint32 = 1 << 20
+
+ // DefaultMessageSize is a sensible default.
+ DefaultMessageSize uint32 = 64 << 10
// initialBufferLength is the initial data buffer we allocate.
initialBufferLength uint32 = 64
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
new file mode 100644
index 000000000..7cdf4ecc3
--- /dev/null
+++ b/pkg/p9/transport_flipcall.go
@@ -0,0 +1,263 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// channelsPerClient is the number of channels to create per client.
+//
+// While the client and server will generally agree on this number, in reality
+// it's completely up to the server. We simply define a minimum of 2, and a
+// maximum of 4, and select the number of available processes as a tie-breaker.
+// Note that we don't want the number of channels to be too large, because each
+// will account for channelSize memory used, which can be large.
+var channelsPerClient = func() int {
+ n := runtime.NumCPU()
+ if n < 2 {
+ return 2
+ }
+ if n > 4 {
+ return 4
+ }
+ return n
+}()
+
+// channelSize is the channel size to create.
+//
+// We simply ensure that this is larger than the largest possible message size,
+// plus the flipcall packet header, plus the two bytes we write below.
+const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
+
+// channel is a fast IPC channel.
+//
+// The same object is used by both the server and client implementations. In
+// general, the client will use only the send and recv methods.
+type channel struct {
+ desc flipcall.PacketWindowDescriptor
+ data flipcall.Endpoint
+ fds fdchannel.Endpoint
+ buf buffer
+
+ // -- client only --
+ connected bool
+ active bool
+
+ // -- server only --
+ client *fd.FD
+ done chan struct{}
+}
+
+// reset resets the channel buffer.
+func (ch *channel) reset(sz uint32) {
+ ch.buf.data = ch.data.Data()[:sz]
+}
+
+// service services the channel.
+func (ch *channel) service(cs *connState) error {
+ rsz, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rsz > 0 {
+ m, err := ch.recv(nil, rsz)
+ if err != nil {
+ return err
+ }
+ r := cs.handle(m)
+ msgRegistry.put(m)
+ rsz, err = ch.send(r)
+ if err != nil {
+ return err
+ }
+ }
+ return nil // Done.
+}
+
+// Shutdown shuts down the channel.
+//
+// This must be called before Close.
+func (ch *channel) Shutdown() {
+ ch.data.Shutdown()
+}
+
+// Close closes the channel.
+//
+// This must only be called once, and cannot return an error. Note that
+// synchronization for this method is provided at a high-level, depending on
+// whether it is the client or server. This cannot be called while there are
+// active callers in either service or sendRecv.
+//
+// Precondition: the channel should be shutdown.
+func (ch *channel) Close() error {
+ // Close all backing transports.
+ ch.fds.Destroy()
+ ch.data.Destroy()
+ if ch.client != nil {
+ ch.client.Close()
+ }
+ return nil
+}
+
+// send sends the given message.
+//
+// The return value is the size of the received response. Not that in the
+// server case, this is the size of the next request.
+func (ch *channel) send(m message) (uint32, error) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [channel @%p] %s", ch, m.String())
+ }
+
+ // Send any file payload.
+ sentFD := false
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ if err := ch.fds.SendFD(f.FD()); err != nil {
+ return 0, syscall.EIO // Map everything to EIO.
+ }
+ f.Close() // Per sendRecvLegacy.
+ sentFD = true // To mark below.
+ }
+ }
+
+ // Encode the message.
+ //
+ // Note that IPC itself encodes the length of messages, so we don't
+ // need to encode a standard 9P header. We write only the message type.
+ ch.reset(0)
+
+ ch.buf.WriteMsgType(m.Type())
+ if sentFD {
+ ch.buf.Write8(1) // Incoming FD.
+ } else {
+ ch.buf.Write8(0) // No incoming FD.
+ }
+ m.Encode(&ch.buf)
+ ssz := uint32(len(ch.buf.data)) // Updated below.
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ copy(ch.data.Data()[ssz:], p)
+ ssz += uint32(len(p))
+ }
+
+ // Perform the one-shot communication.
+ n, err := ch.data.SendRecv(ssz)
+ if err != nil {
+ if n > 0 {
+ return n, nil
+ }
+ return 0, syscall.EIO // See above.
+ }
+
+ return n, nil
+}
+
+// recv decodes a message that exists on the channel.
+//
+// If the passed r is non-nil, then the type must match or an error will be
+// generated. If the passed r is nil, then a new message will be created and
+// returned.
+func (ch *channel) recv(r message, rsz uint32) (message, error) {
+ // Decode the response from the inline buffer.
+ ch.reset(rsz)
+ t := ch.buf.ReadMsgType()
+ hasFD := ch.buf.Read8() != 0
+ if t == MsgRlerror {
+ // Change the message type. We check for this special case
+ // after decoding below, and transform into an error.
+ r = &Rlerror{}
+ } else if r == nil {
+ nr, err := msgRegistry.get(0, t)
+ if err != nil {
+ return nil, err
+ }
+ r = nr // New message.
+ } else if t != r.Type() {
+ // Not an error and not the expected response; propagate.
+ return nil, &ErrBadResponse{Got: t, Want: r.Type()}
+ }
+
+ // Is there a payload? Copy from the latter portion.
+ if payloader, ok := r.(payloader); ok {
+ fs := payloader.FixedSize()
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
+ ch.buf.data = ch.buf.data[:fs]
+ }
+
+ r.Decode(&ch.buf)
+ if ch.buf.isOverrun() {
+ // Nothing valid was available.
+ log.Debugf("recv [got %d bytes, needed more]", rsz)
+ return nil, ErrNoValidMessage
+ }
+
+ // Read any FD result.
+ if hasFD {
+ if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
+ f := fd.New(rfd)
+ if filer, ok := r.(filer); ok {
+ // Set the payload.
+ filer.SetFilePayload(f)
+ } else {
+ // Don't want the FD.
+ f.Close()
+ }
+ } else {
+ // The header bit was set but nothing came in.
+ log.Warningf("expected FD, got err: %v", err)
+ }
+ }
+
+ // Log a message.
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [channel @%p] %s", ch, r.String())
+ }
+
+ // Convert errors appropriately; see above.
+ if rlerr, ok := r.(*Rlerror); ok {
+ return nil, syscall.Errno(rlerr.Error)
+ }
+
+ return r, nil
+}
+
+// sendRecv sends the given message over the channel.
+//
+// This is used by the client.
+func (ch *channel) sendRecv(c *Client, m, r message) error {
+ rsz, err := ch.send(m)
+ if err != nil {
+ return err
+ }
+ _, err = ch.recv(r, rsz)
+ return err
+}
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index cdb3bc841..2f50ff3ea 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) {
t.Fatalf("unable to create file: %v", err)
}
- if err := send(client, Tag(1), &Rlopen{File: f}); err != nil {
+ rlopen := &Rlopen{}
+ rlopen.SetFilePayload(f)
+ if err := send(client, Tag(1), rlopen); err != nil {
t.Fatalf("send got err %v expected nil", err)
}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index c2a2885ae..f1ffdd23a 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 7
+ highestSupportedVersion uint32 = 8
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool {
func versionSupportsTallocate(v uint32) bool {
return v >= 7
}
+
+// versionSupportsFlipcall returns true if version v supports IPC channels from
+// the flipcall package. Note that these must be negotiated, but this version
+// string indicates that such a facility exists.
+func versionSupportsFlipcall(v uint32) bool {
+ return v >= 8
+}