From a8834fc555539bd6b0b46936c4a79817812658ff Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 12 Sep 2019 23:36:18 -0700 Subject: Update p9 to support flipcall. PiperOrigin-RevId: 268845090 --- pkg/p9/transport_flipcall.go | 254 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 pkg/p9/transport_flipcall.go (limited to 'pkg/p9/transport_flipcall.go') diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go new file mode 100644 index 000000000..aebb54959 --- /dev/null +++ b/pkg/p9/transport_flipcall.go @@ -0,0 +1,254 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "runtime" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +// channelsPerClient is the number of channels to create per client. +// +// While the client and server will generally agree on this number, in reality +// it's completely up to the server. We simply define a minimum of 2, and a +// maximum of 4, and select the number of available processes as a tie-breaker. +// Note that we don't want the number of channels to be too large, because each +// will account for channelSize memory used, which can be large. +var channelsPerClient = func() int { + n := runtime.NumCPU() + if n < 2 { + return 2 + } + if n > 4 { + return 4 + } + return n +}() + +// channelSize is the channel size to create. +// +// We simply ensure that this is larger than the largest possible message size, +// plus the flipcall packet header, plus the two bytes we write below. +const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength) + +// channel is a fast IPC channel. +// +// The same object is used by both the server and client implementations. In +// general, the client will use only the send and recv methods. +type channel struct { + desc flipcall.PacketWindowDescriptor + data flipcall.Endpoint + fds fdchannel.Endpoint + buf buffer + + // -- client only -- + connected bool + + // -- server only -- + client *fd.FD + done chan struct{} +} + +// reset resets the channel buffer. +func (ch *channel) reset(sz uint32) { + ch.buf.data = ch.data.Data()[:sz] +} + +// service services the channel. +func (ch *channel) service(cs *connState) error { + rsz, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rsz > 0 { + m, err := ch.recv(nil, rsz) + if err != nil { + return err + } + r := cs.handle(m) + msgRegistry.put(m) + rsz, err = ch.send(r) + if err != nil { + return err + } + } + return nil // Done. +} + +// Shutdown shuts down the channel. +// +// This must be called before Close. +func (ch *channel) Shutdown() { + ch.data.Shutdown() +} + +// Close closes the channel. +// +// This must only be called once, and cannot return an error. Note that +// synchronization for this method is provided at a high-level, depending on +// whether it is the client or server. This cannot be called while there are +// active callers in either service or sendRecv. +// +// Precondition: the channel should be shutdown. +func (ch *channel) Close() error { + // Close all backing transports. + ch.fds.Destroy() + ch.data.Destroy() + if ch.client != nil { + ch.client.Close() + } + return nil +} + +// send sends the given message. +// +// The return value is the size of the received response. Not that in the +// server case, this is the size of the next request. +func (ch *channel) send(m message) (uint32, error) { + if log.IsLogging(log.Debug) { + log.Debugf("send [channel @%p] %s", ch, m.String()) + } + + // Send any file payload. + sentFD := false + if filer, ok := m.(filer); ok { + if f := filer.FilePayload(); f != nil { + if err := ch.fds.SendFD(f.FD()); err != nil { + return 0, syscall.EIO // Map everything to EIO. + } + f.Close() // Per sendRecvLegacy. + sentFD = true // To mark below. + } + } + + // Encode the message. + // + // Note that IPC itself encodes the length of messages, so we don't + // need to encode a standard 9P header. We write only the message type. + ch.reset(0) + + ch.buf.WriteMsgType(m.Type()) + if sentFD { + ch.buf.Write8(1) // Incoming FD. + } else { + ch.buf.Write8(0) // No incoming FD. + } + m.Encode(&ch.buf) + ssz := uint32(len(ch.buf.data)) // Updated below. + + // Is there a payload? + if payloader, ok := m.(payloader); ok { + p := payloader.Payload() + copy(ch.data.Data()[ssz:], p) + ssz += uint32(len(p)) + } + + // Perform the one-shot communication. + n, err := ch.data.SendRecv(ssz) + if err != nil { + if n > 0 { + return n, nil + } + return 0, syscall.EIO // See above. + } + + return n, nil +} + +// recv decodes a message that exists on the channel. +// +// If the passed r is non-nil, then the type must match or an error will be +// generated. If the passed r is nil, then a new message will be created and +// returned. +func (ch *channel) recv(r message, rsz uint32) (message, error) { + // Decode the response from the inline buffer. + ch.reset(rsz) + t := ch.buf.ReadMsgType() + hasFD := ch.buf.Read8() != 0 + if t == MsgRlerror { + // Change the message type. We check for this special case + // after decoding below, and transform into an error. + r = &Rlerror{} + } else if r == nil { + nr, err := msgRegistry.get(0, t) + if err != nil { + return nil, err + } + r = nr // New message. + } else if t != r.Type() { + // Not an error and not the expected response; propagate. + return nil, &ErrBadResponse{Got: t, Want: r.Type()} + } + + // Is there a payload? Set to the latter portion. + if payloader, ok := r.(payloader); ok { + fs := payloader.FixedSize() + payloader.SetPayload(ch.buf.data[fs:]) + ch.buf.data = ch.buf.data[:fs] + } + + r.Decode(&ch.buf) + if ch.buf.isOverrun() { + // Nothing valid was available. + log.Debugf("recv [got %d bytes, needed more]", rsz) + return nil, ErrNoValidMessage + } + + // Read any FD result. + if hasFD { + if rfd, err := ch.fds.RecvFDNonblock(); err == nil { + f := fd.New(rfd) + if filer, ok := r.(filer); ok { + // Set the payload. + filer.SetFilePayload(f) + } else { + // Don't want the FD. + f.Close() + } + } else { + // The header bit was set but nothing came in. + log.Warningf("expected FD, got err: %v", err) + } + } + + // Log a message. + if log.IsLogging(log.Debug) { + log.Debugf("recv [channel @%p] %s", ch, r.String()) + } + + // Convert errors appropriately; see above. + if rlerr, ok := r.(*Rlerror); ok { + return nil, syscall.Errno(rlerr.Error) + } + + return r, nil +} + +// sendRecv sends the given message over the channel. +// +// This is used by the client. +func (ch *channel) sendRecv(c *Client, m, r message) error { + rsz, err := ch.send(m) + if err != nil { + return err + } + _, err = ch.recv(r, rsz) + return err +} -- cgit v1.2.3 From e9af227a61b836310fdd0c8543c31afe094af5ae Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 19 Sep 2019 22:51:17 -0700 Subject: Fix p9 integration of flipcall. - Do not call Rread.SetPayload(flipcall packet window) in p9.channel.recv(). - Ignore EINTR from ppoll() in p9.Client.watch(). - Clean up handling of client socket FD lifetimes so that p9.Client.watch() never ppoll()s a closed FD. - Make p9test.Harness.Finish() call clientSocket.Shutdown() instead of clientSocket.Close() for the same reason. - Rework channel reuse to avoid leaking channels in the following case (suppose we have two channels): sendRecvChannel len(channels) == 2 => idx = 1 inuse[1] = ch0 sendRecvChannel len(channels) == 1 => idx = 0 inuse[0] = ch1 inuse[1] = nil sendRecvChannel len(channels) == 1 => idx = 0 inuse[0] = ch0 inuse[0] = nil inuse[0] == nil => ch0 leaked - Avoid deadlocking p9.Client.watch() by calling channelsWg.Wait() without holding channelsMu. - Bump p9test:client_test size to medium. PiperOrigin-RevId: 270200314 --- pkg/p9/client.go | 171 ++++++++++++++++++++----------------------- pkg/p9/p9test/BUILD | 2 +- pkg/p9/p9test/p9test.go | 2 +- pkg/p9/transport_flipcall.go | 13 +++- 4 files changed, 92 insertions(+), 96 deletions(-) (limited to 'pkg/p9/transport_flipcall.go') diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 123f54e29..2412aa5e1 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -92,6 +92,10 @@ type Client struct { // version 0 implies 9P2000.L. version uint32 + // closedWg is marked as done when the Client.watch() goroutine, which is + // responsible for closing channels and the socket fd, returns. + closedWg sync.WaitGroup + // sendRecv is the transport function. // // This is determined dynamically based on whether or not the server @@ -104,17 +108,15 @@ type Client struct { // channelsMu protects channels. channelsMu sync.Mutex - // channelsWg is a wait group for active clients. + // channelsWg counts the number of channels for which channel.active == + // true. channelsWg sync.WaitGroup - // channels are the set of initialized IPCs channels. + // channels is the set of all initialized channels. channels []*channel - // inuse is set when the channels are actually in use. - // - // This is a fixed-size slice, and the entries will be nil when the - // corresponding channel is available. - inuse []*channel + // availableChannels is a FIFO of inactive channels. + availableChannels []*channel // -- below corresponds to sendRecvLegacy -- @@ -135,7 +137,7 @@ type Client struct { // NewClient creates a new client. It performs a Tversion exchange with // the server to assert that messageSize is ok to use. // -// You should not use the same socket for multiple clients. +// If NewClient succeeds, ownership of socket is transferred to the new Client. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. if messageSize <= msgRegistry.largestFixedSize { @@ -214,13 +216,6 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client if len(c.channels) >= 1 { // At least one channel created. c.sendRecv = c.sendRecvChannel - - // If we are using channels for communication, then we must poll - // for shutdown events on the main socket. If the socket happens - // to shutdown, then we will close the channels as well. This is - // necessary because channels can hang forever if the server dies - // while we're expecting a response. - go c.watch(socket) // S/R-SAFE: not relevant. } else { // Channel setup failed; fallback. c.sendRecv = c.sendRecvLegacy @@ -230,13 +225,20 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.sendRecv = c.sendRecvLegacy } + // Ensure that the socket and channels are closed when the socket is shut + // down. + c.closedWg.Add(1) + go c.watch(socket) // S/R-SAFE: not relevant. + return c, nil } -// watch watches the given socket and calls Close on hang up events. +// watch watches the given socket and releases resources on hangup events. // // This is intended to be called as a goroutine. func (c *Client) watch(socket *unet.Socket) { + defer c.closedWg.Done() + events := []unix.PollFd{ unix.PollFd{ Fd: int32(socket.FD()), @@ -244,19 +246,49 @@ func (c *Client) watch(socket *unet.Socket) { }, } + // Wait for a shutdown event. for { - // Wait for a shutdown event. n, err := unix.Ppoll(events, nil, nil) - if n == 0 || err == syscall.EAGAIN { + if err == syscall.EINTR || err == syscall.EAGAIN { continue } + if err != nil { + log.Warningf("p9.Client.watch(): %v", err) + break + } + if n != 1 { + log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) + } break } - // Close everything down: this will kick all active clients off any - // pending requests. Note that Close must be safe to call concurrently, - // and multiple times (see Close below). - c.Close() + // Set availableChannels to nil so that future calls to c.sendRecvChannel() + // don't attempt to activate a channel, and concurrent calls to + // c.sendRecvChannel() don't mark released channels as available. + c.channelsMu.Lock() + c.availableChannels = nil + + // Shut down all active channels. + for _, ch := range c.channels { + if ch.active { + log.Debugf("shutting down active channel@%p...", ch) + ch.Shutdown() + } + } + c.channelsMu.Unlock() + + // Wait for active channels to become inactive. + c.channelsWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.Close() + } + c.channelsMu.Unlock() + + // Close the main socket. + c.socket.Close() } // openChannel attempts to open a client channel. @@ -315,7 +347,7 @@ func (c *Client) openChannel(id int) error { c.channelsMu.Lock() defer c.channelsMu.Unlock() c.channels = append(c.channels, res) - c.inuse = append(c.inuse, nil) + c.availableChannels = append(c.availableChannels, res) return nil } @@ -449,23 +481,16 @@ func (c *Client) sendRecvLegacy(t message, r message) error { // sendRecvChannel uses channels to send a message. func (c *Client) sendRecvChannel(t message, r message) error { + // Acquire an available channel. c.channelsMu.Lock() - if len(c.channels) == 0 { - // No channel available. + if len(c.availableChannels) == 0 { c.channelsMu.Unlock() return c.sendRecvLegacy(t, r) } - - // Find the last used channel. - // - // Note that we must add one to the wait group while holding the - // channel mutex, in order for the Wait operation to be race-free - // below. The Wait operation shuts down all in use channels and - // waits for them to return, but must do so holding the mutex. - idx := len(c.channels) - 1 - ch := c.channels[idx] - c.channels = c.channels[:idx] - c.inuse[idx] = ch + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + ch.active = true c.channelsWg.Add(1) c.channelsMu.Unlock() @@ -473,8 +498,12 @@ func (c *Client) sendRecvChannel(t message, r message) error { if !ch.connected { ch.connected = true if err := ch.data.Connect(); err != nil { - // The channel is unusable, so don't return it. - ch.Close() + // The channel is unusable, so don't return it to + // c.availableChannels. However, we still have to mark it as + // inactive so c.watch() doesn't wait for it. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() c.channelsWg.Done() return err } @@ -482,24 +511,17 @@ func (c *Client) sendRecvChannel(t message, r message) error { // Send the message. err := ch.sendRecv(c, t, r) - if err != nil { - // On shutdown, we'll see ENOENT. This is a normal situation, and - // we shouldn't generate a spurious warning message in that case. - log.Debugf("error calling sendRecvChannel: %v", err) - } - c.channelsWg.Done() - // Return the channel. - // - // Note that we check the channel from the inuse slice here. This - // prevents a race where Close is called, which clears inuse, and - // means that we will not actually return the closed channel. + // Release the channel. c.channelsMu.Lock() - if c.inuse[idx] != nil { - c.channels = append(c.channels, ch) - c.inuse[idx] = nil + ch.active = false + // If c.availableChannels is nil, c.watch() has fired and we should not + // mark this channel as available. + if c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) } c.channelsMu.Unlock() + c.channelsWg.Done() return err } @@ -510,44 +532,9 @@ func (c *Client) Version() uint32 { } // Close closes the underlying socket and channels. -// -// Because Close may be called asynchronously from watch, it must be -// safe to call concurrently and multiple times. -func (c *Client) Close() error { - c.channelsMu.Lock() - defer c.channelsMu.Unlock() - - // Close all inactive channels. - for _, ch := range c.channels { - ch.Shutdown() - ch.Close() - } - // Close all active channels. - for _, ch := range c.inuse { - if ch != nil { - log.Debugf("shutting down active channel@%p...", ch) - ch.Shutdown() - } - } - - // Wait for active users. - c.channelsWg.Wait() - - // Close all previously active channels. - for i, ch := range c.inuse { - if ch != nil { - ch.Close() - - // Clear the inuse entry here so that it will not be returned - // to the channel slice, which is cleared below. See the - // comment at the end of sendRecvChannel. - c.inuse[i] = nil - } - } - c.channels = nil // Prevent use again. - - // Close the main socket. Note that operation is safe to be called - // multiple times, unlikely the channel Close operations above, which - // we are careful to ensure aren't called twice. - return c.socket.Close() +func (c *Client) Close() { + // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already + // been called (by c.watch()). + c.socket.Shutdown() + c.closedWg.Wait() } diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 1d34181e0..28707c0ca 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -77,7 +77,7 @@ go_library( go_test( name = "client_test", - size = "small", + size = "medium", srcs = ["client_test.go"], embed = [":p9test"], deps = [ diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 9d74638bb..4d3271b37 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator { // Finish completes all checks and shuts down the server. func (h *Harness) Finish() { - h.clientSocket.Close() + h.clientSocket.Shutdown() h.wg.Wait() h.mockCtrl.Finish() } diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index aebb54959..7cdf4ecc3 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -60,6 +60,7 @@ type channel struct { // -- client only -- connected bool + active bool // -- server only -- client *fd.FD @@ -197,10 +198,18 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) { return nil, &ErrBadResponse{Got: t, Want: r.Type()} } - // Is there a payload? Set to the latter portion. + // Is there a payload? Copy from the latter portion. if payloader, ok := r.(payloader); ok { fs := payloader.FixedSize() - payloader.SetPayload(ch.buf.data[fs:]) + p := payloader.Payload() + payloadData := ch.buf.data[fs:] + if len(p) < len(payloadData) { + p = make([]byte, len(payloadData)) + copy(p, payloadData) + payloader.SetPayload(p) + } else if n := copy(p, payloadData); n < len(p) { + payloader.SetPayload(p[:n]) + } ch.buf.data = ch.buf.data[:fs] } -- cgit v1.2.3