From 358eb52a76ebd41baf52972f901af0ff398e131b Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Fri, 19 Apr 2019 16:15:37 -0700 Subject: Add support for the MSG_TRUNC msghdr flag. The MSG_TRUNC flag is set in the msghdr when a message is truncated. Fixes google/gvisor#200 PiperOrigin-RevId: 244440486 Change-Id: I03c7d5e7f5935c0c6b8d69b012db1780ac5b8456 --- pkg/sentry/socket/netlink/socket.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'pkg/sentry/socket/netlink') diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 2503a67c5..0fe9b39b6 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -397,7 +397,7 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // RecvMsg implements socket.Socket.RecvMsg. -func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { from := linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: 0, @@ -412,10 +412,14 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have } if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + var mflags int + if n < int64(r.MsgSize) { + mflags |= linux.MSG_TRUNC + } if trunc { n = int64(r.MsgSize) } - return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) } // We'll have to block. Register for notification and keep trying to @@ -426,17 +430,21 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have for { if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + var mflags int + if n < int64(r.MsgSize) { + mflags |= linux.MSG_TRUNC + } if trunc { n = int64(r.MsgSize) } - return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if err == syserror.ETIMEDOUT { - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } } } -- cgit v1.2.3