From 74788b1b6194ef62f8355f7e4721c00f615d16ad Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Thu, 17 Dec 2020 08:45:38 -0800 Subject: [netstack] Implement MSG_ERRQUEUE flag for recvmsg(2). Introduces the per-socket error queue and the necessary cmsg mechanisms. PiperOrigin-RevId: 348028508 --- pkg/sentry/socket/control/control.go | 39 ++++++++++++++++++++++++ pkg/sentry/socket/hostinet/socket.go | 19 ++++++++---- pkg/sentry/socket/netstack/netstack.go | 43 ++++++++++++++++++++++++++ pkg/sentry/socket/socket.go | 55 ++++++++++++++++++++++++++++++++++ 4 files changed, 150 insertions(+), 6 deletions(-) (limited to 'pkg/sentry/socket') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index b88cdca48..ff6b71802 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -371,6 +371,17 @@ func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, b buf, level, optType, t.Arch().Width(), originalDstAddress) } +// PackSockExtendedErr packs an IP*_RECVERR socket control message. +func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte { + return putCmsgStruct( + buf, + sockErr.CMsgLevel(), + sockErr.CMsgType(), + t.Arch().Width(), + sockErr, + ) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -403,6 +414,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } + if cmsgs.IP.SockErr != nil { + buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf) + } + return buf } @@ -440,6 +455,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) } + if cmsgs.IP.SockErr != nil { + space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes()) + } + return space } @@ -546,6 +565,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) + case linux.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } @@ -568,6 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) + case linux.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 1f220c343..2b34ef190 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -450,11 +450,7 @@ func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, sende // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Only allow known and safe flags. - // - // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the - // Socket interface's dependence on netstack. - if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { + if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } @@ -488,7 +484,8 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var ch chan struct{} n, err := copyToDst() - if flags&syscall.MSG_DONTWAIT == 0 { + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. @@ -551,6 +548,11 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s var addr linux.SockAddrInet binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg } case linux.SOL_IPV6: @@ -563,6 +565,11 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s var addr linux.SockAddrInet6 binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg } case linux.SOL_TCP: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 23d5cab9c..a8ab6b385 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2772,6 +2772,8 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages { IP: socket.IPControlMessages{ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp, + HasInq: s.readCM.HasInq, + Inq: s.readCM.Inq, HasTOS: s.readCM.HasTOS, TOS: s.readCM.TOS, HasTClass: s.readCM.HasTClass, @@ -2779,6 +2781,7 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages { HasIPPacketInfo: s.readCM.HasIPPacketInfo, PacketInfo: s.readCM.PacketInfo, OriginalDstAddress: s.readCM.OriginalDstAddress, + SockErr: s.readCM.SockErr, }, } } @@ -2795,9 +2798,49 @@ func (s *socketOpsCommon) updateTimestamp() { } } +// addrFamilyFromNetProto returns the address family identifier for the given +// network protocol. +func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { + switch net { + case header.IPv4ProtocolNumber: + return linux.AF_INET + case header.IPv6ProtocolNumber: + return linux.AF_INET6 + default: + panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net)) + } +} + +// recvErr handles MSG_ERRQUEUE for recvmsg(2). +// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error(). +func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + sockErr := s.Endpoint.SocketOptions().DequeueErr() + if sockErr == nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + } + + // The payload of the original packet that caused the error is passed as + // normal data via msg_iovec. -- recvmsg(2) + msgFlags := linux.MSG_ERRQUEUE + if int(dst.NumBytes()) < len(sockErr.Payload) { + msgFlags |= linux.MSG_TRUNC + } + n, err := dst.CopyOut(t, sockErr.Payload) + + // The original destination address of the datagram that caused the error is + // supplied via msg_name. -- recvmsg(2) + dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst) + cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})} + return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err) +} + // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { + if flags&linux.MSG_ERRQUEUE != 0 { + return s.recvErr(t, dst) + } + trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index bcc426e33..97729dacc 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -56,6 +56,57 @@ func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPack return p } +// errOriginToLinux maps tcpip socket origin to Linux socket origin constants. +func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 { + switch origin { + case tcpip.SockExtErrorOriginNone: + return linux.SO_EE_ORIGIN_NONE + case tcpip.SockExtErrorOriginLocal: + return linux.SO_EE_ORIGIN_LOCAL + case tcpip.SockExtErrorOriginICMP: + return linux.SO_EE_ORIGIN_ICMP + case tcpip.SockExtErrorOriginICMP6: + return linux.SO_EE_ORIGIN_ICMP6 + default: + panic(fmt.Sprintf("unknown socket origin: %d", origin)) + } +} + +// sockErrCmsgToLinux converts SockError control message from tcpip format to +// Linux format. +func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { + if sockErr == nil { + return nil + } + + ee := linux.SockExtendedErr{ + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Origin: errOriginToLinux(sockErr.ErrOrigin), + Type: sockErr.ErrType, + Code: sockErr.ErrCode, + Info: sockErr.ErrInfo, + } + + switch sockErr.NetProto { + case header.IPv4ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet) + } + return errMsg + case header.IPv6ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet6) + } + return errMsg + default: + panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto)) + } +} + // NewIPControlMessages converts the tcpip ControlMessgaes (which does not // have Linux specific format) to Linux format. func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { @@ -75,6 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa HasIPPacketInfo: cmgs.HasIPPacketInfo, PacketInfo: packetInfoToLinux(cmgs.PacketInfo), OriginalDstAddress: orgDstAddr, + SockErr: sockErrCmsgToLinux(cmgs.SockErr), } } @@ -117,6 +169,9 @@ type IPControlMessages struct { // OriginalDestinationAddress holds the original destination address // and port of the incoming packet. OriginalDstAddress linux.SockAddr + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr linux.SockErrCMsg } // Release releases Unix domain socket credentials and rights. -- cgit v1.2.3