summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/fs/gofer/socket.go20
-rw-r--r--pkg/sentry/fs/host/socket.go67
-rw-r--r--pkg/sentry/fs/host/socket_test.go64
-rw-r--r--pkg/tcpip/transport/unix/unix.go4
4 files changed, 85 insertions, 70 deletions
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
index 954000ef0..406756f5f 100644
--- a/pkg/sentry/fs/gofer/socket.go
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -79,26 +79,33 @@ func (e *endpoint) BidirectionalConnect(ce unix.ConnectingEndpoint, returnConnec
// No lock ordering required as only the ConnectingEndpoint has a mutex.
ce.Lock()
- defer ce.Unlock()
// Check connecting state.
if ce.Connected() {
+ ce.Unlock()
return tcpip.ErrAlreadyConnected
}
if ce.Listening() {
+ ce.Unlock()
return tcpip.ErrInvalidEndpointState
}
hostFile, err := e.file.Connect(cf)
if err != nil {
+ ce.Unlock()
return tcpip.ErrConnectionRefused
}
- r, c, terr := host.NewConnectedEndpoint(hostFile, ce.WaiterQueue(), e.path)
+ c, terr := host.NewConnectedEndpoint(hostFile, ce.WaiterQueue(), e.path)
if terr != nil {
+ ce.Unlock()
return terr
}
- returnConnect(r, c)
+
+ returnConnect(c, c)
+ ce.Unlock()
+ c.Init()
+
return nil
}
@@ -109,14 +116,15 @@ func (e *endpoint) UnidirectionalConnect() (unix.ConnectedEndpoint, *tcpip.Error
return nil, tcpip.ErrConnectionRefused
}
- r, c, terr := host.NewConnectedEndpoint(hostFile, &waiter.Queue{}, e.path)
+ c, terr := host.NewConnectedEndpoint(hostFile, &waiter.Queue{}, e.path)
if terr != nil {
return nil, terr
}
+ c.Init()
// We don't need the receiver.
- r.CloseRecv()
- r.Release()
+ c.CloseRecv()
+ c.Release()
return c, nil
}
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 467633052..f4689f51f 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -286,26 +286,33 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu
return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil
}
-// NewConnectedEndpoint creates a new unix.Receiver and unix.ConnectedEndpoint
-// backed by a host FD that will pretend to be bound at a given sentry path.
-func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (unix.Receiver, unix.ConnectedEndpoint, *tcpip.Error) {
- if err := fdnotifier.AddFD(int32(file.FD()), queue); err != nil {
- return nil, nil, translateError(err)
- }
-
- e := &connectedEndpoint{path: path, queue: queue, file: file}
+// NewConnectedEndpoint creates a new ConnectedEndpoint backed by
+// a host FD that will pretend to be bound at a given sentry path.
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs
+// to be called twice because host.ConnectedEndpoint is both a
+// unix.Receiver and unix.ConnectedEndpoint.
+func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *tcpip.Error) {
+ e := &ConnectedEndpoint{path: path, queue: queue, file: file}
// AtomicRefCounters start off with a single reference. We need two.
e.ref.IncRef()
- return e, e, nil
+ return e, nil
+}
+
+// Init will do initialization required without holding other locks.
+func (c *ConnectedEndpoint) Init() {
+ if err := fdnotifier.AddFD(int32(c.file.FD()), c.queue); err != nil {
+ panic(err)
+ }
}
-// connectedEndpoint is a host FD backed implementation of
+// ConnectedEndpoint is a host FD backed implementation of
// unix.ConnectedEndpoint and unix.Receiver.
//
-// connectedEndpoint does not support save/restore for now.
-type connectedEndpoint struct {
+// ConnectedEndpoint does not support save/restore for now.
+type ConnectedEndpoint struct {
queue *waiter.Queue
path string
@@ -328,7 +335,7 @@ type connectedEndpoint struct {
}
// Send implements unix.ConnectedEndpoint.Send.
-func (c *connectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) {
+func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.writeClosed {
@@ -341,20 +348,20 @@ func (c *connectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess
}
// SendNotify implements unix.ConnectedEndpoint.SendNotify.
-func (c *connectedEndpoint) SendNotify() {}
+func (c *ConnectedEndpoint) SendNotify() {}
// CloseSend implements unix.ConnectedEndpoint.CloseSend.
-func (c *connectedEndpoint) CloseSend() {
+func (c *ConnectedEndpoint) CloseSend() {
c.mu.Lock()
c.writeClosed = true
c.mu.Unlock()
}
// CloseNotify implements unix.ConnectedEndpoint.CloseNotify.
-func (c *connectedEndpoint) CloseNotify() {}
+func (c *ConnectedEndpoint) CloseNotify() {}
// Writable implements unix.ConnectedEndpoint.Writable.
-func (c *connectedEndpoint) Writable() bool {
+func (c *ConnectedEndpoint) Writable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
if c.writeClosed {
@@ -364,18 +371,18 @@ func (c *connectedEndpoint) Writable() bool {
}
// Passcred implements unix.ConnectedEndpoint.Passcred.
-func (c *connectedEndpoint) Passcred() bool {
+func (c *ConnectedEndpoint) Passcred() bool {
// We don't support credential passing for host sockets.
return false
}
// GetLocalAddress implements unix.ConnectedEndpoint.GetLocalAddress.
-func (c *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil
}
// EventUpdate implements unix.ConnectedEndpoint.EventUpdate.
-func (c *connectedEndpoint) EventUpdate() {
+func (c *ConnectedEndpoint) EventUpdate() {
c.mu.RLock()
defer c.mu.RUnlock()
if c.file.FD() != -1 {
@@ -384,7 +391,7 @@ func (c *connectedEndpoint) EventUpdate() {
}
// Recv implements unix.Receiver.Recv.
-func (c *connectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.readClosed {
@@ -397,24 +404,24 @@ func (c *connectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p
}
// close releases all resources related to the endpoint.
-func (c *connectedEndpoint) close() {
+func (c *ConnectedEndpoint) close() {
fdnotifier.RemoveFD(int32(c.file.FD()))
c.file.Close()
c.file = nil
}
// RecvNotify implements unix.Receiver.RecvNotify.
-func (c *connectedEndpoint) RecvNotify() {}
+func (c *ConnectedEndpoint) RecvNotify() {}
// CloseRecv implements unix.Receiver.CloseRecv.
-func (c *connectedEndpoint) CloseRecv() {
+func (c *ConnectedEndpoint) CloseRecv() {
c.mu.Lock()
c.readClosed = true
c.mu.Unlock()
}
// Readable implements unix.Receiver.Readable.
-func (c *connectedEndpoint) Readable() bool {
+func (c *ConnectedEndpoint) Readable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
if c.readClosed {
@@ -424,21 +431,21 @@ func (c *connectedEndpoint) Readable() bool {
}
// SendQueuedSize implements unix.Receiver.SendQueuedSize.
-func (c *connectedEndpoint) SendQueuedSize() int64 {
+func (c *ConnectedEndpoint) SendQueuedSize() int64 {
// SendQueuedSize isn't supported for host sockets because we don't allow the
// sentry to call ioctl(2).
return -1
}
// RecvQueuedSize implements unix.Receiver.RecvQueuedSize.
-func (c *connectedEndpoint) RecvQueuedSize() int64 {
+func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// RecvQueuedSize isn't supported for host sockets because we don't allow the
// sentry to call ioctl(2).
return -1
}
// SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize.
-func (c *connectedEndpoint) SendMaxQueueSize() int64 {
+func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
if err != nil {
return -1
@@ -447,7 +454,7 @@ func (c *connectedEndpoint) SendMaxQueueSize() int64 {
}
// RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize.
-func (c *connectedEndpoint) RecvMaxQueueSize() int64 {
+func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF)
if err != nil {
return -1
@@ -456,7 +463,7 @@ func (c *connectedEndpoint) RecvMaxQueueSize() int64 {
}
// Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release.
-func (c *connectedEndpoint) Release() {
+func (c *ConnectedEndpoint) Release() {
c.ref.DecRefWithDestructor(c.close)
}
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index 9b73c5173..8b752737d 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -31,11 +31,11 @@ import (
)
var (
- // Make sure that connectedEndpoint implements unix.ConnectedEndpoint.
- _ = unix.ConnectedEndpoint(new(connectedEndpoint))
+ // Make sure that ConnectedEndpoint implements unix.ConnectedEndpoint.
+ _ = unix.ConnectedEndpoint(new(ConnectedEndpoint))
- // Make sure that connectedEndpoint implements unix.Receiver.
- _ = unix.Receiver(new(connectedEndpoint))
+ // Make sure that ConnectedEndpoint implements unix.Receiver.
+ _ = unix.Receiver(new(ConnectedEndpoint))
)
func getFl(fd int) (uint32, error) {
@@ -198,28 +198,28 @@ func TestListen(t *testing.T) {
}
func TestSend(t *testing.T) {
- e := connectedEndpoint{writeClosed: true}
+ e := ConnectedEndpoint{writeClosed: true}
if _, _, err := e.Send(nil, unix.ControlMessages{}, tcpip.FullAddress{}); err != tcpip.ErrClosedForSend {
t.Errorf("Got %#v.Send() = %v, want = %v", e, err, tcpip.ErrClosedForSend)
}
}
func TestRecv(t *testing.T) {
- e := connectedEndpoint{readClosed: true}
+ e := ConnectedEndpoint{readClosed: true}
if _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != tcpip.ErrClosedForReceive {
t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, tcpip.ErrClosedForReceive)
}
}
func TestPasscred(t *testing.T) {
- e := connectedEndpoint{}
+ e := ConnectedEndpoint{}
if got, want := e.Passcred(), false; got != want {
t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want)
}
}
func TestGetLocalAddress(t *testing.T) {
- e := connectedEndpoint{path: "foo"}
+ e := ConnectedEndpoint{path: "foo"}
want := tcpip.FullAddress{Addr: tcpip.Address("foo")}
if got, err := e.GetLocalAddress(); err != nil || got != want {
t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil)
@@ -227,7 +227,7 @@ func TestGetLocalAddress(t *testing.T) {
}
func TestQueuedSize(t *testing.T) {
- e := connectedEndpoint{}
+ e := ConnectedEndpoint{}
tests := []struct {
name string
f func() int64
@@ -244,14 +244,14 @@ func TestQueuedSize(t *testing.T) {
}
func TestReadable(t *testing.T) {
- e := connectedEndpoint{readClosed: true}
+ e := ConnectedEndpoint{readClosed: true}
if got, want := e.Readable(), true; got != want {
t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want)
}
}
func TestWritable(t *testing.T) {
- e := connectedEndpoint{writeClosed: true}
+ e := ConnectedEndpoint{writeClosed: true}
if got, want := e.Writable(), true; got != want {
t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want)
}
@@ -262,8 +262,8 @@ func TestRelease(t *testing.T) {
if err != nil {
t.Fatal("Creating socket:", err)
}
- c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
- want := &connectedEndpoint{queue: c.queue}
+ c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
+ want := &ConnectedEndpoint{queue: c.queue}
want.ref.DecRef()
fdnotifier.AddFD(int32(c.file.FD()), nil)
c.Release()
@@ -275,119 +275,119 @@ func TestRelease(t *testing.T) {
func TestClose(t *testing.T) {
type testCase struct {
name string
- cep *connectedEndpoint
+ cep *ConnectedEndpoint
addFD bool
f func()
- want *connectedEndpoint
+ want *ConnectedEndpoint
}
var tests []testCase
- // nil is the value used by connectedEndpoint to indicate a closed file.
+ // nil is the value used by ConnectedEndpoint to indicate a closed file.
// Non-nil files are used to check if the file gets closed.
f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
+ c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
tests = append(tests, testCase{
name: "First CloseRecv",
cep: c,
addFD: false,
f: c.CloseRecv,
- want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
})
f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
+ c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
tests = append(tests, testCase{
name: "Second CloseRecv",
cep: c,
addFD: false,
f: c.CloseRecv,
- want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
})
f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
+ c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
tests = append(tests, testCase{
name: "First CloseSend",
cep: c,
addFD: false,
f: c.CloseSend,
- want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
})
f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
+ c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
tests = append(tests, testCase{
name: "Second CloseSend",
cep: c,
addFD: false,
f: c.CloseSend,
- want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
})
f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
+ c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
tests = append(tests, testCase{
name: "CloseSend then CloseRecv",
cep: c,
addFD: true,
f: c.CloseRecv,
- want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
})
f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
+ c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
tests = append(tests, testCase{
name: "CloseRecv then CloseSend",
cep: c,
addFD: true,
f: c.CloseSend,
- want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
})
f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
+ c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
tests = append(tests, testCase{
name: "Full close then CloseRecv",
cep: c,
addFD: false,
f: c.CloseRecv,
- want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
})
f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
t.Fatal("Creating socket:", err)
}
- c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
+ c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
tests = append(tests, testCase{
name: "Full close then CloseSend",
cep: c,
addFD: false,
f: c.CloseSend,
- want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
})
for _, test := range tests {
diff --git a/pkg/tcpip/transport/unix/unix.go b/pkg/tcpip/transport/unix/unix.go
index 5fe37eb71..72c21a432 100644
--- a/pkg/tcpip/transport/unix/unix.go
+++ b/pkg/tcpip/transport/unix/unix.go
@@ -677,8 +677,8 @@ type baseEndpoint struct {
// EventRegister implements waiter.Waitable.EventRegister.
func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
- e.Lock()
e.Queue.EventRegister(we, mask)
+ e.Lock()
if e.connected != nil {
e.connected.EventUpdate()
}
@@ -687,8 +687,8 @@ func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
// EventUnregister implements waiter.Waitable.EventUnregister.
func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
- e.Lock()
e.Queue.EventUnregister(we)
+ e.Lock()
if e.connected != nil {
e.connected.EventUpdate()
}