From cc28d36845cd3b2267ececbdf81b2c265267cdec Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Tue, 15 Dec 2020 13:46:38 -0800 Subject: [netstack] Make recvmsg(2) call to host in hostinet even if dst is empty. We want to make the recvmsg syscall to the host regardless of if the dst is empty or not so that: - Host can populate the control messages if necessary. - Host can return sender address. - Host can return appropriate errors. Earlier because we were using the IOSequence.CopyOutFrom() API, the usermem package does not even call the Reader function if the destination is empty (as an optimization). PiperOrigin-RevId: 347684566 --- pkg/sentry/socket/hostinet/socket.go | 110 ++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 48 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index be418df2e..1f220c343 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -416,6 +416,37 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] return nil } +func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) { + // We always do a non-blocking recv*(). + sysflags := flags | syscall.MSG_DONTWAIT + + msg := syscall.Msghdr{} + if len(iovs) > 0 { + msg.Iov = &iovs[0] + msg.Iovlen = uint64(len(iovs)) + } + var senderAddrBuf []byte + if senderRequested { + senderAddrBuf = make([]byte, sizeofSockaddr) + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(sizeofSockaddr) + } + var controlBuf []byte + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } + n, err := recvmsg(s.fd, &msg, sysflags) + if err != nil { + return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err + } + return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err +} + // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Only allow known and safe flags. @@ -427,56 +458,36 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr linux.SockAddr var senderAddrBuf []byte - if senderRequested { - senderAddrBuf = make([]byte, sizeofSockaddr) - } - var controlBuf []byte var msgFlags int - - recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { - // Refuse to do anything if any part of dst.Addrs was unusable. - if uint64(dst.NumBytes()) != dsts.NumBytes() { - return 0, nil - } - if dsts.IsEmpty() { - return 0, nil - } - - // We always do a non-blocking recv*(). - sysflags := flags | syscall.MSG_DONTWAIT - - iovs := safemem.IovecsFromBlockSeq(dsts) - msg := syscall.Msghdr{ - Iov: &iovs[0], - Iovlen: uint64(len(iovs)), - } - if len(senderAddrBuf) != 0 { - msg.Name = &senderAddrBuf[0] - msg.Namelen = uint32(len(senderAddrBuf)) - } - if controlLen > 0 { - if controlLen > maxControlLen { - controlLen = maxControlLen + copyToDst := func() (int64, error) { + var n uint64 + var err error + if dst.NumBytes() == 0 { + // We want to make the recvmsg(2) call to the host even if dst is empty + // to fetch control messages, sender address or errors if any occur. + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen) + return int64(n), err + } + + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { + // Refuse to do anything if any part of dst.Addrs was unusable. + if uint64(dst.NumBytes()) != dsts.NumBytes() { + return 0, nil + } + if dsts.IsEmpty() { + return 0, nil } - controlBuf = make([]byte, controlLen) - msg.Control = &controlBuf[0] - msg.Controllen = controlLen - } - n, err := recvmsg(s.fd, &msg, sysflags) - if err != nil { - return 0, err - } - senderAddrBuf = senderAddrBuf[:msg.Namelen] - msgFlags = int(msg.Flags) - controlLen = uint64(msg.Controllen) - return n, nil - }) + + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen) + return n, err + }) + return dst.CopyOutFrom(t, recvmsgToBlocks) + } var ch chan struct{} - n, err := dst.CopyOutFrom(t, recvmsgToBlocks) + n, err := copyToDst() if flags&syscall.MSG_DONTWAIT == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which @@ -494,22 +505,26 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags s.EventRegister(&e, waiter.EventIn) defer s.EventUnregister(&e) } - n, err = dst.CopyOutFrom(t, recvmsgToBlocks) + n, err = copyToDst() } } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + var senderAddr linux.SockAddr if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil +} +func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages { controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { @@ -558,8 +573,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } } - - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil + return controlMessages } // SendMsg implements socket.Socket.SendMsg. -- cgit v1.2.3