diff options
-rw-r--r-- | pkg/sentry/fs/gofer/socket.go | 20 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket.go | 67 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket_test.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/transport/unix/unix.go | 4 |
4 files changed, 85 insertions, 70 deletions
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 954000ef0..406756f5f 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -79,26 +79,33 @@ func (e *endpoint) BidirectionalConnect(ce unix.ConnectingEndpoint, returnConnec // No lock ordering required as only the ConnectingEndpoint has a mutex. ce.Lock() - defer ce.Unlock() // Check connecting state. if ce.Connected() { + ce.Unlock() return tcpip.ErrAlreadyConnected } if ce.Listening() { + ce.Unlock() return tcpip.ErrInvalidEndpointState } hostFile, err := e.file.Connect(cf) if err != nil { + ce.Unlock() return tcpip.ErrConnectionRefused } - r, c, terr := host.NewConnectedEndpoint(hostFile, ce.WaiterQueue(), e.path) + c, terr := host.NewConnectedEndpoint(hostFile, ce.WaiterQueue(), e.path) if terr != nil { + ce.Unlock() return terr } - returnConnect(r, c) + + returnConnect(c, c) + ce.Unlock() + c.Init() + return nil } @@ -109,14 +116,15 @@ func (e *endpoint) UnidirectionalConnect() (unix.ConnectedEndpoint, *tcpip.Error return nil, tcpip.ErrConnectionRefused } - r, c, terr := host.NewConnectedEndpoint(hostFile, &waiter.Queue{}, e.path) + c, terr := host.NewConnectedEndpoint(hostFile, &waiter.Queue{}, e.path) if terr != nil { return nil, terr } + c.Init() // We don't need the receiver. - r.CloseRecv() - r.Release() + c.CloseRecv() + c.Release() return c, nil } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 467633052..f4689f51f 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -286,26 +286,33 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil } -// NewConnectedEndpoint creates a new unix.Receiver and unix.ConnectedEndpoint -// backed by a host FD that will pretend to be bound at a given sentry path. -func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (unix.Receiver, unix.ConnectedEndpoint, *tcpip.Error) { - if err := fdnotifier.AddFD(int32(file.FD()), queue); err != nil { - return nil, nil, translateError(err) - } - - e := &connectedEndpoint{path: path, queue: queue, file: file} +// NewConnectedEndpoint creates a new ConnectedEndpoint backed by +// a host FD that will pretend to be bound at a given sentry path. +// +// The caller is responsible for calling Init(). Additionaly, Release needs +// to be called twice because host.ConnectedEndpoint is both a +// unix.Receiver and unix.ConnectedEndpoint. +func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *tcpip.Error) { + e := &ConnectedEndpoint{path: path, queue: queue, file: file} // AtomicRefCounters start off with a single reference. We need two. e.ref.IncRef() - return e, e, nil + return e, nil +} + +// Init will do initialization required without holding other locks. +func (c *ConnectedEndpoint) Init() { + if err := fdnotifier.AddFD(int32(c.file.FD()), c.queue); err != nil { + panic(err) + } } -// connectedEndpoint is a host FD backed implementation of +// ConnectedEndpoint is a host FD backed implementation of // unix.ConnectedEndpoint and unix.Receiver. // -// connectedEndpoint does not support save/restore for now. -type connectedEndpoint struct { +// ConnectedEndpoint does not support save/restore for now. +type ConnectedEndpoint struct { queue *waiter.Queue path string @@ -328,7 +335,7 @@ type connectedEndpoint struct { } // Send implements unix.ConnectedEndpoint.Send. -func (c *connectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { +func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { c.mu.RLock() defer c.mu.RUnlock() if c.writeClosed { @@ -341,20 +348,20 @@ func (c *connectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess } // SendNotify implements unix.ConnectedEndpoint.SendNotify. -func (c *connectedEndpoint) SendNotify() {} +func (c *ConnectedEndpoint) SendNotify() {} // CloseSend implements unix.ConnectedEndpoint.CloseSend. -func (c *connectedEndpoint) CloseSend() { +func (c *ConnectedEndpoint) CloseSend() { c.mu.Lock() c.writeClosed = true c.mu.Unlock() } // CloseNotify implements unix.ConnectedEndpoint.CloseNotify. -func (c *connectedEndpoint) CloseNotify() {} +func (c *ConnectedEndpoint) CloseNotify() {} // Writable implements unix.ConnectedEndpoint.Writable. -func (c *connectedEndpoint) Writable() bool { +func (c *ConnectedEndpoint) Writable() bool { c.mu.RLock() defer c.mu.RUnlock() if c.writeClosed { @@ -364,18 +371,18 @@ func (c *connectedEndpoint) Writable() bool { } // Passcred implements unix.ConnectedEndpoint.Passcred. -func (c *connectedEndpoint) Passcred() bool { +func (c *ConnectedEndpoint) Passcred() bool { // We don't support credential passing for host sockets. return false } // GetLocalAddress implements unix.ConnectedEndpoint.GetLocalAddress. -func (c *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil } // EventUpdate implements unix.ConnectedEndpoint.EventUpdate. -func (c *connectedEndpoint) EventUpdate() { +func (c *ConnectedEndpoint) EventUpdate() { c.mu.RLock() defer c.mu.RUnlock() if c.file.FD() != -1 { @@ -384,7 +391,7 @@ func (c *connectedEndpoint) EventUpdate() { } // Recv implements unix.Receiver.Recv. -func (c *connectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { +func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { c.mu.RLock() defer c.mu.RUnlock() if c.readClosed { @@ -397,24 +404,24 @@ func (c *connectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p } // close releases all resources related to the endpoint. -func (c *connectedEndpoint) close() { +func (c *ConnectedEndpoint) close() { fdnotifier.RemoveFD(int32(c.file.FD())) c.file.Close() c.file = nil } // RecvNotify implements unix.Receiver.RecvNotify. -func (c *connectedEndpoint) RecvNotify() {} +func (c *ConnectedEndpoint) RecvNotify() {} // CloseRecv implements unix.Receiver.CloseRecv. -func (c *connectedEndpoint) CloseRecv() { +func (c *ConnectedEndpoint) CloseRecv() { c.mu.Lock() c.readClosed = true c.mu.Unlock() } // Readable implements unix.Receiver.Readable. -func (c *connectedEndpoint) Readable() bool { +func (c *ConnectedEndpoint) Readable() bool { c.mu.RLock() defer c.mu.RUnlock() if c.readClosed { @@ -424,21 +431,21 @@ func (c *connectedEndpoint) Readable() bool { } // SendQueuedSize implements unix.Receiver.SendQueuedSize. -func (c *connectedEndpoint) SendQueuedSize() int64 { +func (c *ConnectedEndpoint) SendQueuedSize() int64 { // SendQueuedSize isn't supported for host sockets because we don't allow the // sentry to call ioctl(2). return -1 } // RecvQueuedSize implements unix.Receiver.RecvQueuedSize. -func (c *connectedEndpoint) RecvQueuedSize() int64 { +func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // RecvQueuedSize isn't supported for host sockets because we don't allow the // sentry to call ioctl(2). return -1 } // SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize. -func (c *connectedEndpoint) SendMaxQueueSize() int64 { +func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF) if err != nil { return -1 @@ -447,7 +454,7 @@ func (c *connectedEndpoint) SendMaxQueueSize() int64 { } // RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize. -func (c *connectedEndpoint) RecvMaxQueueSize() int64 { +func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF) if err != nil { return -1 @@ -456,7 +463,7 @@ func (c *connectedEndpoint) RecvMaxQueueSize() int64 { } // Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release. -func (c *connectedEndpoint) Release() { +func (c *ConnectedEndpoint) Release() { c.ref.DecRefWithDestructor(c.close) } diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index 9b73c5173..8b752737d 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -31,11 +31,11 @@ import ( ) var ( - // Make sure that connectedEndpoint implements unix.ConnectedEndpoint. - _ = unix.ConnectedEndpoint(new(connectedEndpoint)) + // Make sure that ConnectedEndpoint implements unix.ConnectedEndpoint. + _ = unix.ConnectedEndpoint(new(ConnectedEndpoint)) - // Make sure that connectedEndpoint implements unix.Receiver. - _ = unix.Receiver(new(connectedEndpoint)) + // Make sure that ConnectedEndpoint implements unix.Receiver. + _ = unix.Receiver(new(ConnectedEndpoint)) ) func getFl(fd int) (uint32, error) { @@ -198,28 +198,28 @@ func TestListen(t *testing.T) { } func TestSend(t *testing.T) { - e := connectedEndpoint{writeClosed: true} + e := ConnectedEndpoint{writeClosed: true} if _, _, err := e.Send(nil, unix.ControlMessages{}, tcpip.FullAddress{}); err != tcpip.ErrClosedForSend { t.Errorf("Got %#v.Send() = %v, want = %v", e, err, tcpip.ErrClosedForSend) } } func TestRecv(t *testing.T) { - e := connectedEndpoint{readClosed: true} + e := ConnectedEndpoint{readClosed: true} if _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != tcpip.ErrClosedForReceive { t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, tcpip.ErrClosedForReceive) } } func TestPasscred(t *testing.T) { - e := connectedEndpoint{} + e := ConnectedEndpoint{} if got, want := e.Passcred(), false; got != want { t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) } } func TestGetLocalAddress(t *testing.T) { - e := connectedEndpoint{path: "foo"} + e := ConnectedEndpoint{path: "foo"} want := tcpip.FullAddress{Addr: tcpip.Address("foo")} if got, err := e.GetLocalAddress(); err != nil || got != want { t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) @@ -227,7 +227,7 @@ func TestGetLocalAddress(t *testing.T) { } func TestQueuedSize(t *testing.T) { - e := connectedEndpoint{} + e := ConnectedEndpoint{} tests := []struct { name string f func() int64 @@ -244,14 +244,14 @@ func TestQueuedSize(t *testing.T) { } func TestReadable(t *testing.T) { - e := connectedEndpoint{readClosed: true} + e := ConnectedEndpoint{readClosed: true} if got, want := e.Readable(), true; got != want { t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want) } } func TestWritable(t *testing.T) { - e := connectedEndpoint{writeClosed: true} + e := ConnectedEndpoint{writeClosed: true} if got, want := e.Writable(), true; got != want { t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want) } @@ -262,8 +262,8 @@ func TestRelease(t *testing.T) { if err != nil { t.Fatal("Creating socket:", err) } - c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - want := &connectedEndpoint{queue: c.queue} + c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} + want := &ConnectedEndpoint{queue: c.queue} want.ref.DecRef() fdnotifier.AddFD(int32(c.file.FD()), nil) c.Release() @@ -275,119 +275,119 @@ func TestRelease(t *testing.T) { func TestClose(t *testing.T) { type testCase struct { name string - cep *connectedEndpoint + cep *ConnectedEndpoint addFD bool f func() - want *connectedEndpoint + want *ConnectedEndpoint } var tests []testCase - // nil is the value used by connectedEndpoint to indicate a closed file. + // nil is the value used by ConnectedEndpoint to indicate a closed file. // Non-nil files are used to check if the file gets closed. f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} + c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} tests = append(tests, testCase{ name: "First CloseRecv", cep: c, addFD: false, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} tests = append(tests, testCase{ name: "Second CloseRecv", cep: c, addFD: false, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} tests = append(tests, testCase{ name: "First CloseSend", cep: c, addFD: false, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} tests = append(tests, testCase{ name: "Second CloseSend", cep: c, addFD: false, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} tests = append(tests, testCase{ name: "CloseSend then CloseRecv", cep: c, addFD: true, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} tests = append(tests, testCase{ name: "CloseRecv then CloseSend", cep: c, addFD: true, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} tests = append(tests, testCase{ name: "Full close then CloseRecv", cep: c, addFD: false, f: c.CloseRecv, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { t.Fatal("Creating socket:", err) } - c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} + c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} tests = append(tests, testCase{ name: "Full close then CloseSend", cep: c, addFD: false, f: c.CloseSend, - want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, + want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, }) for _, test := range tests { diff --git a/pkg/tcpip/transport/unix/unix.go b/pkg/tcpip/transport/unix/unix.go index 5fe37eb71..72c21a432 100644 --- a/pkg/tcpip/transport/unix/unix.go +++ b/pkg/tcpip/transport/unix/unix.go @@ -677,8 +677,8 @@ type baseEndpoint struct { // EventRegister implements waiter.Waitable.EventRegister. func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) { - e.Lock() e.Queue.EventRegister(we, mask) + e.Lock() if e.connected != nil { e.connected.EventUpdate() } @@ -687,8 +687,8 @@ func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) { // EventUnregister implements waiter.Waitable.EventUnregister. func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { - e.Lock() e.Queue.EventUnregister(we) + e.Lock() if e.connected != nil { e.connected.EventUpdate() } |