diff options
author | Ian Gudger <igudger@google.com> | 2019-04-19 16:15:37 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-19 16:17:01 -0700 |
commit | 358eb52a76ebd41baf52972f901af0ff398e131b (patch) | |
tree | 90812de0d36a1fde5b8a5ddb8e39a44d206be8e7 | |
parent | cec2cdc12f30e87e5b0f6750fe1c132d89fcfb6d (diff) |
Add support for the MSG_TRUNC msghdr flag.
The MSG_TRUNC flag is set in the msghdr when a message is truncated.
Fixes google/gvisor#200
PiperOrigin-RevId: 244440486
Change-Id: I03c7d5e7f5935c0c6b8d69b012db1780ac5b8456
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 45 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 12 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 18 | ||||
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 16 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 31 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_socket.go | 16 | ||||
-rw-r--r-- | test/syscalls/linux/socket_netlink_route.cc | 80 | ||||
-rw-r--r-- | test/syscalls/linux/socket_non_stream.cc | 55 | ||||
-rw-r--r-- | test/syscalls/linux/socket_stream.cc | 27 |
10 files changed, 245 insertions, 57 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f370b803b..23138d874 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -376,7 +376,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS if dst.NumBytes() == 0 { return 0, nil } - n, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) + n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) if err == syserr.ErrWouldBlock { return int64(n), syserror.ErrWouldBlock } @@ -1696,7 +1696,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq // nonBlockingRead issues a non-blocking read. // // TODO: Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -1712,14 +1712,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) s.readMu.Unlock() - return n, nil, 0, socket.ControlMessages{}, err + return n, 0, nil, 0, socket.ControlMessages{}, err } s.readMu.Lock() defer s.readMu.Unlock() if err := s.fetchReadView(); err != nil { - return 0, nil, 0, socket.ControlMessages{}, err + return 0, 0, nil, 0, socket.ControlMessages{}, err } if !isPacket && peek && trunc { @@ -1727,14 +1727,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // amount that could be read. var rql tcpip.ReceiveQueueSizeOption if err := s.Endpoint.GetSockOpt(&rql); err != nil { - return 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } available := len(s.readView) + int(rql) bufLen := int(dst.NumBytes()) if available < bufLen { - return available, nil, 0, socket.ControlMessages{}, nil + return available, 0, nil, 0, socket.ControlMessages{}, nil } - return bufLen, nil, 0, socket.ControlMessages{}, nil + return bufLen, 0, nil, 0, socket.ControlMessages{}, nil } n, err := dst.CopyOut(ctx, s.readView) @@ -1751,11 +1751,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if peek { if l := len(s.readView); trunc && l > n { // isPacket must be true. - return l, addr, addrLen, s.controlMessages(), syserr.FromError(err) + return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err) } if isPacket || err != nil { - return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err) + return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err) } // We need to peek beyond the first message. @@ -1773,7 +1773,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // We got some data, so no need to return an error. err = nil } - return int(n), nil, 0, s.controlMessages(), syserr.FromError(err) + return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err) } var msgLen int @@ -1785,11 +1785,16 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + var flags int + if msgLen > int(n) { + flags |= linux.MSG_TRUNC + } + if trunc { - return msgLen, addr, addrLen, s.controlMessages(), syserr.FromError(err) + n = msgLen } - return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err) + return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err) } func (s *SocketOperations) controlMessages() socket.ControlMessages { @@ -1810,7 +1815,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -1819,16 +1824,16 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // Stream sockets ignore the sender address. senderRequested = false } - n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) + n, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) if s.isPacketBased() && err == syserr.ErrClosedForReceive && flags&linux.MSG_DONTWAIT != 0 { // In this situation we should return EAGAIN. - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } if err != nil && (err != syserr.ErrWouldBlock || dontWait) { // Read failed and we should not retry. - return 0, nil, 0, socket.ControlMessages{}, err + return 0, 0, nil, 0, socket.ControlMessages{}, err } if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) { @@ -1847,7 +1852,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags for { var rn int - rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) + rn, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) n += rn if err != nil && err != syserr.ErrWouldBlock { // Always stop on errors other than would block as we generally @@ -1866,12 +1871,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if n > 0 { - return n, senderAddr, senderAddrLen, controlMessages, nil + return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil } if err == syserror.ETIMEDOUT { - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } } } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index be63823d8..c4848b313 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -345,14 +345,14 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME: We can't support MSG_ERRQUEUE because it uses ancillary // messages that netstack/tcpip/transport/unix doesn't understand. Kill the // Socket interface's dependence on netstack. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } var senderAddr []byte @@ -360,6 +360,8 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags senderAddr = make([]byte, sizeofSockaddr) } + var msgFlags int + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of dst.Addrs was unusable. if uint64(dst.NumBytes()) != dsts.NumBytes() { @@ -391,6 +393,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, err } senderAddr = senderAddr[:msg.Namelen] + msgFlags = int(msg.Flags) return n, nil }) @@ -417,7 +420,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } - return int(n), senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err) + // We don't allow control messages. + msgFlags &^= linux.MSG_CTRUNC + + return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err) } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 2503a67c5..0fe9b39b6 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -397,7 +397,7 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // RecvMsg implements socket.Socket.RecvMsg. -func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { from := linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: 0, @@ -412,10 +412,14 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have } if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + var mflags int + if n < int64(r.MsgSize) { + mflags |= linux.MSG_TRUNC + } if trunc { n = int64(r.MsgSize) } - return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) } // We'll have to block. Register for notification and keep trying to @@ -426,17 +430,21 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have for { if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + var mflags int + if n < int64(r.MsgSize) { + mflags |= linux.MSG_TRUNC + } if trunc { n = int64(r.MsgSize) } - return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if err == syserror.ETIMEDOUT { - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } } } diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 896b5b7ce..3418a6d75 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -673,7 +673,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ Fd: s.fd, Length: uint32(dst.NumBytes()), @@ -694,10 +694,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { - return 0, nil, 0, socket.ControlMessages{}, err + return 0, 0, nil, 0, socket.ControlMessages{}, err } // We'll have to block. Register for notifications and keep trying to @@ -718,23 +718,23 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { - return 0, nil, 0, socket.ControlMessages{}, err + return 0, 0, nil, 0, socket.ControlMessages{}, err } if s.isShutRdSet() { // Blocking would have caused us to block indefinitely so we return 0, // this is the same behavior as Linux. - return 0, nil, 0, socket.ControlMessages{}, nil + return 0, 0, nil, 0, socket.ControlMessages{}, nil } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if err == syserror.ETIMEDOUT { - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } } } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 5ab423f3c..62ba13782 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -88,7 +88,7 @@ type Socket interface { // not necessarily the actual length of the address. // // If err != nil, the recv was not successful. - RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) + RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take // ownership of the ControlMessage on error. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 92411c901..01efd24d3 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -477,7 +477,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -515,11 +515,17 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if r.From != nil { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } - if trunc { - n = int64(r.MsgSize) - } + if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { - return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + if s.isPacket && n < int64(r.MsgSize) { + msgFlags |= linux.MSG_TRUNC + } + + if trunc { + n = int64(r.MsgSize) + } + + return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) } // Don't overwrite any data we received. @@ -541,14 +547,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } if trunc { - n = int64(r.MsgSize) + // n and r.MsgSize are the same for streams. + total += int64(r.MsgSize) + } else { + total += n } - total += n if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { if total > 0 { err = nil } - return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + if s.isPacket && n < int64(r.MsgSize) { + msgFlags |= linux.MSG_TRUNC + } + return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) } // Don't overwrite any data we received. @@ -560,9 +571,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags err = nil } if err == syserror.ETIMEDOUT { - return int(total), nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - return int(total), nil, 0, socket.ControlMessages{}, syserr.FromError(err) + return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } } } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 49e6f4aeb..30ccc3f66 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -742,17 +742,15 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { - n, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) + n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) if err != nil { return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) } cms.Unix.Release() - if msg.Flags != 0 { + if int(msg.Flags) != mflags { // Copy out the flags to the caller. - // - // TODO: Plumb through actual flags. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(0)); err != nil { + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } } @@ -763,7 +761,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if msg.ControlLen > maxControlLen { return 0, syscall.ENOBUFS } - n, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) + n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } @@ -802,9 +800,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } // Copy out the flags to the caller. - // - // TODO: Plumb through actual flags. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(0)); err != nil { + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } @@ -856,7 +852,7 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f flags |= linux.MSG_DONTWAIT } - n, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) + n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) cm.Unix.Release() if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 5f83836df..fa895d841 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -220,6 +220,86 @@ TEST(NetlinkRouteTest, GetLinkDump) { EXPECT_TRUE(loopbackFound); } +TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + // No destination required; it defaults to pid 0, the kernel. + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + // Small enough to ensure that the response doesn't fit. + constexpr size_t kBufferSize = 10; + std::vector<char> buf(kBufferSize); + iov.iov_base = buf.data(); + iov.iov_len = buf.size(); + + ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), + SyscallSucceedsWithValue(kBufferSize)); + EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC); +} + +TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + // No destination required; it defaults to pid 0, the kernel. + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + // Small enough to ensure that the response doesn't fit. + constexpr size_t kBufferSize = 10; + std::vector<char> buf(kBufferSize); + iov.iov_base = buf.data(); + iov.iov_len = buf.size(); + + int res = 0; + ASSERT_THAT(res = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC), + SyscallSucceeds()); + EXPECT_GT(res, kBufferSize); + EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC); +} + TEST(NetlinkRouteTest, ControlMessageIgnored) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc index d49aab363..d170008a4 100644 --- a/test/syscalls/linux/socket_non_stream.cc +++ b/test/syscalls/linux/socket_non_stream.cc @@ -15,6 +15,7 @@ #include "test/syscalls/linux/socket_non_stream.h" #include <stdio.h> +#include <sys/socket.h> #include <sys/un.h> #include "gtest/gtest.h" @@ -89,6 +90,33 @@ TEST_P(NonStreamSocketPairTest, SingleRecv) { EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1))); } +TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[sizeof(sent_data) / 2] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); + + // Check that msghdr flags were updated. + EXPECT_EQ(msg.msg_flags, MSG_TRUNC); +} + // Stream sockets allow data sent with multiple sends to be peeked at in a // single recv. Datagram sockets (except for unix sockets) do not. // @@ -142,6 +170,33 @@ TEST_P(NonStreamSocketPairTest, MsgTruncTruncation) { sizeof(sent_data) / 2)); } +TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[sizeof(sent_data) / 2] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC), + SyscallSucceedsWithValue(sizeof(sent_data))); + EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); + + // Check that msghdr flags were updated. + EXPECT_EQ(msg.msg_flags, MSG_TRUNC); +} + TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[512]; diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc index 32e9d958b..c8a8ad0f6 100644 --- a/test/syscalls/linux/socket_stream.cc +++ b/test/syscalls/linux/socket_stream.cc @@ -81,6 +81,33 @@ TEST_P(StreamSocketPairTest, WriteOneSideClosed) { SyscallFailsWithErrno(EPIPE)); } +TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[sizeof(sent_data) / 2] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); + + // Check that msghdr flags were cleared (MSG_TRUNC was not set). + EXPECT_EQ(msg.msg_flags, 0); +} + TEST_P(StreamSocketPairTest, MsgTrunc) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[512]; |