diff options
-rw-r--r-- | pkg/sentry/socket/socket.go | 87 |
1 files changed, 85 insertions, 2 deletions
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 50d9744e6..b5ba4a56b 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" @@ -48,11 +49,25 @@ func (c *ControlMessages) Release() { c.Unix.Release() } -// Socket is the interface containing socket syscalls used by the syscall layer -// to redirect them to the appropriate implementation. +// Socket is an interface combining fs.FileOperations and SocketOps, +// representing a VFS1 socket file. type Socket interface { fs.FileOperations + SocketOps +} + +// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps, +// representing a VFS2 socket file. +type SocketVFS2 interface { + vfs.FileDescriptionImpl + SocketOps +} +// SocketOps is the interface containing socket syscalls used by the syscall +// layer to redirect them to the appropriate implementation. +// +// It is implemented by both Socket and SocketVFS2. +type SocketOps interface { // Connect implements the connect(2) linux syscall. Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error @@ -153,6 +168,8 @@ var families = make(map[int][]Provider) // RegisterProvider registers the provider of a given address family so that // sockets of that type can be created via socket() and/or socketpair() // syscalls. +// +// This should only be called during the initialization of the address family. func RegisterProvider(family int, provider Provider) { families[family] = append(families[family], provider) } @@ -216,6 +233,72 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent { return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino)) } +// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for +// specific address families (e.g., AF_INET). +type ProviderVFS2 interface { + // Socket creates a new socket. + // + // If a nil Socket _and_ a nil error is returned, it means that the + // protocol is not supported. A non-nil error should only be returned + // if the protocol is supported, but an error occurs during creation. + Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) + + // Pair creates a pair of connected sockets. + // + // See Socket for error information. + Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) +} + +// familiesVFS2 holds a map of all known address families and their providers. +var familiesVFS2 = make(map[int][]ProviderVFS2) + +// RegisterProviderVFS2 registers the provider of a given address family so that +// sockets of that type can be created via socket() and/or socketpair() +// syscalls. +// +// This should only be called during the initialization of the address family. +func RegisterProviderVFS2(family int, provider ProviderVFS2) { + familiesVFS2[family] = append(familiesVFS2[family], provider) +} + +// NewVFS2 creates a new socket with the given family, type and protocol. +func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + for _, p := range familiesVFS2[family] { + s, err := p.Socket(t, stype, protocol) + if err != nil { + return nil, err + } + if s != nil { + // TODO: Add vfs2 sockets to global socket table. + return s, nil + } + } + + return nil, syserr.ErrAddressFamilyNotSupported +} + +// PairVFS2 creates a new connected socket pair with the given family, type and +// protocol. +func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + providers, ok := familiesVFS2[family] + if !ok { + return nil, nil, syserr.ErrAddressFamilyNotSupported + } + + for _, p := range providers { + s1, s2, err := p.Pair(t, stype, protocol) + if err != nil { + return nil, nil, err + } + if s1 != nil && s2 != nil { + // TODO: Add vfs2 sockets to global socket table. + return s1, s2, nil + } + } + + return nil, nil, syserr.ErrSocketNotSupported +} + // SendReceiveTimeout stores timeouts for send and receive calls. // // It is meant to be embedded into Socket implementations to help satisfy the |