summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/fs/host/socket.go62
-rw-r--r--pkg/sentry/fs/host/socket_test.go156
2 files changed, 28 insertions, 190 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 3ed137006..e4ec0f62c 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -15,6 +15,7 @@
package host
import (
+ "fmt"
"sync"
"syscall"
@@ -51,20 +52,6 @@ type ConnectedEndpoint struct {
// ref keeps track of references to a connectedEndpoint.
ref refs.AtomicRefCount
- // mu protects fd, readClosed and writeClosed.
- mu sync.RWMutex `state:"nosave"`
-
- // file is an *fd.FD containing the FD backing this endpoint. It must be
- // set to nil if it has been closed.
- file *fd.FD `state:"nosave"`
-
- // readClosed is true if the FD has read shutdown or if it has been closed.
- readClosed bool
-
- // writeClosed is true if the FD has write shutdown or if it has been
- // closed.
- writeClosed bool
-
// If srfd >= 0, it is the host FD that file was imported from.
srfd int `state:"wait"`
@@ -78,6 +65,13 @@ type ConnectedEndpoint struct {
// prevent lots of small messages from filling the real send buffer
// size on the host.
sndbuf int `state:"nosave"`
+
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // file is an *fd.FD containing the FD backing this endpoint. It must be
+ // set to nil if it has been closed.
+ file *fd.FD `state:"nosave"`
}
// init performs initialization required for creating new ConnectedEndpoints and
@@ -208,9 +202,6 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.writeClosed {
- return 0, false, syserr.ErrClosedForSend
- }
if !controlMessages.Empty() {
return 0, false, syserr.ErrInvalidEndpointState
@@ -244,8 +235,13 @@ func (c *ConnectedEndpoint) SendNotify() {}
// CloseSend implements transport.ConnectedEndpoint.CloseSend.
func (c *ConnectedEndpoint) CloseSend() {
c.mu.Lock()
- c.writeClosed = true
- c.mu.Unlock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_WR); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+ }
}
// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
@@ -255,9 +251,7 @@ func (c *ConnectedEndpoint) CloseNotify() {}
func (c *ConnectedEndpoint) Writable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.writeClosed {
- return true
- }
+
return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0
}
@@ -285,9 +279,6 @@ func (c *ConnectedEndpoint) EventUpdate() {
func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.readClosed {
- return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive
- }
var cm unet.ControlMessage
if numRights > 0 {
@@ -344,31 +335,34 @@ func (c *ConnectedEndpoint) RecvNotify() {}
// CloseRecv implements transport.Receiver.CloseRecv.
func (c *ConnectedEndpoint) CloseRecv() {
c.mu.Lock()
- c.readClosed = true
- c.mu.Unlock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_RD); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+ }
}
// Readable implements transport.Receiver.Readable.
func (c *ConnectedEndpoint) Readable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.readClosed {
- return true
- }
+
return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0
}
// SendQueuedSize implements transport.Receiver.SendQueuedSize.
func (c *ConnectedEndpoint) SendQueuedSize() int64 {
- // SendQueuedSize isn't supported for host sockets because we don't allow the
- // sentry to call ioctl(2).
+ // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
return -1
}
// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
- // RecvQueuedSize isn't supported for host sockets because we don't allow the
- // sentry to call ioctl(2).
+ // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
return -1
}
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index 06392a65a..bc3ce5627 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -198,20 +198,6 @@ func TestListen(t *testing.T) {
}
}
-func TestSend(t *testing.T) {
- e := ConnectedEndpoint{writeClosed: true}
- if _, _, err := e.Send(nil, transport.ControlMessages{}, tcpip.FullAddress{}); err != syserr.ErrClosedForSend {
- t.Errorf("Got %#v.Send() = %v, want = %v", e, err, syserr.ErrClosedForSend)
- }
-}
-
-func TestRecv(t *testing.T) {
- e := ConnectedEndpoint{readClosed: true}
- if _, _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive {
- t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, syserr.ErrClosedForReceive)
- }
-}
-
func TestPasscred(t *testing.T) {
e := ConnectedEndpoint{}
if got, want := e.Passcred(), false; got != want {
@@ -244,20 +230,6 @@ func TestQueuedSize(t *testing.T) {
}
}
-func TestReadable(t *testing.T) {
- e := ConnectedEndpoint{readClosed: true}
- if got, want := e.Readable(), true; got != want {
- t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want)
- }
-}
-
-func TestWritable(t *testing.T) {
- e := ConnectedEndpoint{writeClosed: true}
- if got, want := e.Writable(), true; got != want {
- t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want)
- }
-}
-
func TestRelease(t *testing.T) {
f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
@@ -272,131 +244,3 @@ func TestRelease(t *testing.T) {
t.Errorf("got = %#v, want = %#v", c, want)
}
}
-
-func TestClose(t *testing.T) {
- type testCase struct {
- name string
- cep *ConnectedEndpoint
- addFD bool
- f func()
- want *ConnectedEndpoint
- }
-
- var tests []testCase
-
- // nil is the value used by ConnectedEndpoint to indicate a closed file.
- // Non-nil files are used to check if the file gets closed.
-
- f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
- tests = append(tests, testCase{
- name: "First CloseRecv",
- cep: c,
- addFD: false,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
- tests = append(tests, testCase{
- name: "Second CloseRecv",
- cep: c,
- addFD: false,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
- tests = append(tests, testCase{
- name: "First CloseSend",
- cep: c,
- addFD: false,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
- tests = append(tests, testCase{
- name: "Second CloseSend",
- cep: c,
- addFD: false,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
- tests = append(tests, testCase{
- name: "CloseSend then CloseRecv",
- cep: c,
- addFD: true,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
- tests = append(tests, testCase{
- name: "CloseRecv then CloseSend",
- cep: c,
- addFD: true,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
- tests = append(tests, testCase{
- name: "Full close then CloseRecv",
- cep: c,
- addFD: false,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
- tests = append(tests, testCase{
- name: "Full close then CloseSend",
- cep: c,
- addFD: false,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- for _, test := range tests {
- if test.addFD {
- fdnotifier.AddFD(int32(test.cep.file.FD()), nil)
- }
- if test.f(); !reflect.DeepEqual(test.cep, test.want) {
- t.Errorf("%s: got = %#v, want = %#v", test.name, test.cep, test.want)
- }
- }
-}