diff options
Diffstat (limited to 'pkg/p9')
-rw-r--r-- | pkg/p9/client.go | 171 | ||||
-rw-r--r-- | pkg/p9/p9test/BUILD | 2 | ||||
-rw-r--r-- | pkg/p9/p9test/p9test.go | 2 | ||||
-rw-r--r-- | pkg/p9/transport_flipcall.go | 13 |
4 files changed, 92 insertions, 96 deletions
diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 123f54e29..2412aa5e1 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -92,6 +92,10 @@ type Client struct { // version 0 implies 9P2000.L. version uint32 + // closedWg is marked as done when the Client.watch() goroutine, which is + // responsible for closing channels and the socket fd, returns. + closedWg sync.WaitGroup + // sendRecv is the transport function. // // This is determined dynamically based on whether or not the server @@ -104,17 +108,15 @@ type Client struct { // channelsMu protects channels. channelsMu sync.Mutex - // channelsWg is a wait group for active clients. + // channelsWg counts the number of channels for which channel.active == + // true. channelsWg sync.WaitGroup - // channels are the set of initialized IPCs channels. + // channels is the set of all initialized channels. channels []*channel - // inuse is set when the channels are actually in use. - // - // This is a fixed-size slice, and the entries will be nil when the - // corresponding channel is available. - inuse []*channel + // availableChannels is a FIFO of inactive channels. + availableChannels []*channel // -- below corresponds to sendRecvLegacy -- @@ -135,7 +137,7 @@ type Client struct { // NewClient creates a new client. It performs a Tversion exchange with // the server to assert that messageSize is ok to use. // -// You should not use the same socket for multiple clients. +// If NewClient succeeds, ownership of socket is transferred to the new Client. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. if messageSize <= msgRegistry.largestFixedSize { @@ -214,13 +216,6 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client if len(c.channels) >= 1 { // At least one channel created. c.sendRecv = c.sendRecvChannel - - // If we are using channels for communication, then we must poll - // for shutdown events on the main socket. If the socket happens - // to shutdown, then we will close the channels as well. This is - // necessary because channels can hang forever if the server dies - // while we're expecting a response. - go c.watch(socket) // S/R-SAFE: not relevant. } else { // Channel setup failed; fallback. c.sendRecv = c.sendRecvLegacy @@ -230,13 +225,20 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.sendRecv = c.sendRecvLegacy } + // Ensure that the socket and channels are closed when the socket is shut + // down. + c.closedWg.Add(1) + go c.watch(socket) // S/R-SAFE: not relevant. + return c, nil } -// watch watches the given socket and calls Close on hang up events. +// watch watches the given socket and releases resources on hangup events. // // This is intended to be called as a goroutine. func (c *Client) watch(socket *unet.Socket) { + defer c.closedWg.Done() + events := []unix.PollFd{ unix.PollFd{ Fd: int32(socket.FD()), @@ -244,19 +246,49 @@ func (c *Client) watch(socket *unet.Socket) { }, } + // Wait for a shutdown event. for { - // Wait for a shutdown event. n, err := unix.Ppoll(events, nil, nil) - if n == 0 || err == syscall.EAGAIN { + if err == syscall.EINTR || err == syscall.EAGAIN { continue } + if err != nil { + log.Warningf("p9.Client.watch(): %v", err) + break + } + if n != 1 { + log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) + } break } - // Close everything down: this will kick all active clients off any - // pending requests. Note that Close must be safe to call concurrently, - // and multiple times (see Close below). - c.Close() + // Set availableChannels to nil so that future calls to c.sendRecvChannel() + // don't attempt to activate a channel, and concurrent calls to + // c.sendRecvChannel() don't mark released channels as available. + c.channelsMu.Lock() + c.availableChannels = nil + + // Shut down all active channels. + for _, ch := range c.channels { + if ch.active { + log.Debugf("shutting down active channel@%p...", ch) + ch.Shutdown() + } + } + c.channelsMu.Unlock() + + // Wait for active channels to become inactive. + c.channelsWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.Close() + } + c.channelsMu.Unlock() + + // Close the main socket. + c.socket.Close() } // openChannel attempts to open a client channel. @@ -315,7 +347,7 @@ func (c *Client) openChannel(id int) error { c.channelsMu.Lock() defer c.channelsMu.Unlock() c.channels = append(c.channels, res) - c.inuse = append(c.inuse, nil) + c.availableChannels = append(c.availableChannels, res) return nil } @@ -449,23 +481,16 @@ func (c *Client) sendRecvLegacy(t message, r message) error { // sendRecvChannel uses channels to send a message. func (c *Client) sendRecvChannel(t message, r message) error { + // Acquire an available channel. c.channelsMu.Lock() - if len(c.channels) == 0 { - // No channel available. + if len(c.availableChannels) == 0 { c.channelsMu.Unlock() return c.sendRecvLegacy(t, r) } - - // Find the last used channel. - // - // Note that we must add one to the wait group while holding the - // channel mutex, in order for the Wait operation to be race-free - // below. The Wait operation shuts down all in use channels and - // waits for them to return, but must do so holding the mutex. - idx := len(c.channels) - 1 - ch := c.channels[idx] - c.channels = c.channels[:idx] - c.inuse[idx] = ch + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + ch.active = true c.channelsWg.Add(1) c.channelsMu.Unlock() @@ -473,8 +498,12 @@ func (c *Client) sendRecvChannel(t message, r message) error { if !ch.connected { ch.connected = true if err := ch.data.Connect(); err != nil { - // The channel is unusable, so don't return it. - ch.Close() + // The channel is unusable, so don't return it to + // c.availableChannels. However, we still have to mark it as + // inactive so c.watch() doesn't wait for it. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() c.channelsWg.Done() return err } @@ -482,24 +511,17 @@ func (c *Client) sendRecvChannel(t message, r message) error { // Send the message. err := ch.sendRecv(c, t, r) - if err != nil { - // On shutdown, we'll see ENOENT. This is a normal situation, and - // we shouldn't generate a spurious warning message in that case. - log.Debugf("error calling sendRecvChannel: %v", err) - } - c.channelsWg.Done() - // Return the channel. - // - // Note that we check the channel from the inuse slice here. This - // prevents a race where Close is called, which clears inuse, and - // means that we will not actually return the closed channel. + // Release the channel. c.channelsMu.Lock() - if c.inuse[idx] != nil { - c.channels = append(c.channels, ch) - c.inuse[idx] = nil + ch.active = false + // If c.availableChannels is nil, c.watch() has fired and we should not + // mark this channel as available. + if c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) } c.channelsMu.Unlock() + c.channelsWg.Done() return err } @@ -510,44 +532,9 @@ func (c *Client) Version() uint32 { } // Close closes the underlying socket and channels. -// -// Because Close may be called asynchronously from watch, it must be -// safe to call concurrently and multiple times. -func (c *Client) Close() error { - c.channelsMu.Lock() - defer c.channelsMu.Unlock() - - // Close all inactive channels. - for _, ch := range c.channels { - ch.Shutdown() - ch.Close() - } - // Close all active channels. - for _, ch := range c.inuse { - if ch != nil { - log.Debugf("shutting down active channel@%p...", ch) - ch.Shutdown() - } - } - - // Wait for active users. - c.channelsWg.Wait() - - // Close all previously active channels. - for i, ch := range c.inuse { - if ch != nil { - ch.Close() - - // Clear the inuse entry here so that it will not be returned - // to the channel slice, which is cleared below. See the - // comment at the end of sendRecvChannel. - c.inuse[i] = nil - } - } - c.channels = nil // Prevent use again. - - // Close the main socket. Note that operation is safe to be called - // multiple times, unlikely the channel Close operations above, which - // we are careful to ensure aren't called twice. - return c.socket.Close() +func (c *Client) Close() { + // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already + // been called (by c.watch()). + c.socket.Shutdown() + c.closedWg.Wait() } diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 1d34181e0..28707c0ca 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -77,7 +77,7 @@ go_library( go_test( name = "client_test", - size = "small", + size = "medium", srcs = ["client_test.go"], embed = [":p9test"], deps = [ diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 9d74638bb..4d3271b37 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator { // Finish completes all checks and shuts down the server. func (h *Harness) Finish() { - h.clientSocket.Close() + h.clientSocket.Shutdown() h.wg.Wait() h.mockCtrl.Finish() } diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index aebb54959..7cdf4ecc3 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -60,6 +60,7 @@ type channel struct { // -- client only -- connected bool + active bool // -- server only -- client *fd.FD @@ -197,10 +198,18 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) { return nil, &ErrBadResponse{Got: t, Want: r.Type()} } - // Is there a payload? Set to the latter portion. + // Is there a payload? Copy from the latter portion. if payloader, ok := r.(payloader); ok { fs := payloader.FixedSize() - payloader.SetPayload(ch.buf.data[fs:]) + p := payloader.Payload() + payloadData := ch.buf.data[fs:] + if len(p) < len(payloadData) { + p = make([]byte, len(payloadData)) + copy(p, payloadData) + payloader.SetPayload(p) + } else if n := copy(p, payloadData); n < len(p) { + payloader.SetPayload(p[:n]) + } ch.buf.data = ch.buf.data[:fs] } |