From 1c2ecbb1a03ffaa3bcdb2ee69c879da5e7076fa5 Mon Sep 17 00:00:00 2001
From: Dean Deng <deandeng@google.com>
Date: Mon, 27 Apr 2020 16:00:39 -0700
Subject: Import host sockets.

The FileDescription implementation for hostfs sockets uses the standard Unix
socket implementation (unix.SocketVFS2), but is also tied to a hostfs dentry.

Updates #1672, #1476

PiperOrigin-RevId: 308716426
---
 pkg/sentry/fs/host/socket.go            |   4 +
 pkg/sentry/fs/host/socket_iovec.go      |   4 +
 pkg/sentry/fs/host/socket_unsafe.go     |   4 +
 pkg/sentry/fsimpl/host/BUILD            |  10 +
 pkg/sentry/fsimpl/host/host.go          |  57 +++--
 pkg/sentry/fsimpl/host/socket.go        | 397 ++++++++++++++++++++++++++++++++
 pkg/sentry/fsimpl/host/socket_iovec.go  | 113 +++++++++
 pkg/sentry/fsimpl/host/socket_unsafe.go | 101 ++++++++
 pkg/sentry/fsimpl/sockfs/sockfs.go      |  20 +-
 pkg/sentry/socket/unix/BUILD            |   1 +
 pkg/sentry/socket/unix/unix_vfs2.go     |  42 ++--
 11 files changed, 698 insertions(+), 55 deletions(-)
 create mode 100644 pkg/sentry/fsimpl/host/socket.go
 create mode 100644 pkg/sentry/fsimpl/host/socket_iovec.go
 create mode 100644 pkg/sentry/fsimpl/host/socket_unsafe.go

diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 06fc2d80a..b6e94583e 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -37,6 +37,8 @@ import (
 	"gvisor.dev/gvisor/pkg/waiter"
 )
 
+// LINT.IfChange
+
 // maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
 //
 // N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
@@ -388,3 +390,5 @@ func (c *ConnectedEndpoint) Release() {
 
 // CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
 func (c *ConnectedEndpoint) CloseUnread() {}
+
+// LINT.ThenChange(../../fsimpl/host/socket.go)
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
index af6955675..5c18dbd5e 100644
--- a/pkg/sentry/fs/host/socket_iovec.go
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -21,6 +21,8 @@ import (
 	"gvisor.dev/gvisor/pkg/syserror"
 )
 
+// LINT.IfChange
+
 // maxIovs is the maximum number of iovecs to pass to the host.
 var maxIovs = linux.UIO_MAXIOV
 
@@ -111,3 +113,5 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
 
 	return total, iovecs, nil, err
 }
+
+// LINT.ThenChange(../../fsimpl/host/socket_iovec.go)
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index f3bbed7ea..5d4f312cf 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,6 +19,8 @@ import (
 	"unsafe"
 )
 
+// LINT.IfChange
+
 // fdReadVec receives from fd to bufs.
 //
 // If the total length of bufs is > maxlen, fdReadVec will do a partial read
@@ -99,3 +101,5 @@ func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int6
 
 	return int64(n), length, err
 }
+
+// LINT.ThenChange(../../fsimpl/host/socket_unsafe.go)
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 2dcb03a73..e1c56d89b 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -8,6 +8,9 @@ go_library(
         "control.go",
         "host.go",
         "ioctl_unsafe.go",
+        "socket.go",
+        "socket_iovec.go",
+        "socket_unsafe.go",
         "tty.go",
         "util.go",
         "util_unsafe.go",
@@ -16,6 +19,7 @@ go_library(
     deps = [
         "//pkg/abi/linux",
         "//pkg/context",
+        "//pkg/fdnotifier",
         "//pkg/log",
         "//pkg/refs",
         "//pkg/sentry/arch",
@@ -25,12 +29,18 @@ go_library(
         "//pkg/sentry/kernel/auth",
         "//pkg/sentry/memmap",
         "//pkg/sentry/socket/control",
+        "//pkg/sentry/socket/unix",
         "//pkg/sentry/socket/unix/transport",
         "//pkg/sentry/unimpl",
+        "//pkg/sentry/uniqueid",
         "//pkg/sentry/vfs",
         "//pkg/sync",
+        "//pkg/syserr",
         "//pkg/syserror",
+        "//pkg/tcpip",
+        "//pkg/unet",
         "//pkg/usermem",
+        "//pkg/waiter",
         "@org_golang_x_sys//unix:go_default_library",
     ],
 )
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 1e53b5c1b..05538ea42 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -17,7 +17,6 @@
 package host
 
 import (
-	"errors"
 	"fmt"
 	"math"
 	"syscall"
@@ -31,6 +30,7 @@ import (
 	"gvisor.dev/gvisor/pkg/sentry/hostfd"
 	"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
 	"gvisor.dev/gvisor/pkg/sentry/memmap"
+	unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
 	"gvisor.dev/gvisor/pkg/sentry/vfs"
 	"gvisor.dev/gvisor/pkg/sync"
 	"gvisor.dev/gvisor/pkg/syserror"
@@ -156,7 +156,7 @@ type inode struct {
 
 // Note that these flags may become out of date, since they can be modified
 // on the host, e.g. with fcntl.
-func fileFlagsFromHostFD(fd int) (int, error) {
+func fileFlagsFromHostFD(fd int) (uint32, error) {
 	flags, err := unix.FcntlInt(uintptr(fd), syscall.F_GETFL, 0)
 	if err != nil {
 		log.Warningf("Failed to get file flags for donated FD %d: %v", fd, err)
@@ -164,7 +164,7 @@ func fileFlagsFromHostFD(fd int) (int, error) {
 	}
 	// TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags.
 	flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
-	return flags, nil
+	return uint32(flags), nil
 }
 
 // CheckPermissions implements kernfs.Inode.
@@ -361,6 +361,10 @@ func (i *inode) Destroy() {
 
 // Open implements kernfs.Inode.
 func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+	// Once created, we cannot re-open a socket fd through /proc/[pid]/fd/.
+	if i.Mode().FileType() == linux.S_IFSOCK {
+		return nil, syserror.ENXIO
+	}
 	return i.open(ctx, vfsd, rp.Mount())
 }
 
@@ -370,42 +374,45 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount) (*vfs.F
 		return nil, err
 	}
 	fileType := s.Mode & linux.FileTypeMask
+	flags, err := fileFlagsFromHostFD(i.hostFD)
+	if err != nil {
+		return nil, err
+	}
+
 	if fileType == syscall.S_IFSOCK {
 		if i.isTTY {
-			return nil, errors.New("cannot use host socket as TTY")
+			log.Warningf("cannot use host socket fd %d as TTY", i.hostFD)
+			return nil, syserror.ENOTTY
 		}
-		// TODO(gvisor.dev/issue/1672): support importing sockets.
-		return nil, errors.New("importing host sockets not supported")
+
+		ep, err := newEndpoint(ctx, i.hostFD)
+		if err != nil {
+			return nil, err
+		}
+		// Currently, we only allow Unix sockets to be imported.
+		return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d)
 	}
 
 	// TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that
 	// we don't allow importing arbitrary file types without proper support.
-	var (
-		vfsfd  *vfs.FileDescription
-		fdImpl vfs.FileDescriptionImpl
-	)
 	if i.isTTY {
 		fd := &ttyFD{
 			fileDescription: fileDescription{inode: i},
 			termios:         linux.DefaultSlaveTermios,
 		}
-		vfsfd = &fd.vfsfd
-		fdImpl = fd
-	} else {
-		// For simplicity, set offset to 0. Technically, we should
-		// only set to 0 on files that are not seekable (sockets, pipes, etc.),
-		// and use the offset from the host fd otherwise.
-		fd := &fileDescription{inode: i}
-		vfsfd = &fd.vfsfd
-		fdImpl = fd
-	}
-
-	flags, err := fileFlagsFromHostFD(i.hostFD)
-	if err != nil {
-		return nil, err
+		vfsfd := &fd.vfsfd
+		if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+			return nil, err
+		}
+		return vfsfd, nil
 	}
 
-	if err := vfsfd.Init(fdImpl, uint32(flags), mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+	// For simplicity, set offset to 0. Technically, we should
+	// only set to 0 on files that are not seekable (sockets, pipes, etc.),
+	// and use the offset from the host fd otherwise.
+	fd := &fileDescription{inode: i}
+	vfsfd := &fd.vfsfd
+	if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
 		return nil, err
 	}
 	return vfsfd, nil
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
new file mode 100644
index 000000000..39dd624a0
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -0,0 +1,397 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+	"fmt"
+	"syscall"
+
+	"gvisor.dev/gvisor/pkg/abi/linux"
+	"gvisor.dev/gvisor/pkg/context"
+	"gvisor.dev/gvisor/pkg/fdnotifier"
+	"gvisor.dev/gvisor/pkg/log"
+	"gvisor.dev/gvisor/pkg/refs"
+	"gvisor.dev/gvisor/pkg/sentry/socket/control"
+	"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+	"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+	"gvisor.dev/gvisor/pkg/sync"
+	"gvisor.dev/gvisor/pkg/syserr"
+	"gvisor.dev/gvisor/pkg/syserror"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/unet"
+	"gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Create a new host-backed endpoint from the given fd.
+func newEndpoint(ctx context.Context, hostFD int) (transport.Endpoint, error) {
+	// Set up an external transport.Endpoint using the host fd.
+	addr := fmt.Sprintf("hostfd:[%d]", hostFD)
+	var q waiter.Queue
+	e, err := NewConnectedEndpoint(ctx, hostFD, &q, addr, true /* saveable */)
+	if err != nil {
+		return nil, err.ToError()
+	}
+	e.Init()
+	ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
+	return ep, nil
+}
+
+// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
+//
+// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
+const maxSendBufferSize = 8 << 20
+
+// ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and
+// transport.Receiver. It is backed by a host fd that was imported at sentry
+// startup. This fd is shared with a hostfs inode, which retains ownership of
+// it.
+//
+// ConnectedEndpoint is saveable, since we expect that the host will provide
+// the same fd upon restore.
+//
+// As of this writing, we only allow Unix sockets to be imported.
+//
+// +stateify savable
+type ConnectedEndpoint struct {
+	// ref keeps track of references to a ConnectedEndpoint.
+	ref refs.AtomicRefCount
+
+	// mu protects fd below.
+	mu sync.RWMutex `state:"nosave"`
+
+	// fd is the host fd backing this endpoint.
+	fd int
+
+	// addr is the address at which this endpoint is bound.
+	addr string
+
+	queue *waiter.Queue
+
+	// sndbuf is the size of the send buffer.
+	//
+	// N.B. When this is smaller than the host size, we present it via
+	// GetSockOpt and message splitting/rejection in SendMsg, but do not
+	// prevent lots of small messages from filling the real send buffer
+	// size on the host.
+	sndbuf int64 `state:"nosave"`
+
+	// stype is the type of Unix socket.
+	stype linux.SockType
+}
+
+// init performs initialization required for creating new ConnectedEndpoints and
+// for restoring them.
+func (c *ConnectedEndpoint) init() *syserr.Error {
+	family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN)
+	if err != nil {
+		return syserr.FromError(err)
+	}
+
+	if family != syscall.AF_UNIX {
+		// We only allow Unix sockets.
+		return syserr.ErrInvalidEndpointState
+	}
+
+	stype, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+	if err != nil {
+		return syserr.FromError(err)
+	}
+
+	if err := syscall.SetNonblock(c.fd, true); err != nil {
+		return syserr.FromError(err)
+	}
+
+	sndbuf, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+	if err != nil {
+		return syserr.FromError(err)
+	}
+	if sndbuf > maxSendBufferSize {
+		log.Warningf("Socket send buffer too large: %d", sndbuf)
+		return syserr.ErrInvalidEndpointState
+	}
+
+	c.stype = linux.SockType(stype)
+	c.sndbuf = int64(sndbuf)
+
+	return nil
+}
+
+// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host fd
+// imported at sentry startup,
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewConnectedEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) {
+	e := ConnectedEndpoint{
+		fd:    hostFD,
+		addr:  addr,
+		queue: queue,
+	}
+
+	if err := e.init(); err != nil {
+		return nil, err
+	}
+
+	// AtomicRefCounters start off with a single reference. We need two.
+	e.ref.IncRef()
+	e.ref.EnableLeakCheck("host.ConnectedEndpoint")
+	return &e, nil
+}
+
+// Init will do the initialization required without holding other locks.
+func (c *ConnectedEndpoint) Init() {
+	if err := fdnotifier.AddFD(int32(c.fd), c.queue); err != nil {
+		panic(err)
+	}
+}
+
+// Send implements transport.ConnectedEndpoint.Send.
+func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
+	if !controlMessages.Empty() {
+		return 0, false, syserr.ErrInvalidEndpointState
+	}
+
+	// Since stream sockets don't preserve message boundaries, we can write
+	// only as much of the message as fits in the send buffer.
+	truncate := c.stype == linux.SOCK_STREAM
+
+	n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate)
+	if n < totalLen && err == nil {
+		// The host only returns a short write if it would otherwise
+		// block (and only for stream sockets).
+		err = syserror.EAGAIN
+	}
+	if n > 0 && err != syserror.EAGAIN {
+		// The caller may need to block to send more data, but
+		// otherwise there isn't anything that can be done about an
+		// error with a partial write.
+		err = nil
+	}
+
+	// There is no need for the callee to call SendNotify because fdWriteVec
+	// uses the host's sendmsg(2) and the host kernel's queue.
+	return n, false, syserr.FromError(err)
+}
+
+// SendNotify implements transport.ConnectedEndpoint.SendNotify.
+func (c *ConnectedEndpoint) SendNotify() {}
+
+// CloseSend implements transport.ConnectedEndpoint.CloseSend.
+func (c *ConnectedEndpoint) CloseSend() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if err := syscall.Shutdown(c.fd, syscall.SHUT_WR); err != nil {
+		// A well-formed UDS shutdown can't fail. See
+		// net/unix/af_unix.c:unix_shutdown.
+		panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+	}
+}
+
+// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
+func (c *ConnectedEndpoint) CloseNotify() {}
+
+// Writable implements transport.ConnectedEndpoint.Writable.
+func (c *ConnectedEndpoint) Writable() bool {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
+	return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventOut)&waiter.EventOut != 0
+}
+
+// Passcred implements transport.ConnectedEndpoint.Passcred.
+func (c *ConnectedEndpoint) Passcred() bool {
+	// We don't support credential passing for host sockets.
+	return false
+}
+
+// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+	return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil
+}
+
+// EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
+func (c *ConnectedEndpoint) EventUpdate() {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+	if c.fd != -1 {
+		fdnotifier.UpdateFD(int32(c.fd))
+	}
+}
+
+// Recv implements transport.Receiver.Recv.
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
+	var cm unet.ControlMessage
+	if numRights > 0 {
+		cm.EnableFDs(int(numRights))
+	}
+
+	// N.B. Unix sockets don't have a receive buffer, the send buffer
+	// serves both purposes.
+	rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf)
+	if rl > 0 && err != nil {
+		// We got some data, so all we need to do on error is return
+		// the data that we got. Short reads are fine, no need to
+		// block.
+		err = nil
+	}
+	if err != nil {
+		return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+	}
+
+	// There is no need for the callee to call RecvNotify because fdReadVec uses
+	// the host's recvmsg(2) and the host kernel's queue.
+
+	// Trim the control data if we received less than the full amount.
+	if cl < uint64(len(cm)) {
+		cm = cm[:cl]
+	}
+
+	// Avoid extra allocations in the case where there isn't any control data.
+	if len(cm) == 0 {
+		return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+	}
+
+	fds, err := cm.ExtractFDs()
+	if err != nil {
+		return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+	}
+
+	if len(fds) == 0 {
+		return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+	}
+	return rl, ml, control.NewVFS2(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+}
+
+// RecvNotify implements transport.Receiver.RecvNotify.
+func (c *ConnectedEndpoint) RecvNotify() {}
+
+// CloseRecv implements transport.Receiver.CloseRecv.
+func (c *ConnectedEndpoint) CloseRecv() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if err := syscall.Shutdown(c.fd, syscall.SHUT_RD); err != nil {
+		// A well-formed UDS shutdown can't fail. See
+		// net/unix/af_unix.c:unix_shutdown.
+		panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+	}
+}
+
+// Readable implements transport.Receiver.Readable.
+func (c *ConnectedEndpoint) Readable() bool {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
+	return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventIn)&waiter.EventIn != 0
+}
+
+// SendQueuedSize implements transport.Receiver.SendQueuedSize.
+func (c *ConnectedEndpoint) SendQueuedSize() int64 {
+	// TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+	// sockets because we don't allow the sentry to call ioctl(2).
+	return -1
+}
+
+// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
+func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
+	// TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+	// sockets because we don't allow the sentry to call ioctl(2).
+	return -1
+}
+
+// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
+func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
+	return int64(c.sndbuf)
+}
+
+// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
+func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
+	// N.B. Unix sockets don't use the receive buffer. We'll claim it is
+	// the same size as the send buffer.
+	return int64(c.sndbuf)
+}
+
+func (c *ConnectedEndpoint) destroyLocked() {
+	fdnotifier.RemoveFD(int32(c.fd))
+	c.fd = -1
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (c *ConnectedEndpoint) Release() {
+	c.ref.DecRefWithDestructor(func() {
+		c.mu.Lock()
+		c.destroyLocked()
+		c.mu.Unlock()
+	})
+}
+
+// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
+func (c *ConnectedEndpoint) CloseUnread() {}
+
+// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
+// passed through a gofer Unix socket. It is almost the same as
+// ConnectedEndpoint, with the following differences:
+// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
+// the same descriptor number across S/R.
+// - SCMConnectedEndpoint holds ownership of its fd and is responsible for
+// closing it.
+type SCMConnectedEndpoint struct {
+	ConnectedEndpoint
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (e *SCMConnectedEndpoint) Release() {
+	e.ref.DecRefWithDestructor(func() {
+		e.mu.Lock()
+		if err := syscall.Close(e.fd); err != nil {
+			log.Warningf("Failed to close host fd %d: %v", err)
+		}
+		e.destroyLocked()
+		e.mu.Unlock()
+	})
+}
+
+// NewSCMEndpoint creates a new SCMConnectedEndpoint backed by a host fd that
+// was passed through a Unix socket.
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
+	e := SCMConnectedEndpoint{ConnectedEndpoint{
+		fd:    hostFD,
+		addr:  addr,
+		queue: queue,
+	}}
+
+	if err := e.init(); err != nil {
+		return nil, err
+	}
+
+	// AtomicRefCounters start off with a single reference. We need two.
+	e.ref.IncRef()
+	e.ref.EnableLeakCheck("host.SCMConnectedEndpoint")
+	return &e, nil
+}
diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go
new file mode 100644
index 000000000..584c247d2
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_iovec.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+	"syscall"
+
+	"gvisor.dev/gvisor/pkg/abi/linux"
+	"gvisor.dev/gvisor/pkg/syserror"
+)
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+	for _, d := range dst {
+		done := copy(d, src)
+		src = src[done:]
+		if len(src) == 0 {
+			break
+		}
+	}
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+	for _, s := range src {
+		done := copy(dst, s)
+		dst = dst[done:]
+		if len(dst) == 0 {
+			break
+		}
+	}
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) {
+	var iovsRequired int
+	for _, b := range bufs {
+		length += int64(len(b))
+		if len(b) > 0 {
+			iovsRequired++
+		}
+	}
+
+	stopLen := length
+	if length > maxlen {
+		if truncate {
+			stopLen = maxlen
+			err = syserror.EAGAIN
+		} else {
+			return 0, nil, nil, syserror.EMSGSIZE
+		}
+	}
+
+	if iovsRequired > maxIovs {
+		// The kernel will reject our call if we pass this many iovs.
+		// Use a single intermediate buffer instead.
+		b := make([]byte, stopLen)
+
+		return stopLen, []syscall.Iovec{{
+			Base: &b[0],
+			Len:  uint64(stopLen),
+		}}, b, err
+	}
+
+	var total int64
+	iovecs = make([]syscall.Iovec, 0, iovsRequired)
+	for i := range bufs {
+		l := len(bufs[i])
+		if l == 0 {
+			continue
+		}
+
+		stop := int64(l)
+		if total+stop > stopLen {
+			stop = stopLen - total
+		}
+
+		iovecs = append(iovecs, syscall.Iovec{
+			Base: &bufs[i][0],
+			Len:  uint64(stop),
+		})
+
+		total += stop
+		if total >= stopLen {
+			break
+		}
+	}
+
+	return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go
new file mode 100644
index 000000000..35ded24bc
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_unsafe.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+	"syscall"
+	"unsafe"
+)
+
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) {
+	flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
+	if peek {
+		flags |= syscall.MSG_PEEK
+	}
+
+	// Always truncate the receive buffer. All socket types will truncate
+	// received messages.
+	length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+	if err != nil && len(iovecs) == 0 {
+		// No partial write to do, return error immediately.
+		return 0, 0, 0, false, err
+	}
+
+	var msg syscall.Msghdr
+	if len(control) != 0 {
+		msg.Control = &control[0]
+		msg.Controllen = uint64(len(control))
+	}
+
+	if len(iovecs) != 0 {
+		msg.Iov = &iovecs[0]
+		msg.Iovlen = uint64(len(iovecs))
+	}
+
+	rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
+	if e != 0 {
+		// N.B. prioritize the syscall error over the buildIovec error.
+		return 0, 0, 0, false, e
+	}
+	n := int64(rawN)
+
+	// Copy data back to bufs.
+	if intermediate != nil {
+		copyToMulti(bufs, intermediate)
+	}
+
+	controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
+
+	if n > length {
+		return length, n, msg.Controllen, controlTrunc, err
+	}
+
+	return n, n, msg.Controllen, controlTrunc, err
+}
+
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) {
+	length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+	if err != nil && len(iovecs) == 0 {
+		// No partial write to do, return error immediately.
+		return 0, length, err
+	}
+
+	// Copy data to intermediate buf.
+	if intermediate != nil {
+		copyFromMulti(intermediate, bufs)
+	}
+
+	var msg syscall.Msghdr
+	if len(iovecs) > 0 {
+		msg.Iov = &iovecs[0]
+		msg.Iovlen = uint64(len(iovecs))
+	}
+
+	n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+	if e != 0 {
+		// N.B. prioritize the syscall error over the buildIovec error.
+		return 0, length, e
+	}
+
+	return int64(n), length, err
+}
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index 5ce50625b..271134af8 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -73,26 +73,14 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr
 	return nil, syserror.ENXIO
 }
 
-// InitSocket initializes a socket FileDescription, with a corresponding
-// Dentry in mnt.
-//
-// fd should be the FileDescription associated with socketImpl, i.e. its first
-// field. mnt should be the global socket mount, Kernel.socketMount.
-func InitSocket(socketImpl vfs.FileDescriptionImpl, fd *vfs.FileDescription, mnt *vfs.Mount, creds *auth.Credentials) error {
-	fsimpl := mnt.Filesystem().Impl()
-	fs := fsimpl.(*kernfs.Filesystem)
-
+// NewDentry constructs and returns a sockfs dentry.
+func NewDentry(creds *auth.Credentials, ino uint64) *vfs.Dentry {
 	// File mode matches net/socket.c:sock_alloc.
 	filemode := linux.FileMode(linux.S_IFSOCK | 0600)
 	i := &inode{}
-	i.Init(creds, fs.NextIno(), filemode)
+	i.Init(creds, ino, filemode)
 
 	d := &kernfs.Dentry{}
 	d.Init(i)
-
-	opts := &vfs.FileDescriptionOptions{UseDentryMetadata: true}
-	if err := fd.Init(socketImpl, linux.O_RDWR, mnt, d.VFSDentry(), opts); err != nil {
-		return err
-	}
-	return nil
+	return d.VFSDentry()
 }
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index de2cc4bdf..941a91097 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -21,6 +21,7 @@ go_library(
         "//pkg/sentry/device",
         "//pkg/sentry/fs",
         "//pkg/sentry/fs/fsutil",
+        "//pkg/sentry/fsimpl/kernfs",
         "//pkg/sentry/fsimpl/sockfs",
         "//pkg/sentry/kernel",
         "//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 3e54d49c4..315dc274f 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -19,6 +19,7 @@ import (
 	"gvisor.dev/gvisor/pkg/context"
 	"gvisor.dev/gvisor/pkg/fspath"
 	"gvisor.dev/gvisor/pkg/sentry/arch"
+	"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
 	"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
 	"gvisor.dev/gvisor/pkg/sentry/kernel"
 	"gvisor.dev/gvisor/pkg/sentry/socket/control"
@@ -42,30 +43,44 @@ type SocketVFS2 struct {
 	socketOpsCommon
 }
 
-// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket.
-func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
-	sock := NewFDImpl(ep, stype)
-	vfsfd := &sock.vfsfd
-	if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil {
+// NewSockfsFile creates a new socket file in the global sockfs mount and
+// returns a corresponding file description.
+func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+	mnt := t.Kernel().SocketMount()
+	fs := mnt.Filesystem().Impl().(*kernfs.Filesystem)
+	d := sockfs.NewDentry(t.Credentials(), fs.NextIno())
+
+	fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d)
+	if err != nil {
 		return nil, syserr.FromError(err)
 	}
-	return vfsfd, nil
+	return fd, nil
 }
 
-// NewFDImpl creates and returns a new SocketVFS2.
-func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 {
+// NewFileDescription creates and returns a socket file description
+// corresponding to the given mount and dentry.
+func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry) (*vfs.FileDescription, error) {
 	// You can create AF_UNIX, SOCK_RAW sockets. They're the same as
 	// SOCK_DGRAM and don't require CAP_NET_RAW.
 	if stype == linux.SOCK_RAW {
 		stype = linux.SOCK_DGRAM
 	}
 
-	return &SocketVFS2{
+	sock := &SocketVFS2{
 		socketOpsCommon: socketOpsCommon{
 			ep:    ep,
 			stype: stype,
 		},
 	}
+	vfsfd := &sock.vfsfd
+	if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
+		DenyPRead:         true,
+		DenyPWrite:        true,
+		UseDentryMetadata: true,
+	}); err != nil {
+		return nil, err
+	}
+	return vfsfd, nil
 }
 
 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
@@ -112,8 +127,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
 		}
 	}
 
-	// We expect this to be a FileDescription here.
-	ns, err := NewVFS2File(t, ep, s.stype)
+	ns, err := NewSockfsFile(t, ep, s.stype)
 	if err != nil {
 		return 0, nil, 0, err
 	}
@@ -307,7 +321,7 @@ func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int)
 		return nil, syserr.ErrInvalidArgument
 	}
 
-	f, err := NewVFS2File(t, ep, stype)
+	f, err := NewSockfsFile(t, ep, stype)
 	if err != nil {
 		ep.Close()
 		return nil, err
@@ -331,13 +345,13 @@ func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*
 
 	// Create the endpoints and sockets.
 	ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
-	s1, err := NewVFS2File(t, ep1, stype)
+	s1, err := NewSockfsFile(t, ep1, stype)
 	if err != nil {
 		ep1.Close()
 		ep2.Close()
 		return nil, nil, err
 	}
-	s2, err := NewVFS2File(t, ep2, stype)
+	s2, err := NewSockfsFile(t, ep2, stype)
 	if err != nil {
 		s1.DecRef()
 		ep2.Close()
-- 
cgit v1.2.3