summaryrefslogtreecommitdiffhomepage
path: root/pkg/p9
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/p9')
-rw-r--r--pkg/p9/client.go171
-rw-r--r--pkg/p9/p9test/BUILD2
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/p9/p9test/p9test.go2
-rw-r--r--pkg/p9/transport_flipcall.go13
5 files changed, 187 insertions, 96 deletions
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 123f54e29..2412aa5e1 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -92,6 +92,10 @@ type Client struct {
// version 0 implies 9P2000.L.
version uint32
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
// sendRecv is the transport function.
//
// This is determined dynamically based on whether or not the server
@@ -104,17 +108,15 @@ type Client struct {
// channelsMu protects channels.
channelsMu sync.Mutex
- // channelsWg is a wait group for active clients.
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
channelsWg sync.WaitGroup
- // channels are the set of initialized IPCs channels.
+ // channels is the set of all initialized channels.
channels []*channel
- // inuse is set when the channels are actually in use.
- //
- // This is a fixed-size slice, and the entries will be nil when the
- // corresponding channel is available.
- inuse []*channel
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
// -- below corresponds to sendRecvLegacy --
@@ -135,7 +137,7 @@ type Client struct {
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -214,13 +216,6 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
if len(c.channels) >= 1 {
// At least one channel created.
c.sendRecv = c.sendRecvChannel
-
- // If we are using channels for communication, then we must poll
- // for shutdown events on the main socket. If the socket happens
- // to shutdown, then we will close the channels as well. This is
- // necessary because channels can hang forever if the server dies
- // while we're expecting a response.
- go c.watch(socket) // S/R-SAFE: not relevant.
} else {
// Channel setup failed; fallback.
c.sendRecv = c.sendRecvLegacy
@@ -230,13 +225,20 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.sendRecv = c.sendRecvLegacy
}
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
-// watch watches the given socket and calls Close on hang up events.
+// watch watches the given socket and releases resources on hangup events.
//
// This is intended to be called as a goroutine.
func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
events := []unix.PollFd{
unix.PollFd{
Fd: int32(socket.FD()),
@@ -244,19 +246,49 @@ func (c *Client) watch(socket *unet.Socket) {
},
}
+ // Wait for a shutdown event.
for {
- // Wait for a shutdown event.
n, err := unix.Ppoll(events, nil, nil)
- if n == 0 || err == syscall.EAGAIN {
+ if err == syscall.EINTR || err == syscall.EAGAIN {
continue
}
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
break
}
- // Close everything down: this will kick all active clients off any
- // pending requests. Note that Close must be safe to call concurrently,
- // and multiple times (see Close below).
- c.Close()
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
}
// openChannel attempts to open a client channel.
@@ -315,7 +347,7 @@ func (c *Client) openChannel(id int) error {
c.channelsMu.Lock()
defer c.channelsMu.Unlock()
c.channels = append(c.channels, res)
- c.inuse = append(c.inuse, nil)
+ c.availableChannels = append(c.availableChannels, res)
return nil
}
@@ -449,23 +481,16 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
// sendRecvChannel uses channels to send a message.
func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
c.channelsMu.Lock()
- if len(c.channels) == 0 {
- // No channel available.
+ if len(c.availableChannels) == 0 {
c.channelsMu.Unlock()
return c.sendRecvLegacy(t, r)
}
-
- // Find the last used channel.
- //
- // Note that we must add one to the wait group while holding the
- // channel mutex, in order for the Wait operation to be race-free
- // below. The Wait operation shuts down all in use channels and
- // waits for them to return, but must do so holding the mutex.
- idx := len(c.channels) - 1
- ch := c.channels[idx]
- c.channels = c.channels[:idx]
- c.inuse[idx] = ch
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
c.channelsWg.Add(1)
c.channelsMu.Unlock()
@@ -473,8 +498,12 @@ func (c *Client) sendRecvChannel(t message, r message) error {
if !ch.connected {
ch.connected = true
if err := ch.data.Connect(); err != nil {
- // The channel is unusable, so don't return it.
- ch.Close()
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
c.channelsWg.Done()
return err
}
@@ -482,24 +511,17 @@ func (c *Client) sendRecvChannel(t message, r message) error {
// Send the message.
err := ch.sendRecv(c, t, r)
- if err != nil {
- // On shutdown, we'll see ENOENT. This is a normal situation, and
- // we shouldn't generate a spurious warning message in that case.
- log.Debugf("error calling sendRecvChannel: %v", err)
- }
- c.channelsWg.Done()
- // Return the channel.
- //
- // Note that we check the channel from the inuse slice here. This
- // prevents a race where Close is called, which clears inuse, and
- // means that we will not actually return the closed channel.
+ // Release the channel.
c.channelsMu.Lock()
- if c.inuse[idx] != nil {
- c.channels = append(c.channels, ch)
- c.inuse[idx] = nil
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
}
c.channelsMu.Unlock()
+ c.channelsWg.Done()
return err
}
@@ -510,44 +532,9 @@ func (c *Client) Version() uint32 {
}
// Close closes the underlying socket and channels.
-//
-// Because Close may be called asynchronously from watch, it must be
-// safe to call concurrently and multiple times.
-func (c *Client) Close() error {
- c.channelsMu.Lock()
- defer c.channelsMu.Unlock()
-
- // Close all inactive channels.
- for _, ch := range c.channels {
- ch.Shutdown()
- ch.Close()
- }
- // Close all active channels.
- for _, ch := range c.inuse {
- if ch != nil {
- log.Debugf("shutting down active channel@%p...", ch)
- ch.Shutdown()
- }
- }
-
- // Wait for active users.
- c.channelsWg.Wait()
-
- // Close all previously active channels.
- for i, ch := range c.inuse {
- if ch != nil {
- ch.Close()
-
- // Clear the inuse entry here so that it will not be returned
- // to the channel slice, which is cleared below. See the
- // comment at the end of sendRecvChannel.
- c.inuse[i] = nil
- }
- }
- c.channels = nil // Prevent use again.
-
- // Close the main socket. Note that operation is safe to be called
- // multiple times, unlikely the channel Close operations above, which
- // we are careful to ensure aren't called twice.
- return c.socket.Close()
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 1d34181e0..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 9d74638bb..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index aebb54959..7cdf4ecc3 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -60,6 +60,7 @@ type channel struct {
// -- client only --
connected bool
+ active bool
// -- server only --
client *fd.FD
@@ -197,10 +198,18 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
return nil, &ErrBadResponse{Got: t, Want: r.Type()}
}
- // Is there a payload? Set to the latter portion.
+ // Is there a payload? Copy from the latter portion.
if payloader, ok := r.(payloader); ok {
fs := payloader.FixedSize()
- payloader.SetPayload(ch.buf.data[fs:])
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
ch.buf.data = ch.buf.data[:fs]
}