diff options
-rw-r--r-- | pkg/sentry/kernel/msgqueue/msgqueue.go | 278 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/linux64.go | 8 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_msgqueue.go | 85 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 2 | ||||
-rw-r--r-- | test/syscalls/linux/msgqueue.cc | 572 |
6 files changed, 936 insertions, 10 deletions
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go index 3ce926950..c111297d7 100644 --- a/pkg/sentry/kernel/msgqueue/msgqueue.go +++ b/pkg/sentry/kernel/msgqueue/msgqueue.go @@ -119,14 +119,21 @@ type Queue struct { type Message struct { msgEntry - // mType is an integer representing the type of the sent message. - mType int64 + // Type is an integer representing the type of the sent message. + Type int64 - // mText is an untyped block of memory. - mText []byte + // Text is an untyped block of memory. + Text []byte - // mSize is the size of mText. - mSize uint64 + // Size is the size of Text. + Size uint64 +} + +// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves +// as an abstracted version of kernel.Task. kernel.Task is not directly used to +// prevent circular dependencies. +type Blocker interface { + Block(C <-chan struct{}) error } // FindOrCreate creates a new message queue or returns an existing one. See @@ -186,6 +193,265 @@ func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error { return nil } +// FindByID returns the queue with the specified ID and an error if the ID +// doesn't exist. +func (r *Registry) FindByID(id ipc.ID) (*Queue, error) { + r.mu.Lock() + defer r.mu.Unlock() + + mech := r.reg.FindByID(id) + if mech == nil { + return nil, linuxerr.EINVAL + } + return mech.(*Queue), nil +} + +// Send appends a message to the message queue, and returns an error if sending +// fails. See msgsnd(2). +func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) { + // Try to perform a non-blocking send using queue.append. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + return err + } + + if !wait { + return linuxerr.EAGAIN + } + + e, ch := waiter.NewChannelEntry(nil) + q.senders.EventRegister(&e, waiter.EventOut) + + for { + if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + + q.senders.EventUnregister(&e) + return err +} + +// append appends a message to the queue's message list and notifies waiting +// receivers that a message has been inserted. It returns an error if adding +// the message would cause the queue to exceed its maximum capacity, which can +// be used as a signal to block the task. Other errors should be returned as is. +func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error { + if m.Type <= 0 { + return linuxerr.EINVAL + } + + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Write: true}) { + // The calling process does not have write permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return linuxerr.EIDRM + } + + // Check if sufficient space is available (the queue isn't full.) From + // the man pages: + // + // "A message queue is considered to be full if either of the following + // conditions is true: + // + // • Adding a new message to the queue would cause the total number + // of bytes in the queue to exceed the queue's maximum size (the + // msg_qbytes field). + // + // • Adding another message to the queue would cause the total + // number of messages in the queue to exceed the queue's maximum + // size (the msg_qbytes field). This check is necessary to + // prevent an unlimited number of zero-length messages being + // placed on the queue. Although such messages contain no data, + // they nevertheless consume (locked) kernel memory." + // + // The msg_qbytes field in our implementation is q.maxBytes. + if m.Size+q.byteCount > q.maxBytes || q.messageCount+1 > q.maxBytes { + return linuxerr.EWOULDBLOCK + } + + // Copy the message into the queue. + q.messages.PushBack(&m) + + q.byteCount += m.Size + q.messageCount++ + q.sendPID = pid + q.sendTime = ktime.NowFromContext(ctx) + + // Notify receivers about the new message. + q.receivers.Notify(waiter.EventIn) + + return nil +} + +// Receive removes a message from the queue and returns it. See msgrcv(2). +func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) { + if maxSize < 0 || maxSize > maxMessageBytes { + return nil, linuxerr.EINVAL + } + max := uint64(maxSize) + + // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + return msg, err + } + + if !wait { + return nil, linuxerr.ENOMSG + } + + e, ch := waiter.NewChannelEntry(nil) + q.receivers.EventRegister(&e, waiter.EventIn) + + for { + if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + q.receivers.EventUnregister(&e) + return msg, err +} + +// pop pops the first message from the queue that matches the given type. It +// returns an error for all the cases specified in msgrcv(2). If the queue is +// empty or no message of the specified type is available, a EWOULDBLOCK error +// is returned, which can then be used as a signal to block the process or fail. +func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) { + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Read: true}) { + // The calling process does not have read permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return nil, linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return nil, linuxerr.EIDRM + } + + if q.messages.Empty() { + return nil, linuxerr.EWOULDBLOCK + } + + // Get a message from the queue. + switch { + case mType == 0: + msg = q.messages.Front() + case mType > 0: + msg = q.msgOfType(mType, except) + case mType < 0: + msg = q.msgOfTypeLessThan(-1 * mType) + } + + // If no message exists, return a blocking singal. + if msg == nil { + return nil, linuxerr.EWOULDBLOCK + } + + // Check message's size is acceptable. + if maxSize < msg.Size { + if !truncate { + return nil, linuxerr.E2BIG + } + msg.Size = maxSize + msg.Text = msg.Text[:maxSize+1] + } + + q.messages.Remove(msg) + + q.byteCount -= msg.Size + q.messageCount-- + q.receivePID = pid + q.receiveTime = ktime.NowFromContext(ctx) + + // Notify senders about available space. + q.senders.Notify(waiter.EventOut) + + return msg, nil +} + +// Copy copies a message from the queue without deleting it. If no message +// exists, an error is returned. See msgrcv(MSG_COPY). +func (q *Queue) Copy(mType int64) (*Message, error) { + q.mu.Lock() + defer q.mu.Unlock() + + if mType < 0 || q.messages.Empty() { + return nil, linuxerr.ENOMSG + } + + msg := q.msgAtIndex(mType) + if msg == nil { + return nil, linuxerr.ENOMSG + } + return msg, nil +} + +// msgOfType returns the first message with the specified type, nil if no +// message is found. If except is true, the first message of a type not equal +// to mType will be returned. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfType(mType int64, except bool) *Message { + if except { + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type != mType { + return msg + } + } + return nil + } + + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type == mType { + return msg + } + } + return nil +} + +// msgOfTypeLessThan return the the first message with the lowest type less +// than or equal to mType, nil if no such message exists. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfTypeLessThan(mType int64) (m *Message) { + min := mType + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type <= mType && msg.Type < min { + m = msg + min = msg.Type + } + } + return m +} + +// msgAtIndex returns a pointer to a message at given index, nil if non exits. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgAtIndex(mType int64) *Message { + msg := q.messages.Front() + for ; mType != 0 && msg != nil; mType-- { + msg = msg.Next() + } + return msg +} + // Lock implements ipc.Mechanism.Lock. func (q *Queue) Lock() { q.mu.Lock() diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index ccccce6a9..b5a371d9a 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -86,6 +86,7 @@ go_library( "//pkg/sentry/kernel/eventfd", "//pkg/sentry/kernel/fasync", "//pkg/sentry/kernel/ipc", + "//pkg/sentry/kernel/msgqueue", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/sched", "//pkg/sentry/kernel/shm", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 6f44d767b..1ead3c7e8 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -122,8 +122,8 @@ var AMD64 = &kernel.SyscallTable{ 66: syscalls.Supported("semctl", Semctl), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.Supported("msgget", Msgget), - 69: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 70: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 69: syscalls.Supported("msgsnd", Msgsnd), + 70: syscalls.Supported("msgrcv", Msgrcv), 71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), @@ -618,8 +618,8 @@ var ARM64 = &kernel.SyscallTable{ 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) 186: syscalls.Supported("msgget", Msgget), 187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), - 188: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 189: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 188: syscalls.Supported("msgrcv", Msgrcv), + 189: syscalls.Supported("msgsnd", Msgsnd), 190: syscalls.Supported("semget", Semget), 191: syscalls.Supported("semctl", Semctl), 192: syscalls.Supported("semtimedop", Semtimedop), diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go index 3476e218d..5259ade90 100644 --- a/pkg/sentry/syscalls/linux/sys_msgqueue.go +++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go @@ -17,10 +17,12 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" + "gvisor.dev/gvisor/pkg/sentry/kernel/msgqueue" ) // Msgget implements msgget(2). @@ -41,6 +43,89 @@ func Msgget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return uintptr(queue.ID()), nil, nil } +// Msgsnd implements msgsnd(2). +func Msgsnd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + id := ipc.ID(args[0].Int()) + msgAddr := args[1].Pointer() + size := args[2].Int64() + flag := args[3].Int() + + if size < 0 || size > linux.MSGMAX { + return 0, nil, linuxerr.EINVAL + } + + wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT + pid := int32(t.ThreadGroup().ID()) + + buf := linux.MsgBuf{ + Text: make([]byte, size), + } + if _, err := buf.CopyIn(t, msgAddr); err != nil { + return 0, nil, err + } + + queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id) + if err != nil { + return 0, nil, err + } + + msg := msgqueue.Message{ + Type: int64(buf.Type), + Text: buf.Text, + Size: uint64(size), + } + return 0, nil, queue.Send(t, msg, t, wait, pid) +} + +// Msgrcv implements msgrcv(2). +func Msgrcv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + id := ipc.ID(args[0].Int()) + msgAddr := args[1].Pointer() + size := args[2].Int64() + mType := args[3].Int64() + flag := args[4].Int() + + wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT + except := flag&linux.MSG_EXCEPT == linux.MSG_EXCEPT + truncate := flag&linux.MSG_NOERROR == linux.MSG_NOERROR + + msgCopy := flag&linux.MSG_COPY == linux.MSG_COPY + + msg, err := receive(t, id, mType, size, msgCopy, wait, truncate, except) + if err != nil { + return 0, nil, err + } + + buf := linux.MsgBuf{ + Type: primitive.Int64(msg.Type), + Text: msg.Text, + } + if _, err := buf.CopyOut(t, msgAddr); err != nil { + return 0, nil, err + } + return uintptr(msg.Size), nil, nil +} + +// receive returns a message from the queue with the given ID. If msgCopy is +// true, a message is copied from the queue without being removed. Otherwise, +// a message is removed from the queue and returned. +func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wait, truncate, except bool) (*msgqueue.Message, error) { + pid := int32(t.ThreadGroup().ID()) + + queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id) + if err != nil { + return nil, err + } + + if msgCopy { + if wait || except { + return nil, linuxerr.EINVAL + } + return queue.Copy(mType) + } + return queue.Receive(t, t, mType, maxSize, wait, truncate, except, pid) +} + // Msgctl implements msgctl(2). func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { id := ipc.ID(args[0].Int()) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 3383495d0..7129a797b 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -4172,9 +4172,11 @@ cc_binary( srcs = ["msgqueue.cc"], linkstatic = 1, deps = [ + "//test/util:capability_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/time", ], ) diff --git a/test/syscalls/linux/msgqueue.cc b/test/syscalls/linux/msgqueue.cc index 2409de7e8..837e913d9 100644 --- a/test/syscalls/linux/msgqueue.cc +++ b/test/syscalls/linux/msgqueue.cc @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <errno.h> #include <sys/ipc.h> #include <sys/msg.h> #include <sys/types.h> +#include "absl/time/clock.h" +#include "test/util/capability_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -23,6 +26,10 @@ namespace gvisor { namespace testing { namespace { +constexpr int msgMax = 8192; // Max size for message in bytes. +constexpr int msgMni = 32000; // Max number of identifiers. +constexpr int msgMnb = 16384; // Default max size of message queue in bytes. + // Queue is a RAII class used to automatically clean message queues. class Queue { public: @@ -46,6 +53,25 @@ class Queue { int id_ = -1; }; +// Default size for messages. +constexpr size_t msgSize = 50; + +// msgbuf is a simple buffer using to send and receive text messages for +// testing purposes. +struct msgbuf { + int64_t mtype; + char mtext[msgSize]; +}; + +bool operator==(msgbuf& a, msgbuf& b) { + for (size_t i = 0; i < msgSize; i++) { + if (a.mtext[i] != b.mtext[i]) { + return false; + } + } + return a.mtype == b.mtype; +} + // Test simple creation and retrieval for msgget(2). TEST(MsgqueueTest, MsgGet) { const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -82,6 +108,552 @@ TEST(MsgqueueTest, MsgGetIpcPrivate) { EXPECT_NE(queue1.get(), queue2.get()); } +// Test simple msgsnd and msgrcv. +TEST(MsgqueueTest, MsgOpSimple) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, "A message."}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); +} + +// Test msgsnd and msgrcv of an empty message. +TEST(MsgqueueTest, MsgOpEmpty) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, 0, 0), SyscallSucceeds()); + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0), + SyscallSucceedsWithValue(0)); +} + +// Test truncation of message with MSG_NOERROR flag. +TEST(MsgqueueTest, MsgOpTruncate) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) - 1, 0, MSG_NOERROR), + SyscallSucceedsWithValue(sizeof(buf.mtext) - 1)); +} + +// Test msgsnd and msgrcv using invalid arguments. +TEST(MsgqueueTest, MsgOpInvalidArgs) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + + EXPECT_THAT(msgsnd(-1, &buf, 0, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(msgsnd(queue.get(), &buf, -1, 0), SyscallFailsWithErrno(EINVAL)); + + buf.mtype = -1; + EXPECT_THAT(msgsnd(queue.get(), &buf, 1, 0), SyscallFailsWithErrno(EINVAL)); + + EXPECT_THAT(msgrcv(-1, &buf, 1, 0, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(msgrcv(queue.get(), &buf, -1, 0, 0), + SyscallFailsWithErrno(EINVAL)); +} + +// Test non-blocking msgrcv with an empty queue. +TEST(MsgqueueTest, MsgOpNoMsg) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(rcv.mtext) + 1, 0, IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); +} + +// Test non-blocking msgrcv with a non-empty queue, but no messages of wanted +// type. +TEST(MsgqueueTest, MsgOpNoMsgType) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext) + 1, 2, IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); +} + +// Test msgrcv with a larger size message than wanted, and truncation disabled. +TEST(MsgqueueTest, MsgOpTooBig) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext) - 1, 0, 0), + SyscallFailsWithErrno(E2BIG)); +} + +// Test receiving messages based on type. +TEST(MsgqueueTest, MsgRcvType) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Send messages in an order and receive them in reverse, based on type, + // which shouldn't block. + std::map<int64_t, msgbuf> typeToBuf = { + {1, msgbuf{1, "Message 1."}}, {2, msgbuf{2, "Message 2."}}, + {3, msgbuf{3, "Message 3."}}, {4, msgbuf{4, "Message 4."}}, + {5, msgbuf{5, "Message 5."}}, {6, msgbuf{6, "Message 6."}}, + {7, msgbuf{7, "Message 7."}}, {8, msgbuf{8, "Message 8."}}, + {9, msgbuf{9, "Message 9."}}}; + + for (auto const& [type, buf] : typeToBuf) { + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + for (int64_t i = typeToBuf.size(); i > 0; i--) { + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(typeToBuf[i].mtext) + 1, i, 0), + SyscallSucceedsWithValue(sizeof(typeToBuf[i].mtext))); + EXPECT_TRUE(typeToBuf[i] == rcv); + } +} + +// Test using MSG_EXCEPT to receive a different-type message. +TEST(MsgqueueTest, MsgExcept) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + std::map<int64_t, msgbuf> typeToBuf = { + {1, msgbuf{1, "Message 1."}}, + {2, msgbuf{2, "Message 2."}}, + }; + + for (auto const& [type, buf] : typeToBuf) { + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + for (int64_t i = typeToBuf.size(); i > 0; i--) { + msgbuf actual = typeToBuf[i == 1 ? 2 : 1]; + msgbuf rcv; + + EXPECT_THAT( + msgrcv(queue.get(), &rcv, sizeof(actual.mtext) + 1, i, MSG_EXCEPT), + SyscallSucceedsWithValue(sizeof(actual.mtext))); + EXPECT_TRUE(actual == rcv); + } +} + +// Test msgrcv with a negative type. +TEST(MsgqueueTest, MsgRcvTypeNegative) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // When msgtyp is negative, msgrcv returns the first message with mtype less + // than or equal to the absolute value. + msgbuf buf{2, "A message."}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + // Nothing is less than or equal to 1. + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, -1, IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, -3, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); +} + +// Test permission-related failure scenarios. +TEST(MsgqueueTest, MsgOpPermissions) { + AutoCapability cap(CAP_IPC_OWNER, false); + + Queue queue(msgget(IPC_PRIVATE, 0000)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext), 0, 0), + SyscallFailsWithErrno(EACCES)); +} + +// Test limits for messages and queues. +TEST(MsgqueueTest, MsgOpLimits) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, "A message."}; + + // Limit for one message. + EXPECT_THAT(msgsnd(queue.get(), &buf, msgMax + 1, 0), + SyscallFailsWithErrno(EINVAL)); + + // Limit for queue. + // Use a buffer with the maximum mount of bytes that can be transformed to + // make it easier to exhaust the queue limit. + struct msgmax { + int64_t mtype; + char mtext[msgMax]; + }; + + msgmax limit{1, ""}; + for (size_t i = 0, msgCount = msgMnb / msgMax; i < msgCount; i++) { + EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), 0), + SyscallSucceeds()); + } + EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), IPC_NOWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// MsgCopySupported returns true if MSG_COPY is supported. +bool MsgCopySupported() { + // msgrcv(2) man page states that MSG_COPY flag is available only if the + // kernel was built with the CONFIG_CHECKPOINT_RESTORE option. If MSG_COPY + // is used when the kernel was configured without the option, msgrcv produces + // a ENOSYS error. + // To avoid test failure, we perform a small test using msgrcv, and skip the + // test if errno == ENOSYS. This means that the test will always run on + // gVisor, but may be skipped on native linux. + + Queue queue(msgget(IPC_PRIVATE, 0600)); + + msgbuf buf{1, "Test message."}; + msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0); + + return !(msgrcv(queue.get(), &buf, sizeof(buf.mtext) + 1, 0, + MSG_COPY | IPC_NOWAIT) == -1 && + errno == ENOSYS); +} + +// Test msgrcv using MSG_COPY. +TEST(MsgqueueTest, MsgCopy) { + SKIP_IF(!MsgCopySupported()); + + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf bufs[5] = { + msgbuf{1, "Message 1."}, msgbuf{2, "Message 2."}, msgbuf{3, "Message 3."}, + msgbuf{4, "Message 4."}, msgbuf{5, "Message 5."}, + }; + + for (auto& buf : bufs) { + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + // Receive a copy of the messages. + for (size_t i = 0, size = sizeof(bufs) / sizeof(bufs[0]); i < size; i++) { + msgbuf buf = bufs[i]; + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, i, + MSG_COPY | IPC_NOWAIT), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); + } + + // Re-receive the messages normally. + for (auto& buf : bufs) { + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); + } +} + +// Test msgrcv using MSG_COPY with invalid arguments. +TEST(MsgqueueTest, MsgCopyInvalidArgs) { + SKIP_IF(!MsgCopySupported()); + + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, 1, MSG_COPY), + SyscallFailsWithErrno(EINVAL)); + + EXPECT_THAT( + msgrcv(queue.get(), &rcv, msgSize, 5, MSG_COPY | MSG_EXCEPT | IPC_NOWAIT), + SyscallFailsWithErrno(EINVAL)); +} + +// Test msgrcv using MSG_COPY with invalid indices. +TEST(MsgqueueTest, MsgCopyInvalidIndex) { + SKIP_IF(!MsgCopySupported()); + + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, -3, MSG_COPY | IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, 5, MSG_COPY | IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); +} + +// Test msgrcv (most probably) blocking on an empty queue. +TEST(MsgqueueTest, MsgRcvBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, "A message."}; + + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, + 0) == sizeof(buf.mtext) && + buf == rcv); + _exit(0); + } + + // Sleep to try and make msgrcv block before sending a message. + absl::SleepFor(absl::Milliseconds(150)); + + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test msgrcv (most probably) waiting for a specific-type message. +TEST(MsgqueueTest, MsgRcvTypeBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf bufs[5] = {{1, "A message."}, + {1, "A message."}, + {1, "A message."}, + {1, "A message."}, + {2, "A different message."}}; + + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf buf = bufs[4]; // Buffer that should be received. + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, 2, + 0) == sizeof(buf.mtext) && + buf == rcv); + _exit(0); + } + + // Sleep to try and make msgrcv block before sending messages. + absl::SleepFor(absl::Milliseconds(150)); + + // Send all buffers in order, only last one should be received. + for (auto& buf : bufs) { + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test msgsnd (most probably) blocking on a full queue. +TEST(MsgqueueTest, MsgSndBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Use a buffer with the maximum mount of bytes that can be transformed to + // make it easier to exhaust the queue limit. + struct msgmax { + int64_t mtype; + char mtext[msgMax]; + }; + + msgmax buf{1, ""}; // Has max amount of bytes. + + const size_t msgCount = msgMnb / msgMax; // Number of messages that can be + // sent without blocking. + + const pid_t child_pid = fork(); + if (child_pid == 0) { + // Fill the queue. + for (size_t i = 0; i < msgCount; i++) { + TEST_PCHECK(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0) == 0); + } + + // Next msgsnd should block. + TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) == + 0); + _exit(0); + } + + // To increase the chance of the last msgsnd blocking before doing a msgrcv, + // we use MSG_COPY option to copy the last index in the queue. As long as + // MSG_COPY fails, the queue hasn't yet been filled. When MSG_COPY succeeds, + // the queue is filled, and most probably, a blocking msgsnd has been made. + msgmax rcv; + while (msgrcv(queue.get(), &rcv, msgMax, msgCount - 1, + MSG_COPY | IPC_NOWAIT) == -1 && + errno == ENOMSG) { + } + + // Delay a bit more for the blocking msgsnd. + absl::SleepFor(absl::Milliseconds(100)); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext), 0, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test removing a queue while a blocking msgsnd is executing. +TEST(MsgqueueTest, MsgSndRmWhileBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Use a buffer with the maximum mount of bytes that can be transformed to + // make it easier to exhaust the queue limit. + struct msgmax { + int64_t mtype; + char mtext[msgMax]; + }; + + const size_t msgCount = msgMnb / msgMax; // Number of messages that can be + // sent without blocking. + const pid_t child_pid = fork(); + if (child_pid == 0) { + // Fill the queue. + msgmax buf{1, ""}; + for (size_t i = 0; i < msgCount; i++) { + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + // Next msgsnd should block. Because we're repeating on EINTR, msgsnd may + // race with msgctl(IPC_RMID) and return EINVAL. + TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) == + -1 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + // Similar to MsgSndBlocking, we do this to increase the chance of msgsnd + // blocking before removing the queue. + msgmax rcv; + while (msgrcv(queue.get(), &rcv, msgMax, msgCount - 1, + MSG_COPY | IPC_NOWAIT) == -1 && + errno == ENOMSG) { + } + absl::SleepFor(absl::Milliseconds(100)); + + EXPECT_THAT(msgctl(queue.release(), IPC_RMID, nullptr), SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test removing a queue while a blocking msgrcv is executing. +TEST(MsgqueueTest, MsgRcvRmWhileBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + const pid_t child_pid = fork(); + if (child_pid == 0) { + // Because we're repeating on EINTR, msgsnd may race with msgctl(IPC_RMID) + // and return EINVAL. + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, 1, 2, 0) == -1 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + // Sleep to try and make msgrcv block before sending messages. + absl::SleepFor(absl::Milliseconds(150)); + + EXPECT_THAT(msgctl(queue.release(), IPC_RMID, nullptr), SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test a collection of msgsnd/msgrcv operations in different processes. +TEST(MsgqueueTest, MsgOpGeneral) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Create 50 sending, and 50 receiving processes. There are only 5 messages to + // be sent and received, each with a different type. All messages will be sent + // and received equally (10 of each.) By the end of the test all processes + // should unblock and return normally. + const size_t msgCount = 5; + std::map<int64_t, msgbuf> typeToBuf = {{1, msgbuf{1, "Message 1."}}, + {2, msgbuf{2, "Message 2."}}, + {3, msgbuf{3, "Message 3."}}, + {4, msgbuf{4, "Message 4."}}, + {5, msgbuf{5, "Message 5."}}}; + + std::vector<pid_t> children; + + const size_t pCount = 50; + for (size_t i = 1; i <= pCount; i++) { + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf buf = typeToBuf[(i % msgCount) + 1]; + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, + (i % msgCount) + 1, + 0) == sizeof(buf.mtext) && + buf == rcv); + _exit(0); + } + children.push_back(child_pid); + } + + for (size_t i = 1; i <= pCount; i++) { + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf buf = typeToBuf[(i % msgCount) + 1]; + TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) == + 0); + _exit(0); + } + children.push_back(child_pid); + } + + for (auto const& pid : children) { + int status; + ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), + SyscallSucceedsWithValue(pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } +} + } // namespace } // namespace testing } // namespace gvisor |