summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorIan Gudger <igudger@google.com>2019-04-19 16:15:37 -0700
committerShentubot <shentubot@google.com>2019-04-19 16:17:01 -0700
commit358eb52a76ebd41baf52972f901af0ff398e131b (patch)
tree90812de0d36a1fde5b8a5ddb8e39a44d206be8e7
parentcec2cdc12f30e87e5b0f6750fe1c132d89fcfb6d (diff)
Add support for the MSG_TRUNC msghdr flag.
The MSG_TRUNC flag is set in the msghdr when a message is truncated. Fixes google/gvisor#200 PiperOrigin-RevId: 244440486 Change-Id: I03c7d5e7f5935c0c6b8d69b012db1780ac5b8456
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go45
-rw-r--r--pkg/sentry/socket/hostinet/socket.go12
-rw-r--r--pkg/sentry/socket/netlink/socket.go18
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go16
-rw-r--r--pkg/sentry/socket/socket.go2
-rw-r--r--pkg/sentry/socket/unix/unix.go31
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go16
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc80
-rw-r--r--test/syscalls/linux/socket_non_stream.cc55
-rw-r--r--test/syscalls/linux/socket_stream.cc27
10 files changed, 245 insertions, 57 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index f370b803b..23138d874 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -376,7 +376,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
if dst.NumBytes() == 0 {
return 0, nil
}
- n, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
if err == syserr.ErrWouldBlock {
return int64(n), syserror.ErrWouldBlock
}
@@ -1696,7 +1696,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
// nonBlockingRead issues a non-blocking read.
//
// TODO: Support timestamps for stream sockets.
-func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
@@ -1712,14 +1712,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
s.readMu.Unlock()
- return n, nil, 0, socket.ControlMessages{}, err
+ return n, 0, nil, 0, socket.ControlMessages{}, err
}
s.readMu.Lock()
defer s.readMu.Unlock()
if err := s.fetchReadView(); err != nil {
- return 0, nil, 0, socket.ControlMessages{}, err
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
}
if !isPacket && peek && trunc {
@@ -1727,14 +1727,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// amount that could be read.
var rql tcpip.ReceiveQueueSizeOption
if err := s.Endpoint.GetSockOpt(&rql); err != nil {
- return 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := len(s.readView) + int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
- return available, nil, 0, socket.ControlMessages{}, nil
+ return available, 0, nil, 0, socket.ControlMessages{}, nil
}
- return bufLen, nil, 0, socket.ControlMessages{}, nil
+ return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
n, err := dst.CopyOut(ctx, s.readView)
@@ -1751,11 +1751,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if peek {
if l := len(s.readView); trunc && l > n {
// isPacket must be true.
- return l, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
}
if isPacket || err != nil {
- return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
}
// We need to peek beyond the first message.
@@ -1773,7 +1773,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// We got some data, so no need to return an error.
err = nil
}
- return int(n), nil, 0, s.controlMessages(), syserr.FromError(err)
+ return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
}
var msgLen int
@@ -1785,11 +1785,16 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readView.TrimFront(int(n))
}
+ var flags int
+ if msgLen > int(n) {
+ flags |= linux.MSG_TRUNC
+ }
+
if trunc {
- return msgLen, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ n = msgLen
}
- return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err)
}
func (s *SocketOperations) controlMessages() socket.ControlMessages {
@@ -1810,7 +1815,7 @@ func (s *SocketOperations) updateTimestamp() {
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -1819,16 +1824,16 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// Stream sockets ignore the sender address.
senderRequested = false
}
- n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ n, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
if s.isPacketBased() && err == syserr.ErrClosedForReceive && flags&linux.MSG_DONTWAIT != 0 {
// In this situation we should return EAGAIN.
- return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
// Read failed and we should not retry.
- return 0, nil, 0, socket.ControlMessages{}, err
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
}
if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
@@ -1847,7 +1852,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
for {
var rn int
- rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ rn, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
n += rn
if err != nil && err != syserr.ErrWouldBlock {
// Always stop on errors other than would block as we generally
@@ -1866,12 +1871,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if n > 0 {
- return n, senderAddr, senderAddrLen, controlMessages, nil
+ return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil
}
if err == syserror.ETIMEDOUT {
- return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index be63823d8..c4848b313 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -345,14 +345,14 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME: We can't support MSG_ERRQUEUE because it uses ancillary
// messages that netstack/tcpip/transport/unix doesn't understand. Kill the
// Socket interface's dependence on netstack.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
- return 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
}
var senderAddr []byte
@@ -360,6 +360,8 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
senderAddr = make([]byte, sizeofSockaddr)
}
+ var msgFlags int
+
recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of dst.Addrs was unusable.
if uint64(dst.NumBytes()) != dsts.NumBytes() {
@@ -391,6 +393,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, err
}
senderAddr = senderAddr[:msg.Namelen]
+ msgFlags = int(msg.Flags)
return n, nil
})
@@ -417,7 +420,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
- return int(n), senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err)
+ // We don't allow control messages.
+ msgFlags &^= linux.MSG_CTRUNC
+
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err)
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 2503a67c5..0fe9b39b6 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -397,7 +397,7 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
from := linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: 0,
@@ -412,10 +412,14 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
}
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
if trunc {
n = int64(r.MsgSize)
}
- return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
}
// We'll have to block. Register for notification and keep trying to
@@ -426,17 +430,21 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
for {
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
if trunc {
n = int64(r.MsgSize)
}
- return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
}
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if err == syserror.ETIMEDOUT {
- return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 896b5b7ce..3418a6d75 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -673,7 +673,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
Fd: s.fd,
Length: uint32(dst.NumBytes()),
@@ -694,10 +694,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
- return 0, nil, 0, socket.ControlMessages{}, err
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
}
// We'll have to block. Register for notifications and keep trying to
@@ -718,23 +718,23 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- return 0, nil, 0, socket.ControlMessages{}, err
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
}
if s.isShutRdSet() {
// Blocking would have caused us to block indefinitely so we return 0,
// this is the same behavior as Linux.
- return 0, nil, 0, socket.ControlMessages{}, nil
+ return 0, 0, nil, 0, socket.ControlMessages{}, nil
}
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if err == syserror.ETIMEDOUT {
- return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 5ab423f3c..62ba13782 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -88,7 +88,7 @@ type Socket interface {
// not necessarily the actual length of the address.
//
// If err != nil, the recv was not successful.
- RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
+ RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
// SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
// ownership of the ControlMessage on error.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 92411c901..01efd24d3 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -477,7 +477,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -515,11 +515,17 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if r.From != nil {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
- if trunc {
- n = int64(r.MsgSize)
- }
+
if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
- return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+
+ return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
}
// Don't overwrite any data we received.
@@ -541,14 +547,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if trunc {
- n = int64(r.MsgSize)
+ // n and r.MsgSize are the same for streams.
+ total += int64(r.MsgSize)
+ } else {
+ total += n
}
- total += n
if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil
}
- return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
}
// Don't overwrite any data we received.
@@ -560,9 +571,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
err = nil
}
if err == syserror.ETIMEDOUT {
- return int(total), nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- return int(total), nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 49e6f4aeb..30ccc3f66 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -742,17 +742,15 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
// Fast path when no control message nor name buffers are provided.
if msg.ControlLen == 0 && msg.NameLen == 0 {
- n, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
+ n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
if err != nil {
return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS)
}
cms.Unix.Release()
- if msg.Flags != 0 {
+ if int(msg.Flags) != mflags {
// Copy out the flags to the caller.
- //
- // TODO: Plumb through actual flags.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(0)); err != nil {
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
}
@@ -763,7 +761,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if msg.ControlLen > maxControlLen {
return 0, syscall.ENOBUFS
}
- n, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
+ n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
@@ -802,9 +800,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
// Copy out the flags to the caller.
- //
- // TODO: Plumb through actual flags.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(0)); err != nil {
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
@@ -856,7 +852,7 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f
flags |= linux.MSG_DONTWAIT
}
- n, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
+ n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
cm.Unix.Release()
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index 5f83836df..fa895d841 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -220,6 +220,86 @@ TEST(NetlinkRouteTest, GetLinkDump) {
EXPECT_TRUE(loopbackFound);
}
+TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ // No destination required; it defaults to pid 0, the kernel.
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ // Small enough to ensure that the response doesn't fit.
+ constexpr size_t kBufferSize = 10;
+ std::vector<char> buf(kBufferSize);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0),
+ SyscallSucceedsWithValue(kBufferSize));
+ EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC);
+}
+
+TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ // No destination required; it defaults to pid 0, the kernel.
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ // Small enough to ensure that the response doesn't fit.
+ constexpr size_t kBufferSize = 10;
+ std::vector<char> buf(kBufferSize);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ int res = 0;
+ ASSERT_THAT(res = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC),
+ SyscallSucceeds());
+ EXPECT_GT(res, kBufferSize);
+ EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC);
+}
+
TEST(NetlinkRouteTest, ControlMessageIgnored) {
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc
index d49aab363..d170008a4 100644
--- a/test/syscalls/linux/socket_non_stream.cc
+++ b/test/syscalls/linux/socket_non_stream.cc
@@ -15,6 +15,7 @@
#include "test/syscalls/linux/socket_non_stream.h"
#include <stdio.h>
+#include <sys/socket.h>
#include <sys/un.h>
#include "gtest/gtest.h"
@@ -89,6 +90,33 @@ TEST_P(NonStreamSocketPairTest, SingleRecv) {
EXPECT_EQ(0, memcmp(sent_data1, received_data, sizeof(sent_data1)));
}
+TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) / 2] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
+
+ // Check that msghdr flags were updated.
+ EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+}
+
// Stream sockets allow data sent with multiple sends to be peeked at in a
// single recv. Datagram sockets (except for unix sockets) do not.
//
@@ -142,6 +170,33 @@ TEST_P(NonStreamSocketPairTest, MsgTruncTruncation) {
sizeof(sent_data) / 2));
}
+TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) / 2] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
+
+ // Check that msghdr flags were updated.
+ EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+}
+
TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[512];
diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc
index 32e9d958b..c8a8ad0f6 100644
--- a/test/syscalls/linux/socket_stream.cc
+++ b/test/syscalls/linux/socket_stream.cc
@@ -81,6 +81,33 @@ TEST_P(StreamSocketPairTest, WriteOneSideClosed) {
SyscallFailsWithErrno(EPIPE));
}
+TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[sizeof(sent_data) / 2] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0),
+ SyscallSucceedsWithValue(sizeof(received_data)));
+ EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ EXPECT_EQ(msg.msg_flags, 0);
+}
+
TEST_P(StreamSocketPairTest, MsgTrunc) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[512];