summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/hostinet/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/hostinet/socket.go')
-rw-r--r--pkg/sentry/socket/hostinet/socket.go105
1 files changed, 74 insertions, 31 deletions
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 22f78d2e2..b49433326 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -49,6 +50,8 @@ const (
maxControlLen = 1024
)
+// LINT.IfChange
+
// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
@@ -59,23 +62,37 @@ type socketOperations struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
socket.SendReceiveTimeout
family int // Read-only.
stype linux.SockType // Read-only.
protocol int // Read-only.
- fd int // must be O_NONBLOCK
queue waiter.Queue
+
+ // fd is the host socket fd. It must have O_NONBLOCK, so that operations
+ // will return EWOULDBLOCK instead of blocking on the host. This allows us to
+ // handle blocking behavior independently in the sentry.
+ fd int
}
var _ = socket.Socket(&socketOperations{})
func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
s := &socketOperations{
- family: family,
- stype: stype,
- protocol: protocol,
- fd: fd,
+ socketOpsCommon: socketOpsCommon{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ },
}
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
return nil, syserr.FromError(err)
@@ -86,28 +103,33 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc
}
// Release implements fs.FileOperations.Release.
-func (s *socketOperations) Release() {
+func (s *socketOpsCommon) Release() {
fdnotifier.RemoveFD(int32(s.fd))
syscall.Close(s.fd)
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.queue.EventRegister(e, mask)
fdnotifier.UpdateFD(int32(s.fd))
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *socketOperations) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.queue.EventUnregister(e)
fdnotifier.UpdateFD(int32(s.fd))
}
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return ioctl(ctx, s.fd, io, args)
+}
+
// Read implements fs.FileOperations.Read.
func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
@@ -155,7 +177,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
// Connect implements socket.Socket.Connect.
-func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
sockaddr = sockaddr[:sizeofSockaddr]
}
@@ -195,7 +217,7 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
var peerAddr linux.SockAddr
var peerAddrBuf []byte
var peerAddrlen uint32
@@ -209,7 +231,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
// Conservatively ignore all flags specified by the application and add
- // SOCK_NONBLOCK since socketOperations requires it.
+ // SOCK_NONBLOCK since socketOpsCommon requires it.
fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
if blocking {
var ch chan struct{}
@@ -235,23 +257,41 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
}
- f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
- if err != nil {
- syscall.Close(fd)
- return 0, nil, 0, err
- }
- defer f.DecRef()
+ var (
+ kfd int32
+ kerr error
+ )
+ if kernel.VFS2Enabled {
+ f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&syscall.SOCK_NONBLOCK))
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
- kfd, kerr := t.NewFDFrom(0, f, kernel.FDFlags{
- CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
- })
- t.Kernel().RecordSocket(f)
+ kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocketVFS2(f)
+ } else {
+ f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
+
+ kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocket(f)
+ }
return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
}
// Bind implements socket.Socket.Bind.
-func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
sockaddr = sockaddr[:sizeofSockaddr]
}
@@ -264,12 +304,12 @@ func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Listen implements socket.Socket.Listen.
-func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return syserr.FromError(syscall.Listen(s.fd, backlog))
}
// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
switch how {
case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR:
return syserr.FromError(syscall.Shutdown(s.fd, how))
@@ -279,7 +319,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -328,7 +368,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
}
// SetSockOpt implements socket.Socket.SetSockOpt.
-func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// Whitelist options and constrain option length.
optlen := setSockOptLen(t, level, name)
switch level {
@@ -374,7 +414,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
@@ -496,7 +536,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Whitelist flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
@@ -585,7 +625,7 @@ func translateIOSyscallError(err error) error {
}
// State implements socket.Socket.State.
-func (s *socketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
info := linux.TCPInfo{}
buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo)
if err != nil {
@@ -607,7 +647,7 @@ func (s *socketOperations) State() uint32 {
}
// Type implements socket.Socket.Type.
-func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
return s.family, s.stype, s.protocol
}
@@ -663,8 +703,11 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int
return nil, nil, nil
}
+// LINT.ThenChange(./socket_vfs2.go)
+
func init() {
for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
socket.RegisterProvider(family, &socketProvider{family})
+ socket.RegisterProviderVFS2(family, &socketProviderVFS2{})
}
}