diff options
Diffstat (limited to 'pkg/sentry')
55 files changed, 2154 insertions, 237 deletions
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index a949fffbf..548898aaa 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -59,13 +59,13 @@ go_library( "//pkg/sentry/limits", "//pkg/sentry/memmap", "//pkg/sentry/platform", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/uniqueid", "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/ashmem/BUILD b/pkg/sentry/fs/ashmem/BUILD index dc893d22f..44ef82e64 100644 --- a/pkg/sentry/fs/ashmem/BUILD +++ b/pkg/sentry/fs/ashmem/BUILD @@ -28,7 +28,6 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", - "//pkg/tcpip/transport/unix", ], ) diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index a42c03e98..27fea0019 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -26,9 +26,9 @@ import ( "gvisor.googlesource.com/gvisor/pkg/refs" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) type globalDirentMap struct { @@ -800,7 +800,7 @@ func (d *Dirent) CreateDirectory(ctx context.Context, root *Dirent, name string, } // Bind satisfies the InodeOperations interface; otherwise same as GetFile. -func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data unix.BoundEndpoint, perms FilePermissions) (*Dirent, error) { +func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data transport.BoundEndpoint, perms FilePermissions) (*Dirent, error) { var childDir *Dirent err := d.genericCreate(ctx, root, name, func() error { var e error diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 3512bae6f..6834e1272 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -87,11 +87,11 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/safemem", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", "//pkg/syserror", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index 3479f2fad..3acc32752 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -19,9 +19,9 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/memmap" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -254,7 +254,7 @@ func (InodeNotDirectory) CreateDirectory(context.Context, *fs.Inode, string, fs. } // Bind implements fs.InodeOperations.Bind. -func (InodeNotDirectory) Bind(context.Context, *fs.Inode, string, unix.BoundEndpoint, fs.FilePermissions) (*fs.Dirent, error) { +func (InodeNotDirectory) Bind(context.Context, *fs.Inode, string, transport.BoundEndpoint, fs.FilePermissions) (*fs.Dirent, error) { return nil, syserror.ENOTDIR } @@ -277,7 +277,7 @@ func (InodeNotDirectory) RemoveDirectory(context.Context, *fs.Inode, string) err type InodeNotSocket struct{} // BoundEndpoint implements fs.InodeOperations.BoundEndpoint. -func (InodeNotSocket) BoundEndpoint(*fs.Inode, string) unix.BoundEndpoint { +func (InodeNotSocket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint { return nil } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index cb17339c9..cef01829a 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -41,10 +41,10 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/safemem", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/transport/unix", "//pkg/unet", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index bec9680f8..0bf7881da 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -22,8 +22,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/device" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) // Lookup loads an Inode at name into a Dirent based on the session's cache @@ -180,7 +180,7 @@ func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s } // Bind implements InodeOperations.Bind. -func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, ep unix.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) { +func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) { if i.session().endpoints == nil { return nil, syscall.EOPNOTSUPP } diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 49d27ee88..4e2293398 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -24,7 +24,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/device" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/unet" ) @@ -36,23 +36,23 @@ type endpointMaps struct { // direntMap links sockets to their dirents. // It is filled concurrently with the keyMap and is stored upon save. // Before saving, this map is used to populate the pathMap. - direntMap map[unix.BoundEndpoint]*fs.Dirent + direntMap map[transport.BoundEndpoint]*fs.Dirent // keyMap links MultiDeviceKeys (containing inode IDs) to their sockets. // It is not stored during save because the inode ID may change upon restore. - keyMap map[device.MultiDeviceKey]unix.BoundEndpoint `state:"nosave"` + keyMap map[device.MultiDeviceKey]transport.BoundEndpoint `state:"nosave"` // pathMap links the sockets to their paths. // It is filled before saving from the direntMap and is stored upon save. // Upon restore, this map is used to re-populate the keyMap. - pathMap map[unix.BoundEndpoint]string + pathMap map[transport.BoundEndpoint]string } // add adds the endpoint to the maps. // A reference is taken on the dirent argument. // // Precondition: maps must have been locked with 'lock'. -func (e *endpointMaps) add(key device.MultiDeviceKey, d *fs.Dirent, ep unix.BoundEndpoint) { +func (e *endpointMaps) add(key device.MultiDeviceKey, d *fs.Dirent, ep transport.BoundEndpoint) { e.keyMap[key] = ep d.IncRef() e.direntMap[ep] = d @@ -81,7 +81,7 @@ func (e *endpointMaps) lock() func() { // get returns the endpoint mapped to the given key. // // Precondition: maps must have been locked for reading. -func (e *endpointMaps) get(key device.MultiDeviceKey) unix.BoundEndpoint { +func (e *endpointMaps) get(key device.MultiDeviceKey) transport.BoundEndpoint { return e.keyMap[key] } @@ -285,9 +285,9 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF // newEndpointMaps creates a new endpointMaps. func newEndpointMaps() *endpointMaps { return &endpointMaps{ - direntMap: make(map[unix.BoundEndpoint]*fs.Dirent), - keyMap: make(map[device.MultiDeviceKey]unix.BoundEndpoint), - pathMap: make(map[unix.BoundEndpoint]string), + direntMap: make(map[transport.BoundEndpoint]*fs.Dirent), + keyMap: make(map[device.MultiDeviceKey]transport.BoundEndpoint), + pathMap: make(map[transport.BoundEndpoint]string), } } @@ -341,7 +341,7 @@ func (s *session) fillPathMap() error { func (s *session) restoreEndpointMaps(ctx context.Context) error { // When restoring, only need to create the keyMap because the dirent and path // maps got stored through the save. - s.endpoints.keyMap = make(map[device.MultiDeviceKey]unix.BoundEndpoint) + s.endpoints.keyMap = make(map[device.MultiDeviceKey]transport.BoundEndpoint) if err := s.fillKeyMap(ctx); err != nil { return fmt.Errorf("failed to insert sockets into endpoint map: %v", err) } @@ -349,6 +349,6 @@ func (s *session) restoreEndpointMaps(ctx context.Context) error { // Re-create pathMap because it can no longer be trusted as socket paths can // change while process continues to run. Empty pathMap will be re-filled upon // next save. - s.endpoints.pathMap = make(map[unix.BoundEndpoint]string) + s.endpoints.pathMap = make(map[transport.BoundEndpoint]string) return nil } diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 0190bc006..d072da624 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -19,13 +19,13 @@ import ( "gvisor.googlesource.com/gvisor/pkg/p9" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/host" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/tcpip" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) -// BoundEndpoint returns a gofer-backed unix.BoundEndpoint. -func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) unix.BoundEndpoint { +// BoundEndpoint returns a gofer-backed transport.BoundEndpoint. +func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.BoundEndpoint { if !fs.IsSocket(i.fileState.sattr) { return nil } @@ -45,7 +45,7 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) unix.Bound return &endpoint{inode, i.fileState.file.file, path} } -// endpoint is a Gofer-backed unix.BoundEndpoint. +// endpoint is a Gofer-backed transport.BoundEndpoint. // // An endpoint's lifetime is the time between when InodeOperations.BoundEndpoint() // is called and either BoundEndpoint.BidirectionalConnect or @@ -61,20 +61,20 @@ type endpoint struct { path string } -func unixSockToP9(t unix.SockType) (p9.ConnectFlags, bool) { +func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { switch t { - case unix.SockStream: + case transport.SockStream: return p9.StreamSocket, true - case unix.SockSeqpacket: + case transport.SockSeqpacket: return p9.SeqpacketSocket, true - case unix.SockDgram: + case transport.SockDgram: return p9.DgramSocket, true } return 0, false } // BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. -func (e *endpoint) BidirectionalConnect(ce unix.ConnectingEndpoint, returnConnect func(unix.Receiver, unix.ConnectedEndpoint)) *tcpip.Error { +func (e *endpoint) BidirectionalConnect(ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *tcpip.Error { cf, ok := unixSockToP9(ce.Type()) if !ok { return tcpip.ErrConnectionRefused @@ -113,8 +113,9 @@ func (e *endpoint) BidirectionalConnect(ce unix.ConnectingEndpoint, returnConnec return nil } -// UnidirectionalConnect implements unix.BoundEndpoint.UnidirectionalConnect. -func (e *endpoint) UnidirectionalConnect() (unix.ConnectedEndpoint, *tcpip.Error) { +// UnidirectionalConnect implements +// transport.BoundEndpoint.UnidirectionalConnect. +func (e *endpoint) UnidirectionalConnect() (transport.ConnectedEndpoint, *tcpip.Error) { hostFile, err := e.file.Connect(p9.DgramSocket) if err != nil { return nil, tcpip.ErrConnectionRefused @@ -134,7 +135,7 @@ func (e *endpoint) UnidirectionalConnect() (unix.ConnectedEndpoint, *tcpip.Error return c, nil } -// Release implements unix.BoundEndpoint.Release. +// Release implements transport.BoundEndpoint.Release. func (e *endpoint) Release() { e.inode.DecRef() } diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 5ada32ee1..4f264a024 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -42,13 +42,13 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/uniqueid", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/transport/unix", "//pkg/unet", "//pkg/waiter", "//pkg/waiter/fdnotifier", @@ -72,10 +72,10 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", "//pkg/sentry/socket", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/tcpip", - "//pkg/tcpip/transport/unix", "//pkg/waiter", "//pkg/waiter/fdnotifier", ], diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go index d2b007ab2..d2e34a69d 100644 --- a/pkg/sentry/fs/host/control.go +++ b/pkg/sentry/fs/host/control.go @@ -20,7 +20,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) type scmRights struct { @@ -45,13 +45,13 @@ func (c *scmRights) Files(ctx context.Context, max int) control.RightsFiles { return rf } -// Clone implements unix.RightsControlMessage.Clone. -func (c *scmRights) Clone() unix.RightsControlMessage { +// Clone implements transport.RightsControlMessage.Clone. +func (c *scmRights) Clone() transport.RightsControlMessage { // Host rights never need to be cloned. return nil } -// Release implements unix.RightsControlMessage.Release. +// Release implements transport.RightsControlMessage.Release. func (c *scmRights) Release() { for _, fd := range c.fds { syscall.Close(fd) diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index c2e8ba62f..e32497203 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -27,8 +27,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" "gvisor.googlesource.com/gvisor/pkg/sentry/memmap" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -310,12 +310,12 @@ func (i *inodeOperations) Rename(ctx context.Context, oldParent *fs.Inode, oldNa } // Bind implements fs.InodeOperations.Bind. -func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data unix.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) { +func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) { return nil, syserror.EOPNOTSUPP } // BoundEndpoint implements fs.InodeOperations.BoundEndpoint. -func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) unix.BoundEndpoint { +func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.BoundEndpoint { return nil } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index e454b6fe5..0eb267c00 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -25,12 +25,12 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control" unixsocket "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/unet" "gvisor.googlesource.com/gvisor/pkg/waiter" "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier" @@ -42,7 +42,7 @@ import ( const maxSendBufferSize = 8 << 20 // ConnectedEndpoint is a host FD backed implementation of -// unix.ConnectedEndpoint and unix.Receiver. +// transport.ConnectedEndpoint and transport.Receiver. // // +stateify savable type ConnectedEndpoint struct { @@ -70,7 +70,7 @@ type ConnectedEndpoint struct { srfd int `state:"wait"` // stype is the type of Unix socket. - stype unix.SockType + stype transport.SockType // sndbuf is the size of the send buffer. // @@ -112,7 +112,7 @@ func (c *ConnectedEndpoint) init() *tcpip.Error { return tcpip.ErrInvalidEndpointState } - c.stype = unix.SockType(stype) + c.stype = transport.SockType(stype) c.sndbuf = sndbuf return nil @@ -122,8 +122,8 @@ func (c *ConnectedEndpoint) init() *tcpip.Error { // that will pretend to be bound at a given sentry path. // // The caller is responsible for calling Init(). Additionaly, Release needs to -// be called twice because ConnectedEndpoint is both a unix.Receiver and -// unix.ConnectedEndpoint. +// be called twice because ConnectedEndpoint is both a transport.Receiver and +// transport.ConnectedEndpoint. func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *tcpip.Error) { e := ConnectedEndpoint{ path: path, @@ -168,7 +168,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F e.Init() - ep := unix.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) + ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) return unixsocket.NewWithDirent(ctx, d, ep, flags), nil } @@ -200,13 +200,13 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) e.srfd = srfd e.Init() - ep := unix.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) + ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) return unixsocket.New(ctx, ep), nil } -// Send implements unix.ConnectedEndpoint.Send. -func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { +// Send implements transport.ConnectedEndpoint.Send. +func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { c.mu.RLock() defer c.mu.RUnlock() if c.writeClosed { @@ -219,7 +219,7 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess // Since stream sockets don't preserve message boundaries, we can write // only as much of the message as fits in the send buffer. - truncate := c.stype == unix.SockStream + truncate := c.stype == transport.SockStream n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) if n < totalLen && err == nil { @@ -239,20 +239,20 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess return n, false, translateError(err) } -// SendNotify implements unix.ConnectedEndpoint.SendNotify. +// SendNotify implements transport.ConnectedEndpoint.SendNotify. func (c *ConnectedEndpoint) SendNotify() {} -// CloseSend implements unix.ConnectedEndpoint.CloseSend. +// CloseSend implements transport.ConnectedEndpoint.CloseSend. func (c *ConnectedEndpoint) CloseSend() { c.mu.Lock() c.writeClosed = true c.mu.Unlock() } -// CloseNotify implements unix.ConnectedEndpoint.CloseNotify. +// CloseNotify implements transport.ConnectedEndpoint.CloseNotify. func (c *ConnectedEndpoint) CloseNotify() {} -// Writable implements unix.ConnectedEndpoint.Writable. +// Writable implements transport.ConnectedEndpoint.Writable. func (c *ConnectedEndpoint) Writable() bool { c.mu.RLock() defer c.mu.RUnlock() @@ -262,18 +262,18 @@ func (c *ConnectedEndpoint) Writable() bool { return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0 } -// Passcred implements unix.ConnectedEndpoint.Passcred. +// Passcred implements transport.ConnectedEndpoint.Passcred. func (c *ConnectedEndpoint) Passcred() bool { // We don't support credential passing for host sockets. return false } -// GetLocalAddress implements unix.ConnectedEndpoint.GetLocalAddress. +// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress. func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil } -// EventUpdate implements unix.ConnectedEndpoint.EventUpdate. +// EventUpdate implements transport.ConnectedEndpoint.EventUpdate. func (c *ConnectedEndpoint) EventUpdate() { c.mu.RLock() defer c.mu.RUnlock() @@ -282,12 +282,12 @@ func (c *ConnectedEndpoint) EventUpdate() { } } -// Recv implements unix.Receiver.Recv. -func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { +// Recv implements transport.Receiver.Recv. +func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { c.mu.RLock() defer c.mu.RUnlock() if c.readClosed { - return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, tcpip.ErrClosedForReceive + return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, tcpip.ErrClosedForReceive } var cm unet.ControlMessage @@ -305,7 +305,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p err = nil } if err != nil { - return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, translateError(err) + return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, translateError(err) } // There is no need for the callee to call RecvNotify because fdReadVec uses @@ -318,16 +318,16 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p // Avoid extra allocations in the case where there isn't any control data. if len(cm) == 0 { - return rl, ml, unix.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil + return rl, ml, transport.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil } fds, err := cm.ExtractFDs() if err != nil { - return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, translateError(err) + return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, translateError(err) } if len(fds) == 0 { - return rl, ml, unix.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil + return rl, ml, transport.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil } return rl, ml, control.New(nil, nil, newSCMRights(fds)), tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil } @@ -339,17 +339,17 @@ func (c *ConnectedEndpoint) close() { c.file = nil } -// RecvNotify implements unix.Receiver.RecvNotify. +// RecvNotify implements transport.Receiver.RecvNotify. func (c *ConnectedEndpoint) RecvNotify() {} -// CloseRecv implements unix.Receiver.CloseRecv. +// CloseRecv implements transport.Receiver.CloseRecv. func (c *ConnectedEndpoint) CloseRecv() { c.mu.Lock() c.readClosed = true c.mu.Unlock() } -// Readable implements unix.Receiver.Readable. +// Readable implements transport.Receiver.Readable. func (c *ConnectedEndpoint) Readable() bool { c.mu.RLock() defer c.mu.RUnlock() @@ -359,33 +359,33 @@ func (c *ConnectedEndpoint) Readable() bool { return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0 } -// SendQueuedSize implements unix.Receiver.SendQueuedSize. +// SendQueuedSize implements transport.Receiver.SendQueuedSize. func (c *ConnectedEndpoint) SendQueuedSize() int64 { // SendQueuedSize isn't supported for host sockets because we don't allow the // sentry to call ioctl(2). return -1 } -// RecvQueuedSize implements unix.Receiver.RecvQueuedSize. +// RecvQueuedSize implements transport.Receiver.RecvQueuedSize. func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // RecvQueuedSize isn't supported for host sockets because we don't allow the // sentry to call ioctl(2). return -1 } -// SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize. +// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize. func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { return int64(c.sndbuf) } -// RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize. +// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize. func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { // N.B. Unix sockets don't use the receive buffer. We'll claim it is // the same size as the send buffer. return int64(c.sndbuf) } -// Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release. +// Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release. func (c *ConnectedEndpoint) Release() { c.ref.DecRefWithDestructor(c.close) } diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index 8b752737d..1c6f9ddb1 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -22,20 +22,20 @@ import ( "gvisor.googlesource.com/gvisor/pkg/fd" "gvisor.googlesource.com/gvisor/pkg/sentry/context/contexttest" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier" ) var ( - // Make sure that ConnectedEndpoint implements unix.ConnectedEndpoint. - _ = unix.ConnectedEndpoint(new(ConnectedEndpoint)) + // Make sure that ConnectedEndpoint implements transport.ConnectedEndpoint. + _ = transport.ConnectedEndpoint(new(ConnectedEndpoint)) - // Make sure that ConnectedEndpoint implements unix.Receiver. - _ = unix.Receiver(new(ConnectedEndpoint)) + // Make sure that ConnectedEndpoint implements transport.Receiver. + _ = transport.Receiver(new(ConnectedEndpoint)) ) func getFl(fd int) (uint32, error) { @@ -199,7 +199,7 @@ func TestListen(t *testing.T) { func TestSend(t *testing.T) { e := ConnectedEndpoint{writeClosed: true} - if _, _, err := e.Send(nil, unix.ControlMessages{}, tcpip.FullAddress{}); err != tcpip.ErrClosedForSend { + if _, _, err := e.Send(nil, transport.ControlMessages{}, tcpip.FullAddress{}); err != tcpip.ErrClosedForSend { t.Errorf("Got %#v.Send() = %v, want = %v", e, err, tcpip.ErrClosedForSend) } } diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index db7240dca..409c81a97 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -22,8 +22,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/lock" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/memmap" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) // Inode is a file system object that can be simultaneously referenced by different @@ -223,7 +223,7 @@ func (i *Inode) Rename(ctx context.Context, oldParent *Dirent, renamed *Dirent, } // Bind calls i.InodeOperations.Bind with i as the directory. -func (i *Inode) Bind(ctx context.Context, name string, data unix.BoundEndpoint, perm FilePermissions) (*Dirent, error) { +func (i *Inode) Bind(ctx context.Context, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { if i.overlay != nil { return overlayBind(ctx, i.overlay, name, data, perm) } @@ -231,7 +231,7 @@ func (i *Inode) Bind(ctx context.Context, name string, data unix.BoundEndpoint, } // BoundEndpoint calls i.InodeOperations.BoundEndpoint with i as the Inode. -func (i *Inode) BoundEndpoint(path string) unix.BoundEndpoint { +func (i *Inode) BoundEndpoint(path string) transport.BoundEndpoint { if i.overlay != nil { return overlayBoundEndpoint(i.overlay, path) } diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go index 952f9704d..3ee3de10e 100644 --- a/pkg/sentry/fs/inode_operations.go +++ b/pkg/sentry/fs/inode_operations.go @@ -20,8 +20,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/context" ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/memmap" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -146,7 +146,7 @@ type InodeOperations interface { // Implementations must ensure that name does not already exist. // // The caller must ensure that this operation is permitted. - Bind(ctx context.Context, dir *Inode, name string, data unix.BoundEndpoint, perm FilePermissions) (*Dirent, error) + Bind(ctx context.Context, dir *Inode, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) // BoundEndpoint returns the socket endpoint at path stored in // or generated by an Inode. @@ -160,7 +160,7 @@ type InodeOperations interface { // generally implies that this Inode was created via CreateSocket. // // If there is no socket endpoint available, nil will be returned. - BoundEndpoint(inode *Inode, path string) unix.BoundEndpoint + BoundEndpoint(inode *Inode, path string) transport.BoundEndpoint // GetFile returns a new open File backed by a Dirent and FileFlags. // It may block as long as it is done with ctx. diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index e18e095a0..cf698a4da 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -20,8 +20,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/context" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) func overlayHasWhiteout(parent *Inode, name string) bool { @@ -356,7 +356,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena return nil } -func overlayBind(ctx context.Context, o *overlayEntry, name string, data unix.BoundEndpoint, perm FilePermissions) (*Dirent, error) { +func overlayBind(ctx context.Context, o *overlayEntry, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { o.copyMu.RLock() defer o.copyMu.RUnlock() // We do not support doing anything exciting with sockets unless there @@ -383,7 +383,7 @@ func overlayBind(ctx context.Context, o *overlayEntry, name string, data unix.Bo return NewDirent(newOverlayInode(ctx, entry, inode.MountSource), name), nil } -func overlayBoundEndpoint(o *overlayEntry, path string) unix.BoundEndpoint { +func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint { o.copyMu.RLock() defer o.copyMu.RUnlock() diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 5230157fe..a93ad6240 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -23,9 +23,9 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/safemem", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserror", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index d8333194b..075e13b01 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -20,9 +20,9 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) // CreateOps represents operations to create different file types. @@ -37,7 +37,7 @@ type CreateOps struct { NewSymlink func(ctx context.Context, dir *fs.Inode, target string) (*fs.Inode, error) // NewBoundEndpoint creates a new socket. - NewBoundEndpoint func(ctx context.Context, dir *fs.Inode, ep unix.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error) + NewBoundEndpoint func(ctx context.Context, dir *fs.Inode, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error) // NewFifo creates a new fifo. NewFifo func(ctx context.Context, dir *fs.Inode, perm fs.FilePermissions) (*fs.Inode, error) @@ -314,7 +314,7 @@ func (d *Dir) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, p } // Bind implements fs.InodeOperations.Bind. -func (d *Dir) Bind(ctx context.Context, dir *fs.Inode, name string, ep unix.BoundEndpoint, perms fs.FilePermissions) (*fs.Dirent, error) { +func (d *Dir) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Dirent, error) { if d.CreateOps == nil || d.CreateOps.NewBoundEndpoint == nil { return nil, ErrDenied } diff --git a/pkg/sentry/fs/ramfs/ramfs.go b/pkg/sentry/fs/ramfs/ramfs.go index 1028b5f1d..83cbcab23 100644 --- a/pkg/sentry/fs/ramfs/ramfs.go +++ b/pkg/sentry/fs/ramfs/ramfs.go @@ -26,9 +26,9 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/memmap" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -279,7 +279,7 @@ func (*Entry) CreateDirectory(context.Context, *fs.Inode, string, fs.FilePermiss } // Bind is not supported by default. -func (*Entry) Bind(context.Context, *fs.Inode, string, unix.BoundEndpoint, fs.FilePermissions) (*fs.Dirent, error) { +func (*Entry) Bind(context.Context, *fs.Inode, string, transport.BoundEndpoint, fs.FilePermissions) (*fs.Dirent, error) { return nil, ErrInvalidOp } diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go index 93427a1ff..9ac00eb18 100644 --- a/pkg/sentry/fs/ramfs/socket.go +++ b/pkg/sentry/fs/ramfs/socket.go @@ -17,7 +17,7 @@ package ramfs import ( "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) // Socket represents a socket. @@ -27,17 +27,17 @@ type Socket struct { Entry // ep is the bound endpoint. - ep unix.BoundEndpoint + ep transport.BoundEndpoint } // InitSocket initializes a socket. -func (s *Socket) InitSocket(ctx context.Context, ep unix.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions) { +func (s *Socket) InitSocket(ctx context.Context, ep transport.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions) { s.InitEntry(ctx, owner, perms) s.ep = ep } // BoundEndpoint returns the socket data. -func (s *Socket) BoundEndpoint(*fs.Inode, string) unix.BoundEndpoint { +func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint { // ramfs only supports stored sentry internal sockets. Only gofer sockets // care about the path argument. return s.ep diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index cfe11ab02..277583113 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -25,9 +25,9 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/safemem", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 10cb5451d..38be6db46 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -22,9 +22,9 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/pipe" "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usage" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) var fsInfo = fs.Info{ @@ -104,7 +104,7 @@ func (d *Dir) newCreateOps() *ramfs.CreateOps { NewSymlink: func(ctx context.Context, dir *fs.Inode, target string) (*fs.Inode, error) { return NewSymlink(ctx, target, fs.FileOwnerFromContext(ctx), dir.MountSource), nil }, - NewBoundEndpoint: func(ctx context.Context, dir *fs.Inode, socket unix.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error) { + NewBoundEndpoint: func(ctx context.Context, dir *fs.Inode, socket transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error) { return NewSocket(ctx, socket, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil }, NewFifo: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) { @@ -160,7 +160,7 @@ type Socket struct { } // NewSocket returns a new socket with the provided permissions. -func NewSocket(ctx context.Context, socket unix.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { +func NewSocket(ctx context.Context, socket transport.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { s := &Socket{} s.InitSocket(ctx, socket, owner, perms) return fs.NewInode(s, msrc, fs.StableAttr{ diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 3c446eef4..d4dd20e30 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -26,9 +26,9 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserror", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index c6f39fce3..7c0c0b0c1 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -26,9 +26,9 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -215,7 +215,7 @@ func (d *dirInodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, } // Bind implements fs.InodeOperations.Bind. -func (d *dirInodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data unix.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) { +func (d *dirInodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) { return nil, syserror.EPERM } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 31ad96612..acc61cb09 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -156,6 +156,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/safemem", "//pkg/sentry/socket/netlink/port", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/time", "//pkg/sentry/uniqueid", "//pkg/sentry/usage", @@ -166,7 +167,6 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index d6d1d341d..45088c988 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -19,12 +19,12 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/refs" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) // +stateify savable type abstractEndpoint struct { - ep unix.BoundEndpoint + ep transport.BoundEndpoint wr *refs.WeakRef name string ns *AbstractSocketNamespace @@ -56,14 +56,14 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace { } } -// A boundEndpoint wraps a unix.BoundEndpoint to maintain a reference on its -// backing object. +// A boundEndpoint wraps a transport.BoundEndpoint to maintain a reference on +// its backing object. type boundEndpoint struct { - unix.BoundEndpoint + transport.BoundEndpoint rc refs.RefCounter } -// Release implements unix.BoundEndpoint.Release. +// Release implements transport.BoundEndpoint.Release. func (e *boundEndpoint) Release() { e.rc.DecRef() e.BoundEndpoint.Release() @@ -71,7 +71,7 @@ func (e *boundEndpoint) Release() { // BoundEndpoint retrieves the endpoint bound to the given name. The return // value is nil if no endpoint was bound. -func (a *AbstractSocketNamespace) BoundEndpoint(name string) unix.BoundEndpoint { +func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndpoint { a.mu.Lock() defer a.mu.Unlock() @@ -93,7 +93,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) unix.BoundEndpoint // // When the last reference managed by rc is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(name string, ep unix.BoundEndpoint, rc refs.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(name string, ep transport.BoundEndpoint, rc refs.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index a320fca0b..3a8044b5f 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -16,9 +16,9 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/kdefs", "//pkg/sentry/kernel/time", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/tcpip", - "//pkg/tcpip/transport/unix", ], ) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index c4874fdfb..d3a63f15f 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -18,8 +18,8 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/kdefs", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserror", - "//pkg/tcpip/transport/unix", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index c31182e69..db97e95f2 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -24,16 +24,16 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) const maxInt = int(^uint(0) >> 1) // SCMCredentials represents a SCM_CREDENTIALS socket control message. type SCMCredentials interface { - unix.CredentialsControlMessage + transport.CredentialsControlMessage // Credentials returns properly namespaced values for the sender's pid, uid // and gid. @@ -42,7 +42,7 @@ type SCMCredentials interface { // SCMRights represents a SCM_RIGHTS socket control message. type SCMRights interface { - unix.RightsControlMessage + transport.RightsControlMessage // Files returns up to max RightsFiles. Files(ctx context.Context, max int) RightsFiles @@ -81,8 +81,8 @@ func (fs *RightsFiles) Files(ctx context.Context, max int) RightsFiles { return rf } -// Clone implements unix.RightsControlMessage.Clone. -func (fs *RightsFiles) Clone() unix.RightsControlMessage { +// Clone implements transport.RightsControlMessage.Clone. +func (fs *RightsFiles) Clone() transport.RightsControlMessage { nfs := append(RightsFiles(nil), *fs...) for _, nf := range nfs { nf.IncRef() @@ -90,7 +90,7 @@ func (fs *RightsFiles) Clone() unix.RightsControlMessage { return &nfs } -// Release implements unix.RightsControlMessage.Release. +// Release implements transport.RightsControlMessage.Release. func (fs *RightsFiles) Release() { for _, f := range *fs { f.DecRef() @@ -156,8 +156,8 @@ func NewSCMCredentials(t *kernel.Task, cred linux.ControlMessageCredentials) (SC return &scmCredentials{t, kuid, kgid}, nil } -// Equals implements unix.CredentialsControlMessage.Equals. -func (c *scmCredentials) Equals(oc unix.CredentialsControlMessage) bool { +// Equals implements transport.CredentialsControlMessage.Equals. +func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool { if oc, _ := oc.(*scmCredentials); oc != nil && *c == *oc { return true } @@ -301,7 +301,7 @@ func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { } // Parse parses a raw socket control message into portable objects. -func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.ControlMessages, error) { +func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) { var ( fds linux.ControlMessageRights @@ -311,20 +311,20 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.Contr for i := 0; i < len(buf); { if i+linux.SizeOfControlMessageHeader > len(buf) { - return unix.ControlMessages{}, syserror.EINVAL + return transport.ControlMessages{}, syserror.EINVAL } var h linux.ControlMessageHeader binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h) if h.Length < uint64(linux.SizeOfControlMessageHeader) { - return unix.ControlMessages{}, syserror.EINVAL + return transport.ControlMessages{}, syserror.EINVAL } if h.Length > uint64(len(buf)-i) { - return unix.ControlMessages{}, syserror.EINVAL + return transport.ControlMessages{}, syserror.EINVAL } if h.Level != linux.SOL_SOCKET { - return unix.ControlMessages{}, syserror.EINVAL + return transport.ControlMessages{}, syserror.EINVAL } i += linux.SizeOfControlMessageHeader @@ -340,7 +340,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.Contr numRights := rightsSize / linux.SizeOfControlMessageRight if len(fds)+numRights > linux.SCM_MAX_FD { - return unix.ControlMessages{}, syserror.EINVAL + return transport.ControlMessages{}, syserror.EINVAL } for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { @@ -351,7 +351,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.Contr case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { - return unix.ControlMessages{}, syserror.EINVAL + return transport.ControlMessages{}, syserror.EINVAL } binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) @@ -360,7 +360,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.Contr default: // Unknown message type. - return unix.ControlMessages{}, syserror.EINVAL + return transport.ControlMessages{}, syserror.EINVAL } } @@ -368,7 +368,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.Contr if haveCreds { var err error if credentials, err = NewSCMCredentials(t, creds); err != nil { - return unix.ControlMessages{}, err + return transport.ControlMessages{}, err } } else { credentials = makeCreds(t, socketOrEndpoint) @@ -378,22 +378,22 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.Contr if len(fds) > 0 { var err error if rights, err = NewSCMRights(t, fds); err != nil { - return unix.ControlMessages{}, err + return transport.ControlMessages{}, err } } if credentials == nil && rights == nil { - return unix.ControlMessages{}, nil + return transport.ControlMessages{}, nil } - return unix.ControlMessages{Credentials: credentials, Rights: rights}, nil + return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil } func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials { if t == nil || socketOrEndpoint == nil { return nil } - if cr, ok := socketOrEndpoint.(unix.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) { + if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) { tcred := t.Credentials() return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID} } @@ -401,8 +401,8 @@ func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials { } // New creates default control messages if needed. -func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) unix.ControlMessages { - return unix.ControlMessages{ +func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages { + return transport.ControlMessages{ Credentials: makeCreds(t, socketOrEndpoint), Rights: rights, } diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 7f9ea9edc..dbabc931c 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", @@ -42,7 +43,6 @@ go_library( "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index a44679f0b..47c575e7b 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -44,13 +44,13 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -108,26 +108,26 @@ func htons(v uint16) uint16 { } // commonEndpoint represents the intersection of a tcpip.Endpoint and a -// unix.Endpoint. +// transport.Endpoint. type commonEndpoint interface { // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and - // unix.Endpoint.GetLocalAddress. + // transport.Endpoint.GetLocalAddress. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and - // unix.Endpoint.GetRemoteAddress. + // transport.Endpoint.GetRemoteAddress. GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) // Readiness implements tcpip.Endpoint.Readiness and - // unix.Endpoint.Readiness. + // transport.Endpoint.Readiness. Readiness(mask waiter.EventMask) waiter.EventMask // SetSockOpt implements tcpip.Endpoint.SetSockOpt and - // unix.Endpoint.SetSockOpt. + // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and - // unix.Endpoint.GetSockOpt. + // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error } @@ -146,7 +146,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint - skType unix.SockType + skType transport.SockType // readMu protects access to readView, control, and sender. readMu sync.Mutex `state:"nosave"` @@ -156,7 +156,7 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType unix.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) *fs.File { +func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) *fs.File { dirent := socket.NewDirent(t, epsocketDevice) defer dirent.DecRef() return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true}, &SocketOperations{ @@ -502,7 +502,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType unix.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index 6c1e3b6b9..dbc232d26 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -21,6 +21,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" @@ -28,7 +29,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -40,7 +40,7 @@ type provider struct { // GetTransportProtocol figures out transport protocol. Currently only TCP, // UDP, and ICMP are supported. -func GetTransportProtocol(stype unix.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { @@ -62,7 +62,7 @@ func GetTransportProtocol(stype unix.SockType, protocol int) (tcpip.TransportPro } // Socket creates a new socket object for the AF_INET or AF_INET6 family. -func (p *provider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() if stack == nil { @@ -92,7 +92,7 @@ func (p *provider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*f } // Pair just returns nil sockets (not supported). -func (*provider) Pair(*kernel.Task, unix.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index d623718b3..c30220a46 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -29,10 +29,10 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", - "//pkg/tcpip/transport/unix", "//pkg/waiter", "//pkg/waiter/fdnotifier", ], diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index d0f3054dc..e82624b44 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -27,10 +27,10 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier" ) @@ -511,7 +511,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags unix.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the host network stack. stack := t.NetworkContext() if stack == nil { @@ -553,7 +553,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags unix.SockType, protoc } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index b852165f7..cff922cb8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -25,11 +25,11 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index e874216f4..5d0a04a07 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -22,8 +22,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) // Protocol is the implementation of a netlink socket protocol. @@ -66,10 +66,10 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (*socketProvider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { // Netlink sockets must be specified as datagram or raw, but they // behave the same regardless of type. - if stype != unix.SockDgram && stype != unix.SockRaw { + if stype != transport.SockDgram && stype != transport.SockRaw { return nil, syserr.ErrSocketNotSupported } @@ -94,7 +94,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype unix.SockType, protocol int) } // Pair implements socket.Provider.Pair by returning an error. -func (*socketProvider) Pair(*kernel.Task, unix.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Netlink sockets never supports creating socket pairs. return nil, nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index f3b2c7256..0c03997f2 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -31,12 +31,12 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink/port" - sunix "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" "gvisor.googlesource.com/gvisor/pkg/tcpip" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -80,11 +80,11 @@ type Socket struct { // ep is a datagram unix endpoint used to buffer messages sent from the // kernel to userspace. RecvMsg reads messages from this endpoint. - ep unix.Endpoint + ep transport.Endpoint // connection is the kernel's connection to ep, used to write messages // sent to userspace. - connection unix.ConnectedEndpoint + connection transport.ConnectedEndpoint // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -105,7 +105,7 @@ var _ socket.Socket = (*Socket)(nil) // NewSocket creates a new Socket. func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { // Datagram endpoint used to buffer kernel -> user messages. - ep := unix.NewConnectionless() + ep := transport.NewConnectionless() // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. @@ -115,7 +115,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { } // Create a connection from which the kernel can write messages. - connection, terr := ep.(unix.BoundEndpoint).UnidirectionalConnect() + connection, terr := ep.(transport.BoundEndpoint).UnidirectionalConnect() if terr != nil { ep.Close() return nil, syserr.TranslateNetstackError(terr) @@ -368,7 +368,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have trunc := flags&linux.MSG_TRUNC != 0 - r := sunix.EndpointReader{ + r := unix.EndpointReader{ Endpoint: s.ep, Peek: flags&linux.MSG_PEEK != 0, } @@ -408,7 +408,7 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ if dst.NumBytes() == 0 { return 0, nil } - return dst.CopyOutFrom(ctx, &sunix.EndpointReader{ + return dst.CopyOutFrom(ctx, &unix.EndpointReader{ Endpoint: s.ep, }) } @@ -424,7 +424,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, terr := s.connection.Send(bufs, unix.ControlMessages{}, tcpip.FullAddress{}) + _, notify, terr := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if terr != nil && terr != tcpip.ErrWouldBlock { @@ -448,7 +448,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error PortID: uint32(ms.PortID), }) - _, notify, terr := s.connection.Send([][]byte{m.Finalize()}, unix.ControlMessages{}, tcpip.FullAddress{}) + _, notify, terr := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) if terr != nil && terr != tcpip.ErrWouldBlock { return syserr.TranslateNetstackError(terr) } diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 288199779..3ea433360 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -31,12 +31,12 @@ go_library( "//pkg/sentry/socket/hostinet", "//pkg/sentry/socket/rpcinet/conn", "//pkg/sentry/socket/rpcinet/notifier", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", - "//pkg/tcpip/transport/unix", "//pkg/unet", "//pkg/waiter", ], diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 72fa1ca8f..c7e761d54 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -31,12 +31,12 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier" pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -763,7 +763,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags unix.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the RPC network stack. stack := t.NetworkContext() if stack == nil { @@ -803,7 +803,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags unix.SockType, protoc } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 54fe64595..31f8d42d7 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -29,16 +29,16 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs" ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) // ControlMessages represents the union of unix control messages and tcpip // control messages. type ControlMessages struct { - Unix unix.ControlMessages + Unix transport.ControlMessages IP tcpip.ControlMessages } @@ -109,12 +109,12 @@ type Provider interface { // If a nil Socket _and_ a nil error is returned, it means that the // protocol is not supported. A non-nil error should only be returned // if the protocol is supported, but an error occurs during creation. - Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) + Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) // Pair creates a pair of connected sockets. // // See Socket for error information. - Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) + Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) } // families holds a map of all known address families and their providers. @@ -128,7 +128,7 @@ func RegisterProvider(family int, provider Provider) { } // New creates a new socket with the given family, type and protocol. -func New(t *kernel.Task, family int, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { for _, p := range families[family] { s, err := p.Socket(t, stype, protocol) if err != nil { @@ -144,7 +144,7 @@ func New(t *kernel.Task, family int, stype unix.SockType, protocol int) (*fs.Fil // Pair creates a new connected socket pair with the given family, type and // protocol. -func Pair(t *kernel.Task, family int, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { providers, ok := families[family] if !ok { return nil, nil, syserr.ErrAddressFamilyNotSupported diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 9fe681e9a..a12fa93db 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -26,11 +26,11 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/epsocket", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 0ca2e35d0..06333e14b 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -16,23 +16,23 @@ package unix import ( "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) -// EndpointWriter implements safemem.Writer that writes to a unix.Endpoint. +// EndpointWriter implements safemem.Writer that writes to a transport.Endpoint. // // EndpointWriter is not thread-safe. type EndpointWriter struct { - // Endpoint is the unix.Endpoint to write to. - Endpoint unix.Endpoint + // Endpoint is the transport.Endpoint to write to. + Endpoint transport.Endpoint // Control is the control messages to send. - Control unix.ControlMessages + Control transport.ControlMessages // To is the endpoint to send to. May be nil. - To unix.BoundEndpoint + To transport.BoundEndpoint } // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. @@ -46,12 +46,13 @@ func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) }}.WriteFromBlocks(srcs) } -// EndpointReader implements safemem.Reader that reads from a unix.Endpoint. +// EndpointReader implements safemem.Reader that reads from a +// transport.Endpoint. // // EndpointReader is not thread-safe. type EndpointReader struct { - // Endpoint is the unix.Endpoint to read from. - Endpoint unix.Endpoint + // Endpoint is the transport.Endpoint to read from. + Endpoint transport.Endpoint // Creds indicates if credential control messages are requested. Creds bool @@ -71,7 +72,7 @@ type EndpointReader struct { From *tcpip.FullAddress // Control contains the received control messages. - Control unix.ControlMessages + Control transport.ControlMessages } // ReadToBlocks implements safemem.Reader.ReadToBlocks. diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD new file mode 100644 index 000000000..04ef0d438 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -0,0 +1,22 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("//tools/go_stateify:defs.bzl", "go_library") + +go_library( + name = "transport", + srcs = [ + "connectioned.go", + "connectioned_state.go", + "connectionless.go", + "unix.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport", + visibility = ["//:sandbox"], + deps = [ + "//pkg/ilist", + "//pkg/sentry/socket/unix/transport/queue", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go new file mode 100644 index 000000000..f09935765 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -0,0 +1,454 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport/queue" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// UniqueIDProvider generates a sequence of unique identifiers useful for, +// among other things, lock ordering. +type UniqueIDProvider interface { + // UniqueID returns a new unique identifier. + UniqueID() uint64 +} + +// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to +// establish a bidirectional connection with a BoundEndpoint. +type ConnectingEndpoint interface { + // ID returns the endpoint's globally unique identifier. This identifier + // must be used to determine locking order if more than one endpoint is + // to be locked in the same codepath. The endpoint with the smaller + // identifier must be locked before endpoints with larger identifiers. + ID() uint64 + + // Passcred implements socket.Credentialer.Passcred. + Passcred() bool + + // Type returns the socket type, typically either SockStream or + // SockSeqpacket. The connection attempt must be aborted if this + // value doesn't match the ConnectableEndpoint's type. + Type() SockType + + // GetLocalAddress returns the bound path. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // Locker protects the following methods. While locked, only the holder of + // the lock can change the return value of the protected methods. + sync.Locker + + // Connected returns true iff the ConnectingEndpoint is in the connected + // state. ConnectingEndpoints can only be connected to a single endpoint, + // so the connection attempt must be aborted if this returns true. + Connected() bool + + // Listening returns true iff the ConnectingEndpoint is in the listening + // state. ConnectingEndpoints cannot make connections while listening, so + // the connection attempt must be aborted if this returns true. + Listening() bool + + // WaiterQueue returns a pointer to the endpoint's waiter queue. + WaiterQueue() *waiter.Queue +} + +// connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements +// ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint. +// +// connectionedEndpoints must be in connected state in order to transfer data. +// +// This implementation includes STREAM and SEQPACKET Unix sockets created with +// socket(2), accept(2) or socketpair(2) and dgram unix sockets created with +// socketpair(2). See unix_connectionless.go for the implementation of DGRAM +// Unix sockets created with socket(2). +// +// The state is much simpler than a TCP endpoint, so it is not encoded +// explicitly. Instead we enforce the following invariants: +// +// receiver != nil, connected != nil => connected. +// path != "" && acceptedChan == nil => bound, not listening. +// path != "" && acceptedChan != nil => bound and listening. +// +// Only one of these will be true at any moment. +// +// +stateify savable +type connectionedEndpoint struct { + baseEndpoint + + // id is the unique endpoint identifier. This is used exclusively for + // lock ordering within connect. + id uint64 + + // idGenerator is used to generate new unique endpoint identifiers. + idGenerator UniqueIDProvider + + // stype is used by connecting sockets to ensure that they are the + // same type. The value is typically either tcpip.SockSeqpacket or + // tcpip.SockStream. + stype SockType + + // acceptedChan is per the TCP endpoint implementation. Note that the + // sockets in this channel are _already in the connected state_, and + // have another associated connectionedEndpoint. + // + // If nil, then no listen call has been made. + acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"` +} + +// NewConnectioned creates a new unbound connectionedEndpoint. +func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { + return &connectionedEndpoint{ + baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, + id: uid.UniqueID(), + idGenerator: uid, + stype: stype, + } +} + +// NewPair allocates a new pair of connected unix-domain connectionedEndpoints. +func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { + a := &connectionedEndpoint{ + baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, + id: uid.UniqueID(), + idGenerator: uid, + stype: stype, + } + b := &connectionedEndpoint{ + baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, + id: uid.UniqueID(), + idGenerator: uid, + stype: stype, + } + + q1 := queue.New(a.Queue, b.Queue, initialLimit) + q2 := queue.New(b.Queue, a.Queue, initialLimit) + + if stype == SockStream { + a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} + b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} + } else { + a.receiver = &queueReceiver{q1} + b.receiver = &queueReceiver{q2} + } + + a.connected = &connectedEndpoint{ + endpoint: b, + writeQueue: q2, + } + b.connected = &connectedEndpoint{ + endpoint: a, + writeQueue: q1, + } + + return a, b +} + +// NewExternal creates a new externally backed Endpoint. It behaves like a +// socketpair. +func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { + return &connectionedEndpoint{ + baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, + id: uid.UniqueID(), + idGenerator: uid, + stype: stype, + } +} + +// ID implements ConnectingEndpoint.ID. +func (e *connectionedEndpoint) ID() uint64 { + return e.id +} + +// Type implements ConnectingEndpoint.Type and Endpoint.Type. +func (e *connectionedEndpoint) Type() SockType { + return e.stype +} + +// WaiterQueue implements ConnectingEndpoint.WaiterQueue. +func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue { + return e.Queue +} + +// isBound returns true iff the connectionedEndpoint is bound (but not +// listening). +func (e *connectionedEndpoint) isBound() bool { + return e.path != "" && e.acceptedChan == nil +} + +// Listening implements ConnectingEndpoint.Listening. +func (e *connectionedEndpoint) Listening() bool { + return e.acceptedChan != nil +} + +// Close puts the connectionedEndpoint in a closed state and frees all +// resources associated with it. +// +// The socket will be a fresh state after a call to close and may be reused. +// That is, close may be used to "unbind" or "disconnect" the socket in error +// paths. +func (e *connectionedEndpoint) Close() { + e.Lock() + var c ConnectedEndpoint + var r Receiver + switch { + case e.Connected(): + e.connected.CloseSend() + e.receiver.CloseRecv() + c = e.connected + r = e.receiver + e.connected = nil + e.receiver = nil + case e.isBound(): + e.path = "" + case e.Listening(): + close(e.acceptedChan) + for n := range e.acceptedChan { + n.Close() + } + e.acceptedChan = nil + e.path = "" + } + e.Unlock() + if c != nil { + c.CloseNotify() + c.Release() + } + if r != nil { + r.CloseNotify() + r.Release() + } +} + +// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. +func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error { + if ce.Type() != e.stype { + return tcpip.ErrConnectionRefused + } + + // Check if ce is e to avoid a deadlock. + if ce, ok := ce.(*connectionedEndpoint); ok && ce == e { + return tcpip.ErrInvalidEndpointState + } + + // Do a dance to safely acquire locks on both endpoints. + if e.id < ce.ID() { + e.Lock() + ce.Lock() + } else { + ce.Lock() + e.Lock() + } + + // Check connecting state. + if ce.Connected() { + e.Unlock() + ce.Unlock() + return tcpip.ErrAlreadyConnected + } + if ce.Listening() { + e.Unlock() + ce.Unlock() + return tcpip.ErrInvalidEndpointState + } + + // Check bound state. + if !e.Listening() { + e.Unlock() + ce.Unlock() + return tcpip.ErrConnectionRefused + } + + // Create a newly bound connectionedEndpoint. + ne := &connectionedEndpoint{ + baseEndpoint: baseEndpoint{ + path: e.path, + Queue: &waiter.Queue{}, + }, + id: e.idGenerator.UniqueID(), + idGenerator: e.idGenerator, + stype: e.stype, + } + readQueue := queue.New(ce.WaiterQueue(), ne.Queue, initialLimit) + writeQueue := queue.New(ne.Queue, ce.WaiterQueue(), initialLimit) + ne.connected = &connectedEndpoint{ + endpoint: ce, + writeQueue: readQueue, + } + if e.stype == SockStream { + ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} + } else { + ne.receiver = &queueReceiver{readQueue: writeQueue} + } + + select { + case e.acceptedChan <- ne: + // Commit state. + connected := &connectedEndpoint{ + endpoint: ne, + writeQueue: writeQueue, + } + if e.stype == SockStream { + returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) + } else { + returnConnect(&queueReceiver{readQueue: readQueue}, connected) + } + + // Notify can deadlock if we are holding these locks. + e.Unlock() + ce.Unlock() + + // Notify on both ends. + e.Notify(waiter.EventIn) + ce.WaiterQueue().Notify(waiter.EventOut) + + return nil + default: + // Busy; return ECONNREFUSED per spec. + ne.Close() + e.Unlock() + ce.Unlock() + return tcpip.ErrConnectionRefused + } +} + +// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. +func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) { + return nil, tcpip.ErrConnectionRefused +} + +// Connect attempts to directly connect to another Endpoint. +// Implements Endpoint.Connect. +func (e *connectionedEndpoint) Connect(server BoundEndpoint) *tcpip.Error { + returnConnect := func(r Receiver, ce ConnectedEndpoint) { + e.receiver = r + e.connected = ce + } + + return server.BidirectionalConnect(e, returnConnect) +} + +// Listen starts listening on the connection. +func (e *connectionedEndpoint) Listen(backlog int) *tcpip.Error { + e.Lock() + defer e.Unlock() + if e.Listening() { + // Adjust the size of the channel iff we can fix existing + // pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *connectionedEndpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } + return nil + } + if !e.isBound() { + return tcpip.ErrInvalidEndpointState + } + + // Normal case. + e.acceptedChan = make(chan *connectionedEndpoint, backlog) + return nil +} + +// Accept accepts a new connection. +func (e *connectionedEndpoint) Accept() (Endpoint, *tcpip.Error) { + e.Lock() + defer e.Unlock() + + if !e.Listening() { + return nil, tcpip.ErrInvalidEndpointState + } + + select { + case ne := <-e.acceptedChan: + return ne, nil + + default: + // Nothing left. + return nil, tcpip.ErrWouldBlock + } +} + +// Bind binds the connection. +// +// For Unix connectionedEndpoints, this _only sets the address associated with +// the socket_. Work associated with sockets in the filesystem or finding those +// sockets must be done by a higher level. +// +// Bind will fail only if the socket is connected, bound or the passed address +// is invalid (the empty string). +func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.Lock() + defer e.Unlock() + if e.isBound() || e.Listening() { + return tcpip.ErrAlreadyBound + } + if addr.Addr == "" { + // The empty string is not permitted. + return tcpip.ErrBadLocalAddress + } + if commit != nil { + if err := commit(); err != nil { + return err + } + } + + // Save the bound address. + e.path = string(addr.Addr) + return nil +} + +// SendMsg writes data and a control message to the endpoint's peer. +// This method does not block if the data cannot be written. +func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) { + // Stream sockets do not support specifying the endpoint. Seqpacket + // sockets ignore the passed endpoint. + if e.stype == SockStream && to != nil { + return 0, tcpip.ErrNotSupported + } + return e.baseEndpoint.SendMsg(data, c, to) +} + +// Readiness returns the current readiness of the connectionedEndpoint. For +// example, if waiter.EventIn is set, the connectionedEndpoint is immediately +// readable. +func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + e.Lock() + defer e.Unlock() + + ready := waiter.EventMask(0) + switch { + case e.Connected(): + if mask&waiter.EventIn != 0 && e.receiver.Readable() { + ready |= waiter.EventIn + } + if mask&waiter.EventOut != 0 && e.connected.Writable() { + ready |= waiter.EventOut + } + case e.Listening(): + if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 { + ready |= waiter.EventIn + } + } + + return ready +} diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go new file mode 100644 index 000000000..7e6c73dcc --- /dev/null +++ b/pkg/sentry/socket/unix/transport/connectioned_state.go @@ -0,0 +1,53 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +// saveAcceptedChan is invoked by stateify. +func (e *connectionedEndpoint) saveAcceptedChan() []*connectionedEndpoint { + // If acceptedChan is nil (i.e. we are not listening) then we will save nil. + // Otherwise we create a (possibly empty) slice of the values in acceptedChan and + // save that. + var acceptedSlice []*connectionedEndpoint + if e.acceptedChan != nil { + // Swap out acceptedChan with a new empty channel of the same capacity. + saveChan := e.acceptedChan + e.acceptedChan = make(chan *connectionedEndpoint, cap(saveChan)) + + // Create a new slice with the same len and capacity as the channel. + acceptedSlice = make([]*connectionedEndpoint, len(saveChan), cap(saveChan)) + // Drain acceptedChan into saveSlice, and fill up the new acceptChan at the + // same time. + for i := range acceptedSlice { + ep := <-saveChan + acceptedSlice[i] = ep + e.acceptedChan <- ep + } + close(saveChan) + } + return acceptedSlice +} + +// loadAcceptedChan is invoked by stateify. +func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEndpoint) { + // If acceptedSlice is nil, then acceptedChan should also be nil. + if acceptedSlice != nil { + // Otherwise, create a new channel with the same capacity as acceptedSlice. + e.acceptedChan = make(chan *connectionedEndpoint, cap(acceptedSlice)) + // Seed the channel with values from acceptedSlice. + for _, ep := range acceptedSlice { + e.acceptedChan <- ep + } + } +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go new file mode 100644 index 000000000..fb2728010 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -0,0 +1,192 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport/queue" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in +// a conectionless fashon. +// +// Specifically, this means datagram unix sockets not created with +// socketpair(2). +// +// +stateify savable +type connectionlessEndpoint struct { + baseEndpoint +} + +// NewConnectionless creates a new unbound dgram endpoint. +func NewConnectionless() Endpoint { + ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} + ep.receiver = &queueReceiver{readQueue: queue.New(&waiter.Queue{}, ep.Queue, initialLimit)} + return ep +} + +// isBound returns true iff the endpoint is bound. +func (e *connectionlessEndpoint) isBound() bool { + return e.path != "" +} + +// Close puts the endpoint in a closed state and frees all resources associated +// with it. +// +// The socket will be a fresh state after a call to close and may be reused. +// That is, close may be used to "unbind" or "disconnect" the socket in error +// paths. +func (e *connectionlessEndpoint) Close() { + e.Lock() + var r Receiver + if e.Connected() { + e.receiver.CloseRecv() + r = e.receiver + e.receiver = nil + + e.connected.Release() + e.connected = nil + } + if e.isBound() { + e.path = "" + } + e.Unlock() + if r != nil { + r.CloseNotify() + r.Release() + } +} + +// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. +func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error { + return tcpip.ErrConnectionRefused +} + +// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. +func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) { + e.Lock() + r := e.receiver + e.Unlock() + if r == nil { + return nil, tcpip.ErrConnectionRefused + } + return &connectedEndpoint{ + endpoint: e, + writeQueue: r.(*queueReceiver).readQueue, + }, nil +} + +// SendMsg writes data and a control message to the specified endpoint. +// This method does not block if the data cannot be written. +func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) { + if to == nil { + return e.baseEndpoint.SendMsg(data, c, nil) + } + + connected, err := to.UnidirectionalConnect() + if err != nil { + return 0, tcpip.ErrInvalidEndpointState + } + defer connected.Release() + + e.Lock() + n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + e.Unlock() + + if notify { + connected.SendNotify() + } + + return n, err +} + +// Type implements Endpoint.Type. +func (e *connectionlessEndpoint) Type() SockType { + return SockDgram +} + +// Connect attempts to connect directly to server. +func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *tcpip.Error { + connected, err := server.UnidirectionalConnect() + if err != nil { + return err + } + + e.Lock() + e.connected = connected + e.Unlock() + + return nil +} + +// Listen starts listening on the connection. +func (e *connectionlessEndpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept accepts a new connection. +func (e *connectionlessEndpoint) Accept() (Endpoint, *tcpip.Error) { + return nil, tcpip.ErrNotSupported +} + +// Bind binds the connection. +// +// For Unix endpoints, this _only sets the address associated with the socket_. +// Work associated with sockets in the filesystem or finding those sockets must +// be done by a higher level. +// +// Bind will fail only if the socket is connected, bound or the passed address +// is invalid (the empty string). +func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.Lock() + defer e.Unlock() + if e.isBound() { + return tcpip.ErrAlreadyBound + } + if addr.Addr == "" { + // The empty string is not permitted. + return tcpip.ErrBadLocalAddress + } + if commit != nil { + if err := commit(); err != nil { + return err + } + } + + // Save the bound address. + e.path = string(addr.Addr) + return nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + e.Lock() + defer e.Unlock() + + ready := waiter.EventMask(0) + if mask&waiter.EventIn != 0 && e.receiver.Readable() { + ready |= waiter.EventIn + } + + if e.Connected() { + if mask&waiter.EventOut != 0 && e.connected.Writable() { + ready |= waiter.EventOut + } + } + + return ready +} diff --git a/pkg/sentry/socket/unix/transport/queue/BUILD b/pkg/sentry/socket/unix/transport/queue/BUILD new file mode 100644 index 000000000..d914ecc23 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/queue/BUILD @@ -0,0 +1,15 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("//tools/go_stateify:defs.bzl", "go_library") + +go_library( + name = "queue", + srcs = ["queue.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport/queue", + visibility = ["//:sandbox"], + deps = [ + "//pkg/ilist", + "//pkg/tcpip", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/socket/unix/transport/queue/queue.go b/pkg/sentry/socket/unix/transport/queue/queue.go new file mode 100644 index 000000000..b3d2ea68b --- /dev/null +++ b/pkg/sentry/socket/unix/transport/queue/queue.go @@ -0,0 +1,227 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package queue provides the implementation of buffer queue +// and interface of queue entry with Length method. +package queue + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/ilist" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Entry implements Linker interface and has additional required methods. +type Entry interface { + ilist.Linker + + // Length returns the number of bytes stored in the entry. + Length() int64 + + // Release releases any resources held by the entry. + Release() + + // Peek returns a copy of the entry. It must be Released separately. + Peek() Entry + + // Truncate reduces the number of bytes stored in the entry to n bytes. + // + // Preconditions: n <= Length(). + Truncate(n int64) +} + +// Queue is a buffer queue. +// +// +stateify savable +type Queue struct { + ReaderQueue *waiter.Queue + WriterQueue *waiter.Queue + + mu sync.Mutex `state:"nosave"` + closed bool + used int64 + limit int64 + dataList ilist.List +} + +// New allocates and initializes a new queue. +func New(ReaderQueue *waiter.Queue, WriterQueue *waiter.Queue, limit int64) *Queue { + return &Queue{ReaderQueue: ReaderQueue, WriterQueue: WriterQueue, limit: limit} +} + +// Close closes q for reading and writing. It is immediately not writable and +// will become unreadable when no more data is pending. +// +// Both the read and write queues must be notified after closing: +// q.ReaderQueue.Notify(waiter.EventIn) +// q.WriterQueue.Notify(waiter.EventOut) +func (q *Queue) Close() { + q.mu.Lock() + q.closed = true + q.mu.Unlock() +} + +// Reset empties the queue and Releases all of the Entries. +// +// Both the read and write queues must be notified after resetting: +// q.ReaderQueue.Notify(waiter.EventIn) +// q.WriterQueue.Notify(waiter.EventOut) +func (q *Queue) Reset() { + q.mu.Lock() + for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { + cur.(Entry).Release() + } + q.dataList.Reset() + q.used = 0 + q.mu.Unlock() +} + +// IsReadable determines if q is currently readable. +func (q *Queue) IsReadable() bool { + q.mu.Lock() + defer q.mu.Unlock() + + return q.closed || q.dataList.Front() != nil +} + +// bufWritable returns true if there is space for writing. +// +// N.B. Linux only considers a unix socket "writable" if >75% of the buffer is +// free. +// +// See net/unix/af_unix.c:unix_writeable. +func (q *Queue) bufWritable() bool { + return 4*q.used < q.limit +} + +// IsWritable determines if q is currently writable. +func (q *Queue) IsWritable() bool { + q.mu.Lock() + defer q.mu.Unlock() + + return q.closed || q.bufWritable() +} + +// Enqueue adds an entry to the data queue if room is available. +// +// If truncate is true, Enqueue may truncate the message beforing enqueuing it. +// Otherwise, the entire message must fit. If n < e.Length(), err indicates why. +// +// If notify is true, ReaderQueue.Notify must be called: +// q.ReaderQueue.Notify(waiter.EventIn) +func (q *Queue) Enqueue(e Entry, truncate bool) (l int64, notify bool, err *tcpip.Error) { + q.mu.Lock() + + if q.closed { + q.mu.Unlock() + return 0, false, tcpip.ErrClosedForSend + } + + free := q.limit - q.used + + l = e.Length() + + if l > free && truncate { + if free == 0 { + // Message can't fit right now. + q.mu.Unlock() + return 0, false, tcpip.ErrWouldBlock + } + + e.Truncate(free) + l = e.Length() + err = tcpip.ErrWouldBlock + } + + if l > q.limit { + // Message is too big to ever fit. + q.mu.Unlock() + return 0, false, tcpip.ErrMessageTooLong + } + + if l > free { + // Message can't fit right now. + q.mu.Unlock() + return 0, false, tcpip.ErrWouldBlock + } + + notify = q.dataList.Front() == nil + q.used += l + q.dataList.PushBack(e) + + q.mu.Unlock() + + return l, notify, err +} + +// Dequeue removes the first entry in the data queue, if one exists. +// +// If notify is true, WriterQueue.Notify must be called: +// q.WriterQueue.Notify(waiter.EventOut) +func (q *Queue) Dequeue() (e Entry, notify bool, err *tcpip.Error) { + q.mu.Lock() + + if q.dataList.Front() == nil { + err := tcpip.ErrWouldBlock + if q.closed { + err = tcpip.ErrClosedForReceive + } + q.mu.Unlock() + + return nil, false, err + } + + notify = !q.bufWritable() + + e = q.dataList.Front().(Entry) + q.dataList.Remove(e) + q.used -= e.Length() + + notify = notify && q.bufWritable() + + q.mu.Unlock() + + return e, notify, nil +} + +// Peek returns the first entry in the data queue, if one exists. +func (q *Queue) Peek() (Entry, *tcpip.Error) { + q.mu.Lock() + defer q.mu.Unlock() + + if q.dataList.Front() == nil { + err := tcpip.ErrWouldBlock + if q.closed { + err = tcpip.ErrClosedForReceive + } + return nil, err + } + + return q.dataList.Front().(Entry).Peek(), nil +} + +// QueuedSize returns the number of bytes currently in the queue, that is, the +// number of readable bytes. +func (q *Queue) QueuedSize() int64 { + q.mu.Lock() + defer q.mu.Unlock() + return q.used +} + +// MaxQueueSize returns the maximum number of bytes storable in the queue. +func (q *Queue) MaxQueueSize() int64 { + return q.limit +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go new file mode 100644 index 000000000..577aa87d5 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -0,0 +1,953 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package transport contains the implementation of Unix endpoints. +package transport + +import ( + "sync" + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/ilist" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport/queue" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// initialLimit is the starting limit for the socket buffers. +const initialLimit = 16 * 1024 + +// A SockType is a type (as opposed to family) of sockets. These are enumerated +// in the syscall package as syscall.SOCK_* constants. +type SockType int + +const ( + // SockStream corresponds to syscall.SOCK_STREAM. + SockStream SockType = 1 + // SockDgram corresponds to syscall.SOCK_DGRAM. + SockDgram SockType = 2 + // SockRaw corresponds to syscall.SOCK_RAW. + SockRaw SockType = 3 + // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET. + SockSeqpacket SockType = 5 +) + +// A RightsControlMessage is a control message containing FDs. +type RightsControlMessage interface { + // Clone returns a copy of the RightsControlMessage. + Clone() RightsControlMessage + + // Release releases any resources owned by the RightsControlMessage. + Release() +} + +// A CredentialsControlMessage is a control message containing Unix credentials. +type CredentialsControlMessage interface { + // Equals returns true iff the two messages are equal. + Equals(CredentialsControlMessage) bool +} + +// A ControlMessages represents a collection of socket control messages. +// +// +stateify savable +type ControlMessages struct { + // Rights is a control message containing FDs. + Rights RightsControlMessage + + // Credentials is a control message containing Unix credentials. + Credentials CredentialsControlMessage +} + +// Empty returns true iff the ControlMessages does not contain either +// credentials or rights. +func (c *ControlMessages) Empty() bool { + return c.Rights == nil && c.Credentials == nil +} + +// Clone clones both the credentials and the rights. +func (c *ControlMessages) Clone() ControlMessages { + cm := ControlMessages{} + if c.Rights != nil { + cm.Rights = c.Rights.Clone() + } + cm.Credentials = c.Credentials + return cm +} + +// Release releases both the credentials and the rights. +func (c *ControlMessages) Release() { + if c.Rights != nil { + c.Rights.Release() + } + *c = ControlMessages{} +} + +// Endpoint is the interface implemented by Unix transport protocol +// implementations that expose functionality like sendmsg, recvmsg, connect, +// etc. to Unix socket implementations. +type Endpoint interface { + Credentialer + waiter.Waitable + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. + Close() + + // RecvMsg reads data and a control message from the endpoint. This method + // does not block if there is no data pending. + // + // creds indicates if credential control messages are requested by the + // caller. This is useful for determining if control messages can be + // coalesced. creds is a hint and can be safely ignored by the + // implementation if no coalescing is possible. It is fine to return + // credential control messages when none were requested or to not return + // credential control messages when they were requested. + // + // numRights is the number of SCM_RIGHTS FDs requested by the caller. This + // is useful if one must allocate a buffer to receive a SCM_RIGHTS message + // or determine if control messages can be coalesced. numRights is a hint + // and can be safely ignored by the implementation if the number of + // available SCM_RIGHTS FDs is known and no coalescing is possible. It is + // fine for the returned number of SCM_RIGHTS FDs to be either higher or + // lower than the requested number. + // + // If peek is true, no data should be consumed from the Endpoint. Any and + // all data returned from a peek should be available in the next call to + // RecvMsg. + // + // recvLen is the number of bytes copied into data. + // + // msgLen is the length of the read message consumed for datagram Endpoints. + // msgLen is always the same as recvLen for stream Endpoints. + RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *tcpip.Error) + + // SendMsg writes data and a control message to the endpoint's peer. + // This method does not block if the data cannot be written. + // + // SendMsg does not take ownership of any of its arguments on error. + SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *tcpip.Error) + + // Connect connects this endpoint directly to another. + // + // This should be called on the client endpoint, and the (bound) + // endpoint passed in as a parameter. + // + // The error codes are the same as Connect. + Connect(server BoundEndpoint) *tcpip.Error + + // Shutdown closes the read and/or write end of the endpoint connection + // to its peer. + Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error + + // Listen puts the endpoint in "listen" mode, which allows it to accept + // new connections. + Listen(backlog int) *tcpip.Error + + // Accept returns a new endpoint if a peer has established a connection + // to an endpoint previously set to listen mode. This method does not + // block if no new connections are available. + // + // The returned Queue is the wait queue for the newly created endpoint. + Accept() (Endpoint, *tcpip.Error) + + // Bind binds the endpoint to a specific local address and port. + // Specifying a NIC is optional. + // + // An optional commit function will be executed atomically with respect + // to binding the endpoint. If this returns an error, the bind will not + // occur and the error will be propagated back to the caller. + Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error + + // Type return the socket type, typically either SockStream, SockDgram + // or SockSeqpacket. + Type() SockType + + // GetLocalAddress returns the address to which the endpoint is bound. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // GetRemoteAddress returns the address to which the endpoint is + // connected. + GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) + + // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option + // types. + SetSockOpt(opt interface{}) *tcpip.Error + + // GetSockOpt gets a socket option. opt should be a pointer to one of the + // tcpip.*Option types. + GetSockOpt(opt interface{}) *tcpip.Error +} + +// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket +// option. +type Credentialer interface { + // Passcred returns whether or not the SO_PASSCRED socket option is + // enabled on this end. + Passcred() bool + + // ConnectedPasscred returns whether or not the SO_PASSCRED socket option + // is enabled on the connected end. + ConnectedPasscred() bool +} + +// A BoundEndpoint is a unix endpoint that can be connected to. +type BoundEndpoint interface { + // BidirectionalConnect establishes a bi-directional connection between two + // unix endpoints in an all-or-nothing manner. If an error occurs during + // connecting, the state of neither endpoint should be modified. + // + // In order for an endpoint to establish such a bidirectional connection + // with a BoundEndpoint, the endpoint calls the BidirectionalConnect method + // on the BoundEndpoint and sends a representation of itself (the + // ConnectingEndpoint) and a callback (returnConnect) to receive the + // connection information (Receiver and ConnectedEndpoint) upon a + // successful connect. The callback should only be called on a successful + // connect. + // + // For a connection attempt to be successful, the ConnectingEndpoint must + // be unconnected and not listening and the BoundEndpoint whose + // BidirectionalConnect method is being called must be listening. + // + // This method will return tcpip.ErrConnectionRefused on endpoints with a + // type that isn't SockStream or SockSeqpacket. + BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error + + // UnidirectionalConnect establishes a write-only connection to a unix + // endpoint. + // + // An endpoint which calls UnidirectionalConnect and supports it itself must + // not hold its own lock when calling UnidirectionalConnect. + // + // This method will return tcpip.ErrConnectionRefused on a non-SockDgram + // endpoint. + UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) + + // Release releases any resources held by the BoundEndpoint. It must be + // called before dropping all references to a BoundEndpoint returned by a + // function. + Release() +} + +// message represents a message passed over a Unix domain socket. +// +// +stateify savable +type message struct { + ilist.Entry + + // Data is the Message payload. + Data buffer.View + + // Control is auxiliary control message data that goes along with the + // data. + Control ControlMessages + + // Address is the bound address of the endpoint that sent the message. + // + // If the endpoint that sent the message is not bound, the Address is + // the empty string. + Address tcpip.FullAddress +} + +// Length returns number of bytes stored in the message. +func (m *message) Length() int64 { + return int64(len(m.Data)) +} + +// Release releases any resources held by the message. +func (m *message) Release() { + m.Control.Release() +} + +// Peek returns a copy of the message. +func (m *message) Peek() queue.Entry { + return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address} +} + +// Truncate reduces the length of the message payload to n bytes. +// +// Preconditions: n <= m.Length(). +func (m *message) Truncate(n int64) { + m.Data.CapLength(int(n)) +} + +// A Receiver can be used to receive Messages. +type Receiver interface { + // Recv receives a single message. This method does not block. + // + // See Endpoint.RecvMsg for documentation on shared arguments. + // + // notify indicates if RecvNotify should be called. + Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *tcpip.Error) + + // RecvNotify notifies the Receiver of a successful Recv. This must not be + // called while holding any endpoint locks. + RecvNotify() + + // CloseRecv prevents the receiving of additional Messages. + // + // After CloseRecv is called, CloseNotify must also be called. + CloseRecv() + + // CloseNotify notifies the Receiver of recv being closed. This must not be + // called while holding any endpoint locks. + CloseNotify() + + // Readable returns if messages should be attempted to be received. This + // includes when read has been shutdown. + Readable() bool + + // RecvQueuedSize returns the total amount of data currently receivable. + // RecvQueuedSize should return -1 if the operation isn't supported. + RecvQueuedSize() int64 + + // RecvMaxQueueSize returns maximum value for RecvQueuedSize. + // RecvMaxQueueSize should return -1 if the operation isn't supported. + RecvMaxQueueSize() int64 + + // Release releases any resources owned by the Receiver. It should be + // called before droping all references to a Receiver. + Release() +} + +// queueReceiver implements Receiver for datagram sockets. +// +// +stateify savable +type queueReceiver struct { + readQueue *queue.Queue +} + +// Recv implements Receiver.Recv. +func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { + var m queue.Entry + var notify bool + var err *tcpip.Error + if peek { + m, err = q.readQueue.Peek() + } else { + m, notify, err = q.readQueue.Dequeue() + } + if err != nil { + return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + } + msg := m.(*message) + src := []byte(msg.Data) + var copied uintptr + for i := 0; i < len(data) && len(src) > 0; i++ { + n := copy(data[i], src) + copied += uintptr(n) + src = src[n:] + } + return copied, uintptr(len(msg.Data)), msg.Control, msg.Address, notify, nil +} + +// RecvNotify implements Receiver.RecvNotify. +func (q *queueReceiver) RecvNotify() { + q.readQueue.WriterQueue.Notify(waiter.EventOut) +} + +// CloseNotify implements Receiver.CloseNotify. +func (q *queueReceiver) CloseNotify() { + q.readQueue.ReaderQueue.Notify(waiter.EventIn) + q.readQueue.WriterQueue.Notify(waiter.EventOut) +} + +// CloseRecv implements Receiver.CloseRecv. +func (q *queueReceiver) CloseRecv() { + q.readQueue.Close() +} + +// Readable implements Receiver.Readable. +func (q *queueReceiver) Readable() bool { + return q.readQueue.IsReadable() +} + +// RecvQueuedSize implements Receiver.RecvQueuedSize. +func (q *queueReceiver) RecvQueuedSize() int64 { + return q.readQueue.QueuedSize() +} + +// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize. +func (q *queueReceiver) RecvMaxQueueSize() int64 { + return q.readQueue.MaxQueueSize() +} + +// Release implements Receiver.Release. +func (*queueReceiver) Release() {} + +// streamQueueReceiver implements Receiver for stream sockets. +// +// +stateify savable +type streamQueueReceiver struct { + queueReceiver + + mu sync.Mutex `state:"nosave"` + buffer []byte + control ControlMessages + addr tcpip.FullAddress +} + +func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) { + var copied uintptr + for len(data) > 0 && len(buf) > 0 { + n := copy(data[0], buf) + copied += uintptr(n) + buf = buf[n:] + data[0] = data[0][n:] + if len(data[0]) == 0 { + data = data[1:] + } + } + return copied, data, buf +} + +// Readable implements Receiver.Readable. +func (q *streamQueueReceiver) Readable() bool { + q.mu.Lock() + bl := len(q.buffer) + r := q.readQueue.IsReadable() + q.mu.Unlock() + // We're readable if we have data in our buffer or if the queue receiver is + // readable. + return bl > 0 || r +} + +// RecvQueuedSize implements Receiver.RecvQueuedSize. +func (q *streamQueueReceiver) RecvQueuedSize() int64 { + q.mu.Lock() + bl := len(q.buffer) + qs := q.readQueue.QueuedSize() + q.mu.Unlock() + return int64(bl) + qs +} + +// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize. +func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { + // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest + // message we can buffer which is also the largest message we can receive. + return 2 * q.readQueue.MaxQueueSize() +} + +// Recv implements Receiver.Recv. +func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { + q.mu.Lock() + defer q.mu.Unlock() + + var notify bool + + // If we have no data in the endpoint, we need to get some. + if len(q.buffer) == 0 { + // Load the next message into a buffer, even if we are peeking. Peeking + // won't consume the message, so it will be still available to be read + // the next time Recv() is called. + m, n, err := q.readQueue.Dequeue() + if err != nil { + return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + } + notify = n + msg := m.(*message) + q.buffer = []byte(msg.Data) + q.control = msg.Control + q.addr = msg.Address + } + + var copied uintptr + if peek { + // Don't consume control message if we are peeking. + c := q.control.Clone() + + // Don't consume data since we are peeking. + copied, data, _ = vecCopy(data, q.buffer) + + return copied, copied, c, q.addr, notify, nil + } + + // Consume data and control message since we are not peeking. + copied, data, q.buffer = vecCopy(data, q.buffer) + + // Save the original state of q.control. + c := q.control + + // Remove rights from q.control and leave behind just the creds. + q.control.Rights = nil + if !wantCreds { + c.Credentials = nil + } + + if c.Rights != nil && numRights == 0 { + c.Rights.Release() + c.Rights = nil + } + + haveRights := c.Rights != nil + + // If we have more capacity for data and haven't received any usable + // rights. + // + // Linux never coalesces rights control messages. + for !haveRights && len(data) > 0 { + // Get a message from the readQueue. + m, n, err := q.readQueue.Dequeue() + if err != nil { + // We already got some data, so ignore this error. This will + // manifest as a short read to the user, which is what Linux + // does. + break + } + notify = notify || n + msg := m.(*message) + q.buffer = []byte(msg.Data) + q.control = msg.Control + q.addr = msg.Address + + if wantCreds { + if (q.control.Credentials == nil) != (c.Credentials == nil) { + // One message has credentials, the other does not. + break + } + + if q.control.Credentials != nil && c.Credentials != nil && !q.control.Credentials.Equals(c.Credentials) { + // Both messages have credentials, but they don't match. + break + } + } + + if numRights != 0 && c.Rights != nil && q.control.Rights != nil { + // Both messages have rights. + break + } + + var cpd uintptr + cpd, data, q.buffer = vecCopy(data, q.buffer) + copied += cpd + + if cpd == 0 { + // data was actually full. + break + } + + if q.control.Rights != nil { + // Consume rights. + if numRights == 0 { + q.control.Rights.Release() + } else { + c.Rights = q.control.Rights + haveRights = true + } + q.control.Rights = nil + } + } + return copied, copied, c, q.addr, notify, nil +} + +// A ConnectedEndpoint is an Endpoint that can be used to send Messages. +type ConnectedEndpoint interface { + // Passcred implements Endpoint.Passcred. + Passcred() bool + + // GetLocalAddress implements Endpoint.GetLocalAddress. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // Send sends a single message. This method does not block. + // + // notify indicates if SendNotify should be called. + // + // tcpip.ErrWouldBlock can be returned along with a partial write if + // the caller should block to send the rest of the data. + Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *tcpip.Error) + + // SendNotify notifies the ConnectedEndpoint of a successful Send. This + // must not be called while holding any endpoint locks. + SendNotify() + + // CloseSend prevents the sending of additional Messages. + // + // After CloseSend is call, CloseNotify must also be called. + CloseSend() + + // CloseNotify notifies the ConnectedEndpoint of send being closed. This + // must not be called while holding any endpoint locks. + CloseNotify() + + // Writable returns if messages should be attempted to be sent. This + // includes when write has been shutdown. + Writable() bool + + // EventUpdate lets the ConnectedEndpoint know that event registrations + // have changed. + EventUpdate() + + // SendQueuedSize returns the total amount of data currently queued for + // sending. SendQueuedSize should return -1 if the operation isn't + // supported. + SendQueuedSize() int64 + + // SendMaxQueueSize returns maximum value for SendQueuedSize. + // SendMaxQueueSize should return -1 if the operation isn't supported. + SendMaxQueueSize() int64 + + // Release releases any resources owned by the ConnectedEndpoint. It should + // be called before droping all references to a ConnectedEndpoint. + Release() +} + +// +stateify savable +type connectedEndpoint struct { + // endpoint represents the subset of the Endpoint functionality needed by + // the connectedEndpoint. It is implemented by both connectionedEndpoint + // and connectionlessEndpoint and allows the use of types which don't + // fully implement Endpoint. + endpoint interface { + // Passcred implements Endpoint.Passcred. + Passcred() bool + + // GetLocalAddress implements Endpoint.GetLocalAddress. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // Type implements Endpoint.Type. + Type() SockType + } + + writeQueue *queue.Queue +} + +// Passcred implements ConnectedEndpoint.Passcred. +func (e *connectedEndpoint) Passcred() bool { + return e.endpoint.Passcred() +} + +// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress. +func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return e.endpoint.GetLocalAddress() +} + +// Send implements ConnectedEndpoint.Send. +func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { + var l int64 + for _, d := range data { + l += int64(len(d)) + } + + truncate := false + if e.endpoint.Type() == SockStream { + // Since stream sockets don't preserve message boundaries, we + // can write only as much of the message as fits in the queue. + truncate = true + + // Discard empty stream packets. Since stream sockets don't + // preserve message boundaries, sending zero bytes is a no-op. + // In Linux, the receiver actually uses a zero-length receive + // as an indication that the stream was closed. + if l == 0 { + controlMessages.Release() + return 0, false, nil + } + } + + v := make([]byte, 0, l) + for _, d := range data { + v = append(v, d...) + } + + l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate) + return uintptr(l), notify, err +} + +// SendNotify implements ConnectedEndpoint.SendNotify. +func (e *connectedEndpoint) SendNotify() { + e.writeQueue.ReaderQueue.Notify(waiter.EventIn) +} + +// CloseNotify implements ConnectedEndpoint.CloseNotify. +func (e *connectedEndpoint) CloseNotify() { + e.writeQueue.ReaderQueue.Notify(waiter.EventIn) + e.writeQueue.WriterQueue.Notify(waiter.EventOut) +} + +// CloseSend implements ConnectedEndpoint.CloseSend. +func (e *connectedEndpoint) CloseSend() { + e.writeQueue.Close() +} + +// Writable implements ConnectedEndpoint.Writable. +func (e *connectedEndpoint) Writable() bool { + return e.writeQueue.IsWritable() +} + +// EventUpdate implements ConnectedEndpoint.EventUpdate. +func (*connectedEndpoint) EventUpdate() {} + +// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize. +func (e *connectedEndpoint) SendQueuedSize() int64 { + return e.writeQueue.QueuedSize() +} + +// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize. +func (e *connectedEndpoint) SendMaxQueueSize() int64 { + return e.writeQueue.MaxQueueSize() +} + +// Release implements ConnectedEndpoint.Release. +func (*connectedEndpoint) Release() {} + +// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless +// unix domain socket Endpoint implementations. +// +// Not to be used on its own. +// +// +stateify savable +type baseEndpoint struct { + *waiter.Queue + + // passcred specifies whether SCM_CREDENTIALS socket control messages are + // enabled on this endpoint. Must be accessed atomically. + passcred int32 + + // Mutex protects the below fields. + sync.Mutex `state:"nosave"` + + // receiver allows Messages to be received. + receiver Receiver + + // connected allows messages to be sent and state information about the + // connected endpoint to be read. + connected ConnectedEndpoint + + // path is not empty if the endpoint has been bound, + // or may be used if the endpoint is connected. + path string +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) { + e.Queue.EventRegister(we, mask) + e.Lock() + if e.connected != nil { + e.connected.EventUpdate() + } + e.Unlock() +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { + e.Queue.EventUnregister(we) + e.Lock() + if e.connected != nil { + e.connected.EventUpdate() + } + e.Unlock() +} + +// Passcred implements Credentialer.Passcred. +func (e *baseEndpoint) Passcred() bool { + return atomic.LoadInt32(&e.passcred) != 0 +} + +// ConnectedPasscred implements Credentialer.ConnectedPasscred. +func (e *baseEndpoint) ConnectedPasscred() bool { + e.Lock() + defer e.Unlock() + return e.connected != nil && e.connected.Passcred() +} + +func (e *baseEndpoint) setPasscred(pc bool) { + if pc { + atomic.StoreInt32(&e.passcred, 1) + } else { + atomic.StoreInt32(&e.passcred, 0) + } +} + +// Connected implements ConnectingEndpoint.Connected. +func (e *baseEndpoint) Connected() bool { + return e.receiver != nil && e.connected != nil +} + +// RecvMsg reads data and a control message from the endpoint. +func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *tcpip.Error) { + e.Lock() + + if e.receiver == nil { + e.Unlock() + return 0, 0, ControlMessages{}, tcpip.ErrNotConnected + } + + recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + e.Unlock() + if err != nil { + return 0, 0, ControlMessages{}, err + } + + if notify { + e.receiver.RecvNotify() + } + + if addr != nil { + *addr = a + } + return recvLen, msgLen, cms, nil +} + +// SendMsg writes data and a control message to the endpoint's peer. +// This method does not block if the data cannot be written. +func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) { + e.Lock() + if !e.Connected() { + e.Unlock() + return 0, tcpip.ErrNotConnected + } + if to != nil { + e.Unlock() + return 0, tcpip.ErrAlreadyConnected + } + + n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + e.Unlock() + + if notify { + e.connected.SendNotify() + } + + return n, err +} + +// SetSockOpt sets a socket option. Currently not supported. +func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.PasscredOption: + e.setPasscred(v != 0) + return nil + } + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + case *tcpip.SendQueueSizeOption: + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + case *tcpip.ReceiveQueueSizeOption: + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) + } + return nil + case *tcpip.SendBufferSizeOption: + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + case *tcpip.ReceiveBufferSizeOption: + e.Lock() + if e.receiver == nil { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + } + return tcpip.ErrUnknownProtocolOption +} + +// Shutdown closes the read and/or write end of the endpoint connection to its +// peer. +func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + + if flags&tcpip.ShutdownRead != 0 { + e.receiver.CloseRecv() + } + + if flags&tcpip.ShutdownWrite != 0 { + e.connected.CloseSend() + } + + e.Unlock() + + if flags&tcpip.ShutdownRead != 0 { + e.receiver.CloseNotify() + } + + if flags&tcpip.ShutdownWrite != 0 { + e.connected.CloseNotify() + } + + return nil +} + +// GetLocalAddress returns the bound path. +func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.Lock() + defer e.Unlock() + return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil +} + +// GetRemoteAddress returns the local address of the connected endpoint (if +// available). +func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.Lock() + c := e.connected + e.Unlock() + if c != nil { + return c.GetLocalAddress() + } + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Release implements BoundEndpoint.Release. +func (*baseEndpoint) Release() {} diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index e30378e60..668363864 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -32,16 +32,16 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/epsocket" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" "gvisor.googlesource.com/gvisor/pkg/tcpip" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" ) -// SocketOperations is a Unix socket. It is similar to an epsocket, except it is backed -// by a unix.Endpoint instead of a tcpip.Endpoint. +// SocketOperations is a Unix socket. It is similar to an epsocket, except it +// is backed by a transport.Endpoint instead of a tcpip.Endpoint. // // +stateify savable type SocketOperations struct { @@ -52,18 +52,18 @@ type SocketOperations struct { fsutil.NoFsync `state:"nosave"` fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` - ep unix.Endpoint + ep transport.Endpoint } // New creates a new unix socket. -func New(ctx context.Context, endpoint unix.Endpoint) *fs.File { +func New(ctx context.Context, endpoint transport.Endpoint) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) defer dirent.DecRef() return NewWithDirent(ctx, dirent, endpoint, fs.FileFlags{Read: true, Write: true}) } // NewWithDirent creates a new unix socket using an existing dirent. -func NewWithDirent(ctx context.Context, d *fs.Dirent, ep unix.Endpoint, flags fs.FileFlags) *fs.File { +func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, flags fs.FileFlags) *fs.File { return fs.NewFile(ctx, d, flags, &SocketOperations{ ep: ep, }) @@ -83,8 +83,8 @@ func (s *SocketOperations) Release() { s.DecRef() } -// Endpoint extracts the unix.Endpoint. -func (s *SocketOperations) Endpoint() unix.Endpoint { +// Endpoint extracts the transport.Endpoint. +func (s *SocketOperations) Endpoint() transport.Endpoint { return s.ep } @@ -110,7 +110,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { } // GetPeerName implements the linux syscall getpeername(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { @@ -122,7 +122,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy } // GetSockName implements the linux syscall getsockname(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { @@ -139,20 +139,20 @@ func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.S } // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) { return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { return syserr.TranslateNetstackError(s.ep.Listen(backlog)) } // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (unix.Endpoint, *syserr.Error) { +func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -172,7 +172,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (unix.Endpoint, *syser } // Accept implements the linux syscall accept(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, err := s.ep.Accept() @@ -226,7 +226,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { return e } - bep, ok := s.ep.(unix.BoundEndpoint) + bep, ok := s.ep.(transport.BoundEndpoint) if !ok { // This socket can't be bound. return syserr.ErrInvalidArgument @@ -287,10 +287,10 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { })) } -// extractEndpoint retrieves the unix.BoundEndpoint associated with a Unix -// socket path. The Release must be called on the unix.BoundEndpoint when the -// caller is done with it. -func extractEndpoint(t *kernel.Task, sockaddr []byte) (unix.BoundEndpoint, *syserr.Error) { +// extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix +// socket path. The Release must be called on the transport.BoundEndpoint when +// the caller is done with it. +func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, *syserr.Error) { path, err := extractPath(sockaddr) if err != nil { return nil, err @@ -362,7 +362,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ Endpoint: s.ep, @@ -408,12 +408,12 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return int(total), syserr.FromError(err) } -// Passcred implements unix.Credentialer.Passcred. +// Passcred implements transport.Credentialer.Passcred. func (s *SocketOperations) Passcred() bool { return s.ep.Passcred() } -// ConnectedPasscred implements unix.Credentialer.ConnectedPasscred. +// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. func (s *SocketOperations) ConnectedPasscred() bool { return s.ep.ConnectedPasscred() } @@ -434,13 +434,13 @@ func (s *SocketOperations) EventUnregister(e *waiter.Entry) { } // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal) } // Shutdown implements the linux syscall shutdown(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := epsocket.ConvertShutdown(how) if err != nil { @@ -465,7 +465,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by -// a unix.Endpoint. +// a transport.Endpoint. func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 @@ -539,19 +539,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags type provider struct{} // Socket returns a new unix domain socket. -func (*provider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { // Check arguments. if protocol != 0 { return nil, syserr.ErrInvalidArgument } // Create the endpoint and socket. - var ep unix.Endpoint + var ep transport.Endpoint switch stype { case linux.SOCK_DGRAM: - ep = unix.NewConnectionless() + ep = transport.NewConnectionless() case linux.SOCK_STREAM, linux.SOCK_SEQPACKET: - ep = unix.NewConnectioned(stype, t.Kernel()) + ep = transport.NewConnectioned(stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } @@ -560,7 +560,7 @@ func (*provider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs. } // Pair creates a new pair of AF_UNIX connected sockets. -func (*provider) Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Check arguments. if protocol != 0 { return nil, nil, syserr.ErrInvalidArgument @@ -573,7 +573,7 @@ func (*provider) Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.Fi } // Create the endpoints and sockets. - ep1, ep2 := unix.NewPair(stype, t.Kernel()) + ep1, ep2 := transport.NewPair(stype, t.Kernel()) s1 := New(t, ep1) s2 := New(t, ep2) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index bbdfad9da..7621bfdbd 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -79,11 +79,11 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/syscalls", "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", - "//pkg/tcpip/transport/unix", "//pkg/waiter", ], ) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 867fec468..5fa5ddce6 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -27,9 +27,9 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -180,7 +180,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Create the new socket. - s, e := socket.New(t, domain, unix.SockType(stype&0xf), protocol) + s, e := socket.New(t, domain, transport.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } @@ -219,7 +219,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Create the socket pair. - s1, s2, e := socket.Pair(t, domain, unix.SockType(stype&0xf), protocol) + s1, s2, e := socket.Pair(t, domain, transport.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } @@ -750,7 +750,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i controlData := make([]byte, 0, msg.ControlLen) - if cr, ok := s.(unix.Credentialer); ok && cr.Passcred() { + if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) controlData = control.PackCredentials(t, creds, controlData) } diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD index ff50b9925..68b82af47 100644 --- a/pkg/sentry/uniqueid/BUILD +++ b/pkg/sentry/uniqueid/BUILD @@ -9,6 +9,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", - "//pkg/tcpip/transport/unix", + "//pkg/sentry/socket/unix/transport", ], ) diff --git a/pkg/sentry/uniqueid/context.go b/pkg/sentry/uniqueid/context.go index 541e0611d..e48fabc2d 100644 --- a/pkg/sentry/uniqueid/context.go +++ b/pkg/sentry/uniqueid/context.go @@ -18,7 +18,7 @@ package uniqueid import ( "gvisor.googlesource.com/gvisor/pkg/sentry/context" - "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) // contextID is the kernel package's type for context.Context.Value keys. @@ -44,8 +44,8 @@ func GlobalFromContext(ctx context.Context) uint64 { } // GlobalProviderFromContext returns a system-wide unique identifier from ctx. -func GlobalProviderFromContext(ctx context.Context) unix.UniqueIDProvider { - return ctx.Value(CtxGlobalUniqueIDProvider).(unix.UniqueIDProvider) +func GlobalProviderFromContext(ctx context.Context) transport.UniqueIDProvider { + return ctx.Value(CtxGlobalUniqueIDProvider).(transport.UniqueIDProvider) } // InotifyCookie generates a unique inotify event cookie from ctx. |