summaryrefslogtreecommitdiffhomepage
path: root/pkg/p9
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/p9')
-rw-r--r--pkg/p9/client.go171
-rw-r--r--pkg/p9/p9test/BUILD2
-rw-r--r--pkg/p9/p9test/p9test.go2
-rw-r--r--pkg/p9/transport_flipcall.go13
4 files changed, 92 insertions, 96 deletions
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 123f54e29..2412aa5e1 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -92,6 +92,10 @@ type Client struct {
// version 0 implies 9P2000.L.
version uint32
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
// sendRecv is the transport function.
//
// This is determined dynamically based on whether or not the server
@@ -104,17 +108,15 @@ type Client struct {
// channelsMu protects channels.
channelsMu sync.Mutex
- // channelsWg is a wait group for active clients.
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
channelsWg sync.WaitGroup
- // channels are the set of initialized IPCs channels.
+ // channels is the set of all initialized channels.
channels []*channel
- // inuse is set when the channels are actually in use.
- //
- // This is a fixed-size slice, and the entries will be nil when the
- // corresponding channel is available.
- inuse []*channel
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
// -- below corresponds to sendRecvLegacy --
@@ -135,7 +137,7 @@ type Client struct {
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -214,13 +216,6 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
if len(c.channels) >= 1 {
// At least one channel created.
c.sendRecv = c.sendRecvChannel
-
- // If we are using channels for communication, then we must poll
- // for shutdown events on the main socket. If the socket happens
- // to shutdown, then we will close the channels as well. This is
- // necessary because channels can hang forever if the server dies
- // while we're expecting a response.
- go c.watch(socket) // S/R-SAFE: not relevant.
} else {
// Channel setup failed; fallback.
c.sendRecv = c.sendRecvLegacy
@@ -230,13 +225,20 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.sendRecv = c.sendRecvLegacy
}
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
-// watch watches the given socket and calls Close on hang up events.
+// watch watches the given socket and releases resources on hangup events.
//
// This is intended to be called as a goroutine.
func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
events := []unix.PollFd{
unix.PollFd{
Fd: int32(socket.FD()),
@@ -244,19 +246,49 @@ func (c *Client) watch(socket *unet.Socket) {
},
}
+ // Wait for a shutdown event.
for {
- // Wait for a shutdown event.
n, err := unix.Ppoll(events, nil, nil)
- if n == 0 || err == syscall.EAGAIN {
+ if err == syscall.EINTR || err == syscall.EAGAIN {
continue
}
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
break
}
- // Close everything down: this will kick all active clients off any
- // pending requests. Note that Close must be safe to call concurrently,
- // and multiple times (see Close below).
- c.Close()
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
}
// openChannel attempts to open a client channel.
@@ -315,7 +347,7 @@ func (c *Client) openChannel(id int) error {
c.channelsMu.Lock()
defer c.channelsMu.Unlock()
c.channels = append(c.channels, res)
- c.inuse = append(c.inuse, nil)
+ c.availableChannels = append(c.availableChannels, res)
return nil
}
@@ -449,23 +481,16 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
// sendRecvChannel uses channels to send a message.
func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
c.channelsMu.Lock()
- if len(c.channels) == 0 {
- // No channel available.
+ if len(c.availableChannels) == 0 {
c.channelsMu.Unlock()
return c.sendRecvLegacy(t, r)
}
-
- // Find the last used channel.
- //
- // Note that we must add one to the wait group while holding the
- // channel mutex, in order for the Wait operation to be race-free
- // below. The Wait operation shuts down all in use channels and
- // waits for them to return, but must do so holding the mutex.
- idx := len(c.channels) - 1
- ch := c.channels[idx]
- c.channels = c.channels[:idx]
- c.inuse[idx] = ch
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
c.channelsWg.Add(1)
c.channelsMu.Unlock()
@@ -473,8 +498,12 @@ func (c *Client) sendRecvChannel(t message, r message) error {
if !ch.connected {
ch.connected = true
if err := ch.data.Connect(); err != nil {
- // The channel is unusable, so don't return it.
- ch.Close()
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
c.channelsWg.Done()
return err
}
@@ -482,24 +511,17 @@ func (c *Client) sendRecvChannel(t message, r message) error {
// Send the message.
err := ch.sendRecv(c, t, r)
- if err != nil {
- // On shutdown, we'll see ENOENT. This is a normal situation, and
- // we shouldn't generate a spurious warning message in that case.
- log.Debugf("error calling sendRecvChannel: %v", err)
- }
- c.channelsWg.Done()
- // Return the channel.
- //
- // Note that we check the channel from the inuse slice here. This
- // prevents a race where Close is called, which clears inuse, and
- // means that we will not actually return the closed channel.
+ // Release the channel.
c.channelsMu.Lock()
- if c.inuse[idx] != nil {
- c.channels = append(c.channels, ch)
- c.inuse[idx] = nil
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
}
c.channelsMu.Unlock()
+ c.channelsWg.Done()
return err
}
@@ -510,44 +532,9 @@ func (c *Client) Version() uint32 {
}
// Close closes the underlying socket and channels.
-//
-// Because Close may be called asynchronously from watch, it must be
-// safe to call concurrently and multiple times.
-func (c *Client) Close() error {
- c.channelsMu.Lock()
- defer c.channelsMu.Unlock()
-
- // Close all inactive channels.
- for _, ch := range c.channels {
- ch.Shutdown()
- ch.Close()
- }
- // Close all active channels.
- for _, ch := range c.inuse {
- if ch != nil {
- log.Debugf("shutting down active channel@%p...", ch)
- ch.Shutdown()
- }
- }
-
- // Wait for active users.
- c.channelsWg.Wait()
-
- // Close all previously active channels.
- for i, ch := range c.inuse {
- if ch != nil {
- ch.Close()
-
- // Clear the inuse entry here so that it will not be returned
- // to the channel slice, which is cleared below. See the
- // comment at the end of sendRecvChannel.
- c.inuse[i] = nil
- }
- }
- c.channels = nil // Prevent use again.
-
- // Close the main socket. Note that operation is safe to be called
- // multiple times, unlikely the channel Close operations above, which
- // we are careful to ensure aren't called twice.
- return c.socket.Close()
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 1d34181e0..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 9d74638bb..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index aebb54959..7cdf4ecc3 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -60,6 +60,7 @@ type channel struct {
// -- client only --
connected bool
+ active bool
// -- server only --
client *fd.FD
@@ -197,10 +198,18 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
return nil, &ErrBadResponse{Got: t, Want: r.Type()}
}
- // Is there a payload? Set to the latter portion.
+ // Is there a payload? Copy from the latter portion.
if payloader, ok := r.(payloader); ok {
fs := payloader.FixedSize()
- payloader.SetPayload(ch.buf.data[fs:])
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
ch.buf.data = ch.buf.data[:fs]
}