diff options
Diffstat (limited to 'pkg/sentry/socket/unix/unix.go')
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 215 |
1 files changed, 158 insertions, 57 deletions
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1aaae8487..b7e8e4325 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -22,9 +22,9 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -33,11 +33,13 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -52,17 +54,14 @@ type SocketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` - refs.AtomicRefCount - socket.SendReceiveTimeout - ep transport.Endpoint - stype linux.SockType + socketOpsCommon } // New creates a new unix socket. func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true}) } @@ -75,29 +74,51 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty } s := SocketOperations{ - ep: ep, - stype: stype, + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, } - s.EnableLeakCheck("unix.SocketOperations") + s.EnableLeakCheck() return fs.NewFile(ctx, d, flags, &s) } +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { + socketOpsCommonRefs + socket.SendReceiveTimeout + + ep transport.Endpoint + stype linux.SockType + + // abstractName and abstractNamespace indicate the name and namespace of the + // socket if it is bound to an abstract socket namespace. Once the socket is + // bound, they cannot be modified. + abstractName string + abstractNamespace *kernel.AbstractSocketNamespace +} + // DecRef implements RefCounter.DecRef. -func (s *SocketOperations) DecRef() { - s.DecRefWithDestructor(func() { - s.ep.Close() +func (s *socketOpsCommon) DecRef(ctx context.Context) { + s.socketOpsCommonRefs.DecRef(func() { + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } }) } // Release implemements fs.FileOperations.Release. -func (s *SocketOperations) Release() { +func (s *socketOpsCommon) Release(ctx context.Context) { // Release only decrements a reference on s because s may be referenced in // the abstract socket namespace. - s.DecRef() + s.DecRef(ctx) } -func (s *SocketOperations) isPacket() bool { +func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: return true @@ -110,16 +131,22 @@ func (s *SocketOperations) isPacket() bool { } // Endpoint extracts the transport.Endpoint. -func (s *SocketOperations) Endpoint() transport.Endpoint { +func (s *socketOpsCommon) Endpoint() transport.Endpoint { return s.ep } // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, family, err := netstack.AddressAndFamily(sockaddr) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } + if family != linux.AF_UNIX { + return "", syserr.ErrInvalidArgument + } // The address is trimmed by GetAddress. p := string(addr.Addr) @@ -137,7 +164,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -149,7 +176,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -166,13 +193,13 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return s.ep.Listen(backlog) } @@ -215,7 +242,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } ns := New(t, ep, s.stype) - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -265,17 +292,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + asn := t.AbstractSockets() + name := p[1:] + if err := asn.Bind(t, name, bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } + s.abstractName = name + s.abstractNamespace = asn } else { // The parent and name. var d *fs.Dirent var name string cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) // Is there no slash at all? if !strings.Contains(p, "/") { @@ -283,7 +314,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { name = p } else { root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Find the last path component, we know that something follows // that final slash, otherwise extractPath() would have failed. lastSlash := strings.LastIndex(p, "/") @@ -299,16 +330,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // No path available. return syserr.ErrNoSuchFile } - defer d.DecRef() + defer d.DecRef(t) name = p[lastSlash+1:] } // Create the socket. + // + // Note that the file permissions here are not set correctly (see + // gvisor.dev/issue/2324). There is no convenient way to get permissions + // on the socket referred to by s, so we will leave this discrepancy + // unresolved until VFS2 replaces this code. childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}}) if err != nil { return syserr.ErrPortInUse } - childDir.DecRef() + childDir.DecRef(t) } return nil @@ -339,41 +375,76 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, return ep, nil } + if kernel.VFS2Enabled { + p := fspath.Parse(path) + root := t.FSContext().RootDirectoryVFS2() + start := root + relPath := !p.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: p, + FollowFinalSymlink: true, + } + ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path}) + root.DecRef(t) + if relPath { + start.DecRef(t) + } + if e != nil { + return nil, syserr.FromError(e) + } + return ep, nil + } + // Find the node in the filesystem. root := t.FSContext().RootDirectory() cwd := t.FSContext().WorkingDirectory() remainingTraversals := uint(fs.DefaultTraversalLimit) d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals) - cwd.DecRef() - root.DecRef() + cwd.DecRef(t) + root.DecRef(t) if e != nil { return nil, syserr.FromError(e) } // Extract the endpoint if one is there. ep := d.Inode.BoundEndpoint(path) - d.DecRef() + d.DecRef(t) if ep == nil { // No socket! return nil, syserr.ErrConnectionRefused } - return ep, nil } // Connect implements the linux syscall connect(2) for unix sockets. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { ep, err := extractEndpoint(t, sockaddr) if err != nil { return err } - defer ep.Release() + defer ep.Release(t) // Connect the server endpoint. - return s.ep.Connect(t, ep) + err = s.ep.Connect(t, ep) + + if err == syserr.ErrWrongProtocolForSocket { + // Linux for abstract sockets returns ErrConnectionRefused + // instead of ErrWrongProtocolForSocket. + path, _ := extractPath(sockaddr) + if len(path) > 0 && path[0] == 0 { + err = syserr.ErrConnectionRefused + } + } + + return err } -// Writev implements fs.FileOperations.Write. +// Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { t := kernel.TaskFromContext(ctx) ctrl := control.New(t, s.ep, nil) @@ -393,7 +464,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by // a transport.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ Ctx: t, Endpoint: s.ep, @@ -401,15 +472,25 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] To: nil, } if len(to) > 0 { - ep, err := extractEndpoint(t, to) - if err != nil { - return 0, err - } - defer ep.Release() - w.To = ep + switch s.stype { + case linux.SOCK_SEQPACKET: + to = nil + case linux.SOCK_STREAM: + if s.State() == linux.SS_CONNECTED { + return 0, syserr.ErrAlreadyConnected + } + return 0, syserr.ErrNotSupported + default: + ep, err := extractEndpoint(t, to) + if err != nil { + return 0, err + } + defer ep.Release(t) + w.To = ep - if ep.Passcred() && w.Control.Credentials == nil { - w.Control.Credentials = control.MakeCreds(t) + if ep.Passcred() && w.Control.Credentials == nil { + w.Control.Credentials = control.MakeCreds(t) + } } } @@ -447,27 +528,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } // Passcred implements transport.Credentialer.Passcred. -func (s *SocketOperations) Passcred() bool { +func (s *socketOpsCommon) Passcred() bool { return s.ep.Passcred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. -func (s *SocketOperations) ConnectedPasscred() bool { +func (s *socketOpsCommon) ConnectedPasscred() bool { return s.ep.ConnectedPasscred() } // Readiness implements waiter.Waitable.Readiness. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.ep.Readiness(mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.ep.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *SocketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } @@ -479,7 +560,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := netstack.ConvertShutdown(how) if err != nil { return err @@ -505,7 +586,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -541,8 +622,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if senderRequested { r.From = &tcpip.FullAddress{} } + + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. + if trunc && dst.Addrs.NumBytes() == 0 { + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } + + } + var total int64 - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { + if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait { var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { @@ -577,7 +677,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var from linux.SockAddr var fromLen uint32 if r.From != nil { @@ -623,12 +723,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } // State implements socket.Socket.State. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { return s.ep.State() } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { // Unix domain sockets always have a protocol of 0. return linux.AF_UNIX, s.stype, 0 } @@ -681,4 +781,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F func init() { socket.RegisterProvider(linux.AF_UNIX, &provider{}) + socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{}) } |