diff options
Diffstat (limited to 'pkg/flipcall')
-rw-r--r-- | pkg/flipcall/BUILD | 4 | ||||
-rw-r--r-- | pkg/flipcall/ctrl_futex.go | 32 | ||||
-rw-r--r-- | pkg/flipcall/flipcall.go | 72 | ||||
-rw-r--r-- | pkg/flipcall/flipcall_example_test.go | 4 | ||||
-rw-r--r-- | pkg/flipcall/flipcall_test.go | 204 | ||||
-rw-r--r-- | pkg/flipcall/flipcall_unsafe.go | 28 | ||||
-rw-r--r-- | pkg/flipcall/futex_linux.go | 23 |
7 files changed, 300 insertions, 67 deletions
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index bd1d614b6..5643d5f26 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -18,6 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", + "//third_party/gvsync", ], ) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 865b6f640..8390915a2 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error { *ep.dataLen() = w.Len() // Return control to the client. + raceBecomeInactive() if err := ep.futexSwitchToPeer(); err != nil { return err } @@ -121,7 +122,16 @@ func (ep *Endpoint) enterFutexWait() error { } func (ep *Endpoint) exitFutexWait() { - atomic.AddInt32(&ep.ctrl.state, -epsBlocked) + switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps { + case 0: + return + case epsShutdown: + // ep.ctrlShutdown() was called while we were blocked, so we are + // repsonsible for indicating connection shutdown. + ep.shutdownConn() + default: + panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked)) + } } func (ep *Endpoint) ctrlShutdown() { @@ -142,5 +152,25 @@ func (ep *Endpoint) ctrlShutdown() { break } } + } else { + // There is no blocked thread, so we are responsible for indicating + // connection shutdown. + ep.shutdownConn() + } +} + +func (ep *Endpoint) shutdownConn() { + switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs { + case ep.activeState: + if err := ep.futexWakeConnState(1); err != nil { + log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err) + } + case ep.inactiveState: + // The peer is currently active and will detect shutdown when it tries + // to update the connection state. + case csShutdown: + // The peer also called Endpoint.Shutdown(). + default: + log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs) } } diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 5c9212c33..386cee42c 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -42,11 +42,6 @@ type Endpoint struct { // dataCap is immutable. dataCap uint32 - // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the - // Endpoint has acknowledged shutdown initiated by the peer. shutdown is - // accessed using atomic memory operations. - shutdown uint32 - // activeState is csClientActive if this is a client Endpoint and // csServerActive if this is a server Endpoint. activeState uint32 @@ -55,9 +50,27 @@ type Endpoint struct { // csClientActive if this is a server Endpoint. inactiveState uint32 + // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the + // Endpoint has acknowledged shutdown initiated by the peer. shutdown is + // accessed using atomic memory operations. + shutdown uint32 + ctrl endpointControlImpl } +// EndpointSide indicates which side of a connection an Endpoint belongs to. +type EndpointSide int + +const ( + // ClientSide indicates that an Endpoint is a client (initially-active; + // first method call should be Connect). + ClientSide EndpointSide = iota + + // ServerSide indicates that an Endpoint is a server (initially-inactive; + // first method call should be RecvFirst.) + ServerSide +) + // Init must be called on zero-value Endpoints before first use. If it // succeeds, ep.Destroy() must be called once the Endpoint is no longer in use. // @@ -65,7 +78,17 @@ type Endpoint struct { // Endpoint. FD may differ between Endpoints if they are in different // processes, but must represent the same file. The packet window must // initially be filled with zero bytes. -func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error { +func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error { + switch side { + case ClientSide: + ep.activeState = csClientActive + ep.inactiveState = csServerActive + case ServerSide: + ep.activeState = csServerActive + ep.inactiveState = csClientActive + default: + return fmt.Errorf("invalid EndpointSide: %v", side) + } if pwd.Length < pageSize { return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize) } @@ -78,9 +101,6 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err } ep.packet = m ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes) - // These will be overwritten by ep.Connect() for client Endpoints. - ep.activeState = csServerActive - ep.inactiveState = csClientActive if err := ep.ctrlInit(opts...); err != nil { ep.unmapPacket() return err @@ -90,9 +110,9 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err // NewEndpoint is a convenience function that returns an initialized Endpoint // allocated on the heap. -func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { +func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { var ep Endpoint - if err := ep.Init(pwd, opts...); err != nil { + if err := ep.Init(side, pwd, opts...); err != nil { return nil, err } return &ep, nil @@ -115,9 +135,9 @@ func (ep *Endpoint) unmapPacket() { } // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), -// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not -// wait for concurrent calls to return. The effect of Shutdown on the peer -// Endpoint is unspecified. Successive calls to Shutdown have no effect. +// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer +// Endpoint, to unblock and return errors. It does not wait for concurrent +// calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -152,28 +172,31 @@ const ( // The client is, by definition, initially active, so this must be 0. csClientActive = 0 csServerActive = 1 + csShutdown = 2 ) -// Connect designates ep as a client Endpoint and blocks until the peer -// Endpoint has called Endpoint.RecvFirst(). +// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst(). // -// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and -// ep.SendLast() have never been called. +// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), +// ep.SendRecv(), and ep.SendLast() have never been called. func (ep *Endpoint) Connect() error { - ep.activeState = csClientActive - ep.inactiveState = csServerActive - return ep.ctrlConnect() + err := ep.ctrlConnect() + if err == nil { + raceBecomeActive() + } + return err } // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then // returns the datagram length specified by that call. // -// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never -// been called. +// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and +// ep.SendLast() have never been called. func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -200,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then // they can only shoot themselves in the foot. *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlRoundTrip(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -222,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) } *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlWakeLast(); err != nil { return err } diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index edb6a8bef..8d88b845d 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -38,12 +38,12 @@ func Example() { panic(err) } var clientEP Endpoint - if err := clientEP.Init(pwd); err != nil { + if err := clientEP.Init(ClientSide, pwd); err != nil { panic(err) } defer clientEP.Destroy() var serverEP Endpoint - if err := serverEP.Init(pwd); err != nil { + if err := serverEP.Init(ServerSide, pwd); err != nil { panic(err) } defer serverEP.Destroy() diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index da9d736ab..168a487ec 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -39,11 +39,11 @@ func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []Endpoi c.pwa.Destroy() tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) } - if err := c.clientEP.Init(pwd, clientOpts...); err != nil { + if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil { c.pwa.Destroy() tb.Fatalf("failed to create client Endpoint: %v", err) } - if err := c.serverEP.Init(pwd, serverOpts...); err != nil { + if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil { c.pwa.Destroy() c.clientEP.Destroy() tb.Fatalf("failed to create server Endpoint: %v", err) @@ -62,17 +62,30 @@ func (c *testConnection) destroy() { } func testSendRecv(t *testing.T, c *testConnection) { + // This shared variable is used to confirm that synchronization between + // flipcall endpoints is visible to the Go race detector. + state := 0 var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() t.Logf("server Endpoint waiting for packet 1") if _, err := c.serverEP.RecvFirst(); err != nil { - t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return + } + state++ + if state != 2 { + t.Errorf("shared state counter: got %d, wanted 2", state) } t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") if _, err := c.serverEP.SendRecv(0); err != nil { - t.Fatalf("server Endpoint.SendRecv() failed: %v", err) + t.Errorf("server Endpoint.SendRecv() failed: %v", err) + return + } + state++ + if state != 4 { + t.Errorf("shared state counter: got %d, wanted 4", state) } t.Logf("server Endpoint got packet 3") }() @@ -87,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } + state++ + if state != 1 { + t.Errorf("shared state counter: got %d, wanted 1", state) + } t.Logf("client Endpoint sending packet 1 and waiting for packet 2") if _, err := c.clientEP.SendRecv(0); err != nil { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } + state++ + if state != 3 { + t.Errorf("shared state counter: got %d, wanted 3", state) + } t.Logf("client Endpoint got packet 2, sending packet 3") if err := c.clientEP.SendLast(0); err != nil { t.Fatalf("client Endpoint.SendLast() failed: %v", err) @@ -105,7 +126,30 @@ func TestSendRecv(t *testing.T) { testSendRecv(t, c) } -func testShutdownConnect(t *testing.T, c *testConnection) { +func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + if err := c.clientEP.Connect(); err == nil { + t.Errorf("client Endpoint.Connect() succeeded unexpectedly") + } +} + +func TestShutdownBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, false) +} + +func TestShutdownBeforeConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, true) +} + +func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var clientRun sync.WaitGroup clientRun.Add(1) go func() { @@ -115,44 +159,86 @@ func testShutdownConnect(t *testing.T, c *testConnection) { } }() time.Sleep(time.Second) // to allow c.clientEP.Connect() to block - c.clientEP.Shutdown() + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } clientRun.Wait() } -func TestShutdownConnect(t *testing.T) { +func TestShutdownDuringConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringConnect(t, c, false) +} + +func TestShutdownDuringConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringConnect(t, c, true) +} + +func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") + } +} + +func TestShutdownBeforeRecvFirstLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeRecvFirst(t, c, false) +} + +func TestShutdownBeforeRecvFirstRemote(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownConnect(t, c) + testShutdownBeforeRecvFirst(t, c, true) } -func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) { +func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() - _, err := c.serverEP.RecvFirst() - if err == nil { + if _, err := c.serverEP.RecvFirst(); err == nil { t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") } }() time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownRecvFirstBeforeConnect(t *testing.T) { +func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstBeforeConnect(t, c, false) +} + +func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownRecvFirstBeforeConnect(t, c) + testShutdownDuringRecvFirstBeforeConnect(t, c, true) } -func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) { +func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() if _, err := c.serverEP.RecvFirst(); err == nil { - t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly") + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") } }() defer func() { @@ -164,23 +250,75 @@ func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownRecvFirstAfterConnect(t *testing.T) { +func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstAfterConnect(t, c, false) +} + +func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstAfterConnect(t, c, true) +} + +func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + } + // At this point, the client must be blocked in c.clientEP.SendRecv(). + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) + } + if _, err := c.clientEP.SendRecv(0); err == nil { + t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly") + } +} + +func TestShutdownDuringClientSendRecvLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringClientSendRecv(t, c, false) +} + +func TestShutdownDuringClientSendRecvRemote(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownRecvFirstAfterConnect(t, c) + testShutdownDuringClientSendRecv(t, c, true) } -func testShutdownSendRecv(t *testing.T, c *testConnection) { +func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() if _, err := c.serverEP.RecvFirst(); err != nil { - t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return } if _, err := c.serverEP.SendRecv(0); err == nil { t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly") @@ -199,14 +337,24 @@ func testShutdownSendRecv(t *testing.T, c *testConnection) { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } time.Sleep(time.Second) // to allow serverEP.SendRecv() to block - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownSendRecv(t *testing.T) { +func TestShutdownDuringServerSendRecvLocal(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownSendRecv(t, c) + testShutdownDuringServerSendRecv(t, c, false) +} + +func TestShutdownDuringServerSendRecvRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringServerSendRecv(t, c, true) } func benchmarkSendRecv(b *testing.B, c *testConnection) { @@ -218,15 +366,17 @@ func benchmarkSendRecv(b *testing.B, c *testConnection) { return } if _, err := c.serverEP.RecvFirst(); err != nil { - b.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + b.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return } for i := 1; i < b.N; i++ { if _, err := c.serverEP.SendRecv(0); err != nil { - b.Fatalf("server Endpoint.SendRecv() failed: %v", err) + b.Errorf("server Endpoint.SendRecv() failed: %v", err) + return } } if err := c.serverEP.SendLast(0); err != nil { - b.Fatalf("server Endpoint.SendLast() failed: %v", err) + b.Errorf("server Endpoint.SendLast() failed: %v", err) } }() defer func() { diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 7c8977893..a37952637 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -17,17 +17,19 @@ package flipcall import ( "reflect" "unsafe" + + "gvisor.dev/gvisor/third_party/gvsync" ) -// Packets consist of an 8-byte header followed by an arbitrarily-sized +// Packets consist of a 16-byte header followed by an arbitrarily-sized // datagram. The header consists of: // // - A 4-byte native-endian connection state. // // - A 4-byte native-endian datagram length in bytes. +// +// - 8 reserved bytes. const ( - sizeofUint32 = unsafe.Sizeof(uint32(0)) - // PacketHeaderBytes is the size of a flipcall packet header in bytes. The // maximum datagram size supported by a flipcall connection is equal to the // length of the packet window minus PacketHeaderBytes. @@ -35,7 +37,7 @@ const ( // PacketHeaderBytes is exported to support its use in constant // expressions. Non-constant expressions may prefer to use // PacketWindowLengthForDataCap(). - PacketHeaderBytes = 2 * sizeofUint32 + PacketHeaderBytes = 16 ) func (ep *Endpoint) connState() *uint32 { @@ -43,7 +45,7 @@ func (ep *Endpoint) connState() *uint32 { } func (ep *Endpoint) dataLen() *uint32 { - return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32)) + return (*uint32)((unsafe.Pointer)(ep.packet + 4)) } // Data returns the datagram part of ep's packet window as a byte slice. @@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte { bsReflect.Cap = int(ep.dataCap) return bs } + +// ioSync is a dummy variable used to indicate synchronization to the Go race +// detector. Compare syscall.ioSync. +var ioSync int64 + +func raceBecomeActive() { + if gvsync.RaceEnabled { + gvsync.RaceAcquire((unsafe.Pointer)(&ioSync)) + } +} + +func raceBecomeInactive() { + if gvsync.RaceEnabled { + gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + } +} diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index e7dd812b3..b127a2bbb 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -59,7 +59,12 @@ func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeRespon func (ep *Endpoint) futexSwitchToPeer() error { // Update connection state to indicate that the peer should be active. if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { - return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState())) + switch cs := atomic.LoadUint32(ep.connState()); cs { + case csShutdown: + return shutdownError{} + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) + } } // Wake the peer's Endpoint.futexSwitchFromPeer(). @@ -75,16 +80,18 @@ func (ep *Endpoint) futexSwitchFromPeer() error { case ep.activeState: return nil case ep.inactiveState: - // Continue to FUTEX_WAIT. + if ep.isShutdownLocally() { + return shutdownError{} + } + if err := ep.futexWaitConnState(ep.inactiveState); err != nil { + return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) + } + continue + case csShutdown: + return shutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } - if ep.isShutdownLocally() { - return shutdownError{} - } - if err := ep.futexWaitConnState(ep.inactiveState); err != nil { - return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) - } } } |