From f295e26b8abe395eaf1d4bee9a792a79b34d156f Mon Sep 17 00:00:00 2001 From: Brian Geffon Date: Wed, 16 May 2018 13:06:23 -0700 Subject: Release mutex in BidirectionalConnect to avoid deadlock. When doing a BidirectionalConnect we don't need to continue holding the ConnectingEndpoint's mutex when creating the NewConnectedEndpoint as it was held during the Connect. Additionally, we're not holding the baseEndpoint mutex while Unregistering an event. PiperOrigin-RevId: 196875557 Change-Id: Ied4ceed89de883121c6cba81bc62aa3a8549b1e9 --- pkg/sentry/fs/gofer/socket.go | 20 ++++++++---- pkg/sentry/fs/host/socket.go | 67 +++++++++++++++++++++------------------ pkg/sentry/fs/host/socket_test.go | 64 ++++++++++++++++++------------------- pkg/tcpip/transport/unix/unix.go | 4 +-- 4 files changed, 85 insertions(+), 70 deletions(-) diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 954000ef0..406756f5f 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -79,26 +79,33 @@ func (e *endpoint) BidirectionalConnect(ce unix.ConnectingEndpoint, returnConnec // No lock ordering required as only the ConnectingEndpoint has a mutex. ce.Lock() - defer ce.Unlock() // Check connecting state. if ce.Connected() { + ce.Unlock() return tcpip.ErrAlreadyConnected } if ce.Listening() { + ce.Unlock() return tcpip.ErrInvalidEndpointState } hostFile, err := e.file.Connect(cf) if err != nil { + ce.Unlock() return tcpip.ErrConnectionRefused } - r, c, terr := host.NewConnectedEndpoint(hostFile, ce.WaiterQueue(), e.path) + c, terr := host.NewConnectedEndpoint(hostFile, ce.WaiterQueue(), e.path) if terr != nil { + ce.Unlock() return terr } - returnConnect(r, c) + + returnConnect(c, c) + ce.Unlock() + c.Init() + return nil } @@ -109,14 +116,15 @@ func (e *endpoint) UnidirectionalConnect() (unix.ConnectedEndpoint, *tcpip.Error return nil, tcpip.ErrConnectionRefused } - r, c, terr := host.NewConnectedEndpoint(hostFile, &waiter.Queue{}, e.path) + c, terr := host.NewConnectedEndpoint(hostFile, &waiter.Queue{}, e.path) if terr != nil { return nil, terr } + c.Init() // We don't need the receiver. - r.CloseRecv() - r.Release() + c.CloseRecv() + c.Release() return c, nil } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 467633052..f4689f51f 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -286,26 +286,33 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil } -// NewConnectedEndpoint creates a new unix.Receiver and unix.ConnectedEndpoint -// backed by a host FD that will pretend to be bound at a given sentry path. -func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (unix.Receiver, unix.ConnectedEndpoint, *tcpip.Error) { - if err := fdnotifier.AddFD(int32(file.FD()), queue); err != nil { - return nil, nil, translateError(err) - } - - e := &connectedEndpoint{path: path, queue: queue, file: file} +// NewConnectedEndpoint creates a new ConnectedEndpoint backed by +// a host FD that will pretend to be bound at a given sentry path. +// +// The caller is responsible for calling Init(). Additionaly, Release needs +// to be called twice because host.ConnectedEndpoint is both a +// unix.Receiver and unix.ConnectedEndpoint. +func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *tcpip.Error) { + e := &ConnectedEndpoint{path: path, queue: queue, file: file} // AtomicRefCounters start off with a single reference. We need two. e.ref.IncRef() - return e, e, nil + return e, nil +} + +// Init will do initialization required without holding other locks. +func (c *ConnectedEndpoint) Init() { + if err := fdnotifier.AddFD(int32(c.file.FD()), c.queue); err != nil { + panic(err) + } } -// connectedEndpoint is a host FD backed implementation of +// ConnectedEndpoint is a host FD backed implementation of // unix.ConnectedEndpoint and unix.Receiver. // -// connectedEndpoint does not support save/restore for now. -type connectedEndpoint struct { +// ConnectedEndpoint does not support save/restore for now. +type ConnectedEndpoint struct { queue *waiter.Queue path string @@ -328,7 +335,7 @@ type connectedEndpoint struct { } // Send implements unix.ConnectedEndpoint.Send. -func (c *connectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { +func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { c.mu.RLock() defer c.mu.RUnlock() if c.writeClosed { @@ -341,20 +348,20 @@ func (c *connectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess } // SendNotify implements unix.ConnectedEndpoint.SendNotify. -func (c *connectedEndpoint) SendNotify() {} +func (c *ConnectedEndpoint) SendNotify() {} // CloseSend implements unix.ConnectedEndpoint.CloseSend. -func (c *connectedEndpoint) CloseSend() { +func (c *ConnectedEndpoint) CloseSend() { c.mu.Lock() c.writeClosed = true c.mu.Unlock() } // CloseNotify implements unix.ConnectedEndpoint.CloseNotify. -func (c *connectedEndpoint) CloseNotify() {} +func (c *ConnectedEndpoint) CloseNotify() {} // Writable implements unix.ConnectedEndpoint.Writable. -func (c *connectedEndpoint) Writable() bool { +func (c *ConnectedEndpoint) Writable() bool { c.mu.RLock() defer c.mu.RUnlock() if c.writeClosed { @@ -364,18 +371,18 @@ func (c *connectedEndpoint) Writable() bool { } // Passcred implements unix.ConnectedEndpoint.Passcred. -func (c *connectedEndpoint) Passcred() bool { +func (c *ConnectedEndpoint) Passcred() bool { // We don't support credential passing for host sockets. return false } // GetLocalAddress implements unix.ConnectedEndpoint.GetLocalAddress. -func (c *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil } // EventUpdate implements unix.ConnectedEndpoint.EventUpdate. -func (c *connectedEndpoint) EventUpdate() { +func (c *ConnectedEndpoint) EventUpdate() { c.mu.RLock() defer c.mu.RUnlock() if c.file.FD() != -1 { @@ -384,7 +391,7 @@ func (c *connectedEndpoint) EventUpdate() { } // Recv implements unix.Receiver.Recv. -func (c *connectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { +func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { c.mu.RLock() defer c.mu.RUnlock() if c.readClosed { @@ -397,24 +404,24 @@ func (c *connectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p } // close releases all resources related to the endpoint. -func (c *connectedEndpoint) close() { +func (c *ConnectedEndpoint) close() { fdnotifier.RemoveFD(int32(c.file.FD())) c.file.Close() c.file = nil } // RecvNotify implements unix.Receiver.RecvNotify. -func (c *connectedEndpoint) RecvNotify() {} +func (c *ConnectedEndpoint) RecvNotify() {} // CloseRecv implements unix.Receiver.CloseRecv. -func (c *connectedEndpoint) CloseRecv() { +func (c *ConnectedEndpoint) CloseRecv() { c.mu.Lock() c.readClosed = true c.mu.Unlock() } // Readable implements unix.Receiver.Readable. -func (c *connectedEndpoint) Readable() bool { +func (c *ConnectedEndpoint) Readable() bool { c.mu.RLock() defer c.mu.RUnlock() if c.readClosed { @@ -424,21 +431,21 @@ func (c *connectedEndpoint) Readable() bool { } // SendQueuedSize implements unix.Receiver.SendQueuedSize. -func (c *connectedEndpoint) SendQueuedSize() int64 { +func (c *ConnectedEndpoint) SendQueuedSize() int64 { // SendQueuedSize isn't supported for host sockets because we don't allow the // sentry to call ioctl(2). return -1 } // RecvQueuedSize implements unix.Receiver.RecvQueuedSize. -func (c *connectedEndpoint) RecvQueuedSize() int64 { +func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // RecvQueuedSize isn't supported for host sockets because we don't allow the // sentry to call ioctl(2). return -1 } // SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize. -func (c *connectedEndpoint) SendMaxQueueSize() int64 { +func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF) if err != nil { return -1 @@ -447,7 +454,7 @@ func (c *connectedEndpoint) SendMaxQueueSize() int64 { } // RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize. -func (c *connectedEndpoint) RecvMaxQueueSize() int64 { +func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF) if err != nil { return -1 @@ -456,7 +463,7 @@ func (c *connectedEndpoint) RecvMaxQueueSize() int64 { } // Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release. -func (c *connectedEndpoint) Release() { +func (c *ConnectedEndpoint) Release() { c.ref.DecRefWithDestructor(c.close) } diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index 9b73c5173..8b752737d 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -31,11 +31,11 @@ import ( ) var ( - // Make sure that connectedEndpoint implements unix.ConnectedEndpoint. - _ = unix.ConnectedEndpoint(new(connectedEndpoint)) + // Make sure that ConnectedEndpoint implements unix.ConnectedEndpoint. + _ = unix.ConnectedEndpoint(new(ConnectedEndpoint)) - // Make sure that connectedEndpoint implements unix.Receiver. - _ = unix.Receiver(new(connectedEndpoint)) + // Make sure that ConnectedEndpoint implements unix.Receiver. + _ = unix.Receiver(new(ConnectedEndpoint)) ) func getFl(fd int) (uint32, error) { @@ -198,28 +198,28 @@ func TestListen(t *testing.T) { } func TestSend(t *testing.T) { - e := connectedEndpoint{writeClosed: true} + e := ConnectedEndpoint{writeClosed: true} if _, _, err := e.Send(nil, unix.ControlMessages{}, tcpip.FullAddress{}); err != tcpip.ErrClosedForSend { t.Errorf("Got %#v.Send() = %v, want = %v", e, err, tcpip.ErrClosedForSend) } } func TestRecv(t *testing.T) { - e := connectedEndpoint{readClosed: true} + e := ConnectedEndpoint{readClosed: true} if _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != tcpip.ErrClosedForReceive { t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, tcpip.ErrClosedForReceive) } } func TestPasscred(t *testing.T) { - e := connectedEndpoint{} + e := ConnectedEndpoint{} if got, want := e.Passcred(), false; got != want { t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) } } func TestGetLocalAddress(t *testing.T) { - e := connectedEndpoint{path: "foo"} + e := ConnectedEndpoint{path: "foo"} want := tcpip.FullAddress{Addr: tcpip.Address("foo")} if got, err := e.GetLocalAddress(); err != nil || got != want { t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) @@ -227,7 +227,7 @@ func TestGetLocalAddress(t *testing.T) { } func TestQueuedSize(t *testing.T) { - e := connectedEndpoint{} + e := ConnectedEndpoint{} tests := []struct { name string f func() int64 @@ -244,14 +244,14 @@ func TestQueuedSize(t *testing.T) { } func TestReadable(t *testing.T) { - e := connectedEndpoint{readClosed: true} + e := ConnectedEndpoint{readClosed: true} if got, want := e.Readable(), true; got != want { t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want) } } func TestWritable(t *testing.T) { - e := connectedEndpoint{writeClosed: true} + e := ConnectedEndpoint{writeClosed: true} if got, want := e.Writable(), true; got != want { t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want) } @@ -262,8 +262,8 @@ func TestRelease(t *testing.T) { if err != nil { t.Fatal("Creating socket:", err) } - c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - want := &connectedEndpoint{queue: c.queue} + c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} + want := &ConnectedEndpoint{queue: c.queue} want.ref.DecRef() fdnotifier.AddFD(int32(c.file.FD()), nil) c.Release() @@ -275,119 +275,119 @@ func TestRelease(t *testing.T) { func TestClose(t *testing.T) { type testCase struct { name string - cep *connectedEndpoint + cep *ConnectedEndpoint addFD bool f func() - want *connectedEndpoint + want *ConnectedEndpoint } var tests []testCase - // nil is the value used by connectedEndpoint to indicate a closed file. + // nil is the value used by ConnectedEndpoint to indicate a closed file. // Non-nil files are used to check if the file gets closed. f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} + c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} tests = append(tests, testCase{ name: "First CloseRecv", cep: c, addFD: false, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} tests = append(tests, testCase{ name: "Second CloseRecv", cep: c, addFD: false, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} tests = append(tests, testCase{ name: "First CloseSend", cep: c, addFD: false, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} tests = append(tests, testCase{ name: "Second CloseSend", cep: c, addFD: false, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} tests = append(tests, testCase{ name: "CloseSend then CloseRecv", cep: c, addFD: true, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} tests = append(tests, testCase{ name: "CloseRecv then CloseSend", cep: c, addFD: true, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} tests = append(tests, testCase{ name: "Full close then CloseRecv", cep: c, addFD: false, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} tests = append(tests, testCase{ name: "Full close then CloseSend", cep: c, addFD: false, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) for _, test := range tests { diff --git a/pkg/tcpip/transport/unix/unix.go b/pkg/tcpip/transport/unix/unix.go index 5fe37eb71..72c21a432 100644 --- a/pkg/tcpip/transport/unix/unix.go +++ b/pkg/tcpip/transport/unix/unix.go @@ -677,8 +677,8 @@ type baseEndpoint struct { // EventRegister implements waiter.Waitable.EventRegister. func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) { - e.Lock() e.Queue.EventRegister(we, mask) + e.Lock() if e.connected != nil { e.connected.EventUpdate() } @@ -687,8 +687,8 @@ func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) { // EventUnregister implements waiter.Waitable.EventUnregister. func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { - e.Lock() e.Queue.EventUnregister(we) + e.Lock() if e.connected != nil { e.connected.EventUpdate() } -- cgit v1.2.3