diff options
Diffstat (limited to 'pkg/sentry/socket/hostinet')
-rw-r--r-- | pkg/sentry/socket/hostinet/BUILD | 20 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 273 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket_unsafe.go | 14 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket_vfs2.go | 203 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/sockopt_impl.go | 27 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/stack.go | 37 |
6 files changed, 472 insertions, 102 deletions
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index c1b20eaf8..8448ea401 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -10,31 +10,41 @@ go_library( "save_restore.go", "socket.go", "socket_unsafe.go", + "socket_vfs2.go", + "sockopt_impl.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/hostinet", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/fdnotifier", "//pkg/log", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", + "//pkg/sentry/fsimpl/sockfs", + "//pkg/sentry/hostfd", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/usermem", + "//pkg/sentry/socket/control", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 92beb1bcf..242e6bf76 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -18,21 +18,26 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -41,8 +46,14 @@ const ( // sizeofSockaddr is the size in bytes of the largest sockaddr type // supported by this package. sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in) + + // maxControlLen is the maximum size of a control message buffer used in a + // recvmsg or sendmsg syscall. + maxControlLen = 1024 ) +// LINT.IfChange + // socketOperations implements fs.FileOperations and socket.Socket for a socket // implemented using a host socket. type socketOperations struct { @@ -53,55 +64,74 @@ type socketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout family int // Read-only. stype linux.SockType // Read-only. protocol int // Read-only. - fd int // must be O_NONBLOCK queue waiter.Queue + + // fd is the host socket fd. It must have O_NONBLOCK, so that operations + // will return EWOULDBLOCK instead of blocking on the host. This allows us to + // handle blocking behavior independently in the sentry. + fd int } var _ = socket.Socket(&socketOperations{}) func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { s := &socketOperations{ - family: family, - stype: stype, - protocol: protocol, - fd: fd, + socketOpsCommon: socketOpsCommon{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + }, } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } dirent := socket.NewDirent(ctx, socketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil } // Release implements fs.FileOperations.Release. -func (s *socketOperations) Release() { +func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) syscall.Close(s.fd) } // Readiness implements waiter.Waitable.Readiness. -func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return fdnotifier.NonBlockingPoll(int32(s.fd), mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.queue.EventRegister(e, mask) fdnotifier.UpdateFD(int32(s.fd)) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *socketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.queue.EventUnregister(e) fdnotifier.UpdateFD(int32(s.fd)) } +// Ioctl implements fs.FileOperations.Ioctl. +func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return ioctl(ctx, s.fd, io, args) +} + // Read implements fs.FileOperations.Read. func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -120,7 +150,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } return uint64(n), nil } - return readv(s.fd, iovecsFromBlockSeq(dsts)) + return readv(s.fd, safemem.IovecsFromBlockSeq(dsts)) })) return int64(n), err } @@ -143,13 +173,13 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } return uint64(n), nil } - return writev(s.fd, iovecsFromBlockSeq(srcs)) + return writev(s.fd, safemem.IovecsFromBlockSeq(srcs)) })) return int64(n), err } // Connect implements socket.Socket.Connect. -func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -189,7 +219,7 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { var peerAddr linux.SockAddr var peerAddrBuf []byte var peerAddrlen uint32 @@ -203,7 +233,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } // Conservatively ignore all flags specified by the application and add - // SOCK_NONBLOCK since socketOperations requires it. + // SOCK_NONBLOCK since socketOpsCommon requires it. fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) if blocking { var ch chan struct{} @@ -229,23 +259,41 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) - if err != nil { - syscall.Close(fd) - return 0, nil, 0, err - } - defer f.DecRef() + var ( + kfd int32 + kerr error + ) + if kernel.VFS2Enabled { + f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&syscall.SOCK_NONBLOCK)) + if err != nil { + syscall.Close(fd) + return 0, nil, 0, err + } + defer f.DecRef(t) - kfd, kerr := t.NewFDFrom(0, f, kernel.FDFlags{ - CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, - }) - t.Kernel().RecordSocket(f) + kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{ + CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, + }) + t.Kernel().RecordSocketVFS2(f) + } else { + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) + if err != nil { + syscall.Close(fd) + return 0, nil, 0, err + } + defer f.DecRef(t) + + kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{ + CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, + }) + t.Kernel().RecordSocket(f) + } return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } // Bind implements socket.Socket.Bind. -func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -258,12 +306,12 @@ func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Listen implements socket.Socket.Listen. -func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return syserr.FromError(syscall.Listen(s.fd, backlog)) } // Shutdown implements socket.Socket.Shutdown. -func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { switch how { case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR: return syserr.FromError(syscall.Shutdown(s.fd, how)) @@ -273,34 +321,40 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } - // Whitelist options and constrain option length. - var optlen int + // Only allow known and safe options. + optlen := getSockOptLen(t, level, name) switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 - case syscall.SO_LINGER: + case linux.SO_LINGER: optlen = syscall.SizeofLinger } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 - case syscall.TCP_INFO: + case linux.TCP_INFO: optlen = int(linux.SizeOfTCPInfo) } } + if optlen == 0 { return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT } @@ -312,30 +366,39 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. -func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { - // Whitelist options and constrain option length. - var optlen int +func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { + // Only allow known and safe options. + optlen := setSockOptLen(t, level, name) switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: switch name { - case syscall.IPV6_V6ONLY: + case linux.IP_TOS, linux.IP_RECVTOS: optlen = sizeofInt32 + case linux.IP_PKTINFO: + optlen = linux.SizeOfControlMessageIPPacketInfo } - case syscall.SOL_SOCKET: + case linux.SOL_IPV6: switch name { - case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_TCP: + case linux.SOL_SOCKET: switch name { - case syscall.TCP_NODELAY: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + optlen = sizeofInt32 + } + case linux.SOL_TCP: + switch name { + case linux.TCP_NODELAY: optlen = sizeofInt32 } } + if optlen == 0 { // Pretend to accept socket options we don't understand. This seems // dangerous, but it's what netstack does... @@ -354,11 +417,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - // Whitelist flags. +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + // Only allow known and safe flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that netstack/tcpip/transport/unix doesn't understand. Kill the + // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the // Socket interface's dependence on netstack. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument @@ -370,6 +433,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags senderAddrBuf = make([]byte, sizeofSockaddr) } + var controlBuf []byte var msgFlags int recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -384,12 +448,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We always do a non-blocking recv*(). sysflags := flags | syscall.MSG_DONTWAIT - if dsts.NumBlocks() == 1 { - // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) - } - - iovs := iovecsFromBlockSeq(dsts) + iovs := safemem.IovecsFromBlockSeq(dsts) msg := syscall.Msghdr{ Iov: &iovs[0], Iovlen: uint64(len(iovs)), @@ -398,12 +457,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msg.Name = &senderAddrBuf[0] msg.Namelen = uint32(len(senderAddrBuf)) } + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) + controlLen = uint64(msg.Controllen) return n, nil }) @@ -429,36 +497,75 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags n, err = dst.CopyOutFrom(t, recvmsgToBlocks) } } - - // We don't allow control messages. - msgFlags &^= linux.MSG_CTRUNC + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) + + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } + + controlMessages := socket.ControlMessages{} + for _, unixCmsg := range unixControlMessages { + switch unixCmsg.Header.Level { + case syscall.SOL_IP: + switch unixCmsg.Header.Type { + case syscall.IP_TOS: + controlMessages.IP.HasTOS = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + + case syscall.IP_PKTINFO: + controlMessages.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + } + + case syscall.SOL_IPV6: + switch unixCmsg.Header.Type { + case syscall.IPV6_TCLASS: + controlMessages.IP.HasTClass = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + } + } + } + + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil } // SendMsg implements socket.Socket.SendMsg. -func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { - // Whitelist flags. +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { + // Only allow known and safe flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument } + space := uint64(control.CmsgsSpace(t, controlMessages)) + if space > maxControlLen { + space = maxControlLen + } + controlBuf := make([]byte, 0, space) + // PackControlMessages will append up to space bytes to controlBuf. + controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. if uint64(src.NumBytes()) != srcs.NumBytes() { return 0, nil } - if srcs.IsEmpty() { + if srcs.IsEmpty() && len(controlBuf) == 0 { return 0, nil } // We always do a non-blocking send*(). sysflags := flags | syscall.MSG_DONTWAIT - if srcs.NumBlocks() == 1 { + if srcs.NumBlocks() == 1 && len(controlBuf) == 0 { // Skip allocating []syscall.Iovec. src := srcs.Head() n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to))) @@ -468,7 +575,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return uint64(n), nil } - iovs := iovecsFromBlockSeq(srcs) + iovs := safemem.IovecsFromBlockSeq(srcs) msg := syscall.Msghdr{ Iov: &iovs[0], Iovlen: uint64(len(iovs)), @@ -477,6 +584,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] msg.Name = &to[0] msg.Namelen = uint32(len(to)) } + if len(controlBuf) != 0 { + msg.Control = &controlBuf[0] + msg.Controllen = uint64(len(controlBuf)) + } return sendmsg(s.fd, &msg, sysflags) }) @@ -509,21 +620,6 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return int(n), syserr.FromError(err) } -func iovecsFromBlockSeq(bs safemem.BlockSeq) []syscall.Iovec { - iovs := make([]syscall.Iovec, 0, bs.NumBlocks()) - for ; !bs.IsEmpty(); bs = bs.Tail() { - b := bs.Head() - iovs = append(iovs, syscall.Iovec{ - Base: &b.ToSlice()[0], - Len: uint64(b.Len()), - }) - // We don't need to care about b.NeedSafecopy(), because the host - // kernel will handle such address ranges just fine (by returning - // EFAULT). - } - return iovs -} - func translateIOSyscallError(err error) error { if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK { return syserror.ErrWouldBlock @@ -532,7 +628,7 @@ func translateIOSyscallError(err error) error { } // State implements socket.Socket.State. -func (s *socketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { info := linux.TCPInfo{} buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo) if err != nil { @@ -554,7 +650,7 @@ func (s *socketOperations) State() uint32 { } // Type implements socket.Socket.Type. -func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return s.family, s.stype, s.protocol } @@ -610,8 +706,11 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int return nil, nil, nil } +// LINT.ThenChange(./socket_vfs2.go) + func init() { for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) + socket.RegisterProviderVFS2(family, &socketProviderVFS2{family}) } } diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index e69ec38c2..3f420c2ec 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -19,14 +19,13 @@ import ( "unsafe" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) func firstBytePtr(bs []byte) unsafe.Pointer { @@ -54,12 +53,11 @@ func writev(fd int, srcs []syscall.Iovec) (uint64, error) { return uint64(n), nil } -// Ioctl implements fs.FileOperations.Ioctl. -func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := uintptr(args[1].Int()); cmd { case syscall.TIOCINQ, syscall.TIOCOUTQ: var val int32 - if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(s.fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 { + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 { return 0, translateIOSyscallError(errno) } var buf [4]byte @@ -93,7 +91,7 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) { } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) @@ -104,7 +102,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go new file mode 100644 index 000000000..8a1d52ebf --- /dev/null +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -0,0 +1,203 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostinet + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +type socketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + // We store metadata for hostinet sockets internally. Technically, we should + // access metadata (e.g. through stat, chmod) on the host for correctness, + // but this is not very useful for inet socket fds, which do not belong to a + // concrete file anyway. + vfs.DentryMetadataFileDescriptionImpl + + socketOpsCommon +} + +var _ = socket.SocketVFS2(&socketVFS2{}) + +func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { + mnt := t.Kernel().SocketMount() + d := sockfs.NewDentry(t.Credentials(), mnt) + + s := &socketVFS2{ + socketOpsCommon: socketOpsCommon{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + }, + } + s.LockFD.Init(&vfs.FileLocks{}) + if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { + return nil, syserr.FromError(err) + } + vfsfd := &s.vfsfd + if err := vfsfd.Init(s, linux.O_RDWR|(flags&linux.O_NONBLOCK), mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + fdnotifier.RemoveFD(int32(s.fd)) + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *socketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *socketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return ioctl(ctx, s.fd, uio, args) +} + +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ENODEV +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + reader := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags) + n, err := dst.CopyOutFrom(ctx, reader) + hostfd.PutReadWriterAt(reader) + return int64(n), err +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *socketVFS2) PWrite(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + writer := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags) + n, err := src.CopyInTo(ctx, writer) + hostfd.PutReadWriterAt(writer) + return int64(n), err +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + +type socketProviderVFS2 struct { + family int +} + +// Socket implements socket.ProviderVFS2.Socket. +func (p *socketProviderVFS2) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Check that we are using the host network stack. + stack := t.NetworkContext() + if stack == nil { + return nil, nil + } + if _, ok := stack.(*Stack); !ok { + return nil, nil + } + + // Only accept TCP and UDP. + stype := stypeflags & linux.SOCK_TYPE_MASK + switch stype { + case syscall.SOCK_STREAM: + switch protocol { + case 0, syscall.IPPROTO_TCP: + // ok + default: + return nil, nil + } + case syscall.SOCK_DGRAM: + switch protocol { + case 0, syscall.IPPROTO_UDP: + // ok + default: + return nil, nil + } + default: + return nil, nil + } + + // Conservatively ignore all flags specified by the application and add + // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 + // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + if err != nil { + return nil, syserr.FromError(err) + } + return newVFS2Socket(t, p.family, stype, protocol, fd, uint32(stypeflags&syscall.SOCK_NONBLOCK)) +} + +// Pair implements socket.Provider.Pair. +func (p *socketProviderVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Not supported by AF_INET/AF_INET6. + return nil, nil, nil +} diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go new file mode 100644 index 000000000..8a783712e --- /dev/null +++ b/pkg/sentry/socket/hostinet/sockopt_impl.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostinet + +import ( + "gvisor.dev/gvisor/pkg/sentry/kernel" +) + +func getSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} + +func setSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 4b460d30e..3d3fabb30 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -25,15 +25,16 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -55,6 +56,7 @@ type Stack struct { interfaceAddrs map[int32][]inet.InterfaceAddr routes []inet.Route supportsIPv6 bool + tcpRecovery inet.TCPLossRecovery tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool @@ -128,6 +130,13 @@ func (s *Stack) Configure() error { log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false") } + s.ipv4Forwarding = false + if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil { + s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" + } else { + log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false") + } + return nil } @@ -321,6 +330,11 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return addrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + return syserror.EACCES +} + // SupportsIPv6 implements inet.Stack.SupportsIPv6. func (s *Stack) SupportsIPv6() bool { return s.supportsIPv6 @@ -356,6 +370,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + return s.tcpRecovery, nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserror.EACCES +} + // getLine reads one line from proc file, with specified prefix. // The last argument, withHeader, specifies if it contains line header. func getLine(f *os.File, prefix string, withHeader bool) string { @@ -455,6 +479,15 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + // Forwarding implements inet.Stack.Forwarding. func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { switch protocol { |