diff options
-rw-r--r-- | pkg/sentry/kernel/msgqueue/msgqueue.go | 53 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 3 | ||||
-rw-r--r-- | test/syscalls/linux/msgqueue.cc | 138 |
3 files changed, 151 insertions, 43 deletions
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go index c111297d7..fab396d7c 100644 --- a/pkg/sentry/kernel/msgqueue/msgqueue.go +++ b/pkg/sentry/kernel/msgqueue/msgqueue.go @@ -208,11 +208,13 @@ func (r *Registry) FindByID(id ipc.ID) (*Queue, error) { // Send appends a message to the message queue, and returns an error if sending // fails. See msgsnd(2). -func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) { +func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) error { // Try to perform a non-blocking send using queue.append. If EWOULDBLOCK // is returned, start the blocking procedure. Otherwise, return normally. creds := auth.CredentialsFromContext(ctx) - if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + + // Fast path: first attempt a non-blocking push. + if err := q.push(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { return err } @@ -220,25 +222,30 @@ func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid i return linuxerr.EAGAIN } + // Slow path: at this point, the queue was found to be full, and we were + // asked to block. + e, ch := waiter.NewChannelEntry(nil) q.senders.EventRegister(&e, waiter.EventOut) + defer q.senders.EventUnregister(&e) + // Note: we need to check again before blocking the first time since space + // may have become available. for { - if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { - break + if err := q.push(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + return err + } + if err := b.Block(ch); err != nil { + return err } - b.Block(ch) } - - q.senders.EventUnregister(&e) - return err } -// append appends a message to the queue's message list and notifies waiting +// push appends a message to the queue's message list and notifies waiting // receivers that a message has been inserted. It returns an error if adding // the message would cause the queue to exceed its maximum capacity, which can // be used as a signal to block the task. Other errors should be returned as is. -func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error { +func (q *Queue) push(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error { if m.Type <= 0 { return linuxerr.EINVAL } @@ -295,15 +302,14 @@ func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, } // Receive removes a message from the queue and returns it. See msgrcv(2). -func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) { +func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (*Message, error) { if maxSize < 0 || maxSize > maxMessageBytes { return nil, linuxerr.EINVAL } max := uint64(maxSize) - - // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK - // is returned, start the blocking procedure. Otherwise, return normally. creds := auth.CredentialsFromContext(ctx) + + // Fast path: first attempt a non-blocking pop. if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { return msg, err } @@ -312,24 +318,30 @@ func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int return nil, linuxerr.ENOMSG } + // Slow path: at this point, the queue was found to be empty, and we were + // asked to block. + e, ch := waiter.NewChannelEntry(nil) q.receivers.EventRegister(&e, waiter.EventIn) + defer q.receivers.EventUnregister(&e) + // Note: we need to check again before blocking the first time since a + // message may have become available. for { - if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { - break + if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + return msg, err + } + if err := b.Block(ch); err != nil { + return nil, err } - b.Block(ch) } - q.receivers.EventUnregister(&e) - return msg, err } // pop pops the first message from the queue that matches the given type. It // returns an error for all the cases specified in msgrcv(2). If the queue is // empty or no message of the specified type is available, a EWOULDBLOCK error // is returned, which can then be used as a signal to block the process or fail. -func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) { +func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (*Message, error) { q.mu.Lock() defer q.mu.Unlock() @@ -350,6 +362,7 @@ func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, m } // Get a message from the queue. + var msg *Message switch { case mType == 0: msg = q.messages.Front() diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index b7e83b5de..71cc9b51d 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -4173,10 +4173,11 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + "//test/util:signal_util", "//test/util:temp_path", - "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], ) diff --git a/test/syscalls/linux/msgqueue.cc b/test/syscalls/linux/msgqueue.cc index cc4022077..8994b739a 100644 --- a/test/syscalls/linux/msgqueue.cc +++ b/test/syscalls/linux/msgqueue.cc @@ -13,12 +13,15 @@ // limitations under the License. #include <errno.h> +#include <signal.h> #include <sys/ipc.h> #include <sys/msg.h> #include <sys/types.h> +#include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "test/util/capability_util.h" +#include "test/util/signal_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -31,6 +34,8 @@ constexpr int msgMax = 8192; // Max size for message in bytes. constexpr int msgMni = 32000; // Max number of identifiers. constexpr int msgMnb = 16384; // Default max size of message queue in bytes. +constexpr int kInterruptSignal = SIGALRM; + // Queue is a RAII class used to automatically clean message queues. class Queue { public: @@ -73,6 +78,12 @@ bool operator==(msgbuf& a, msgbuf& b) { return a.mtype == b.mtype; } +// msgmax represents a buffer for the largest possible single message. +struct msgmax { + int64_t mtype; + char mtext[msgMax]; +}; + // Test simple creation and retrieval for msgget(2). TEST(MsgqueueTest, MsgGet) { const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -310,13 +321,6 @@ TEST(MsgqueueTest, MsgOpLimits) { SyscallFailsWithErrno(EINVAL)); // Limit for queue. - // Use a buffer with the maximum mount of bytes that can be transformed to - // make it easier to exhaust the queue limit. - struct msgmax { - int64_t mtype; - char mtext[msgMax]; - }; - msgmax limit{1, ""}; for (size_t i = 0, msgCount = msgMnb / msgMax; i < msgCount; i++) { EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), 0), @@ -470,13 +474,6 @@ TEST(MsgqueueTest, MsgSndBlocking) { Queue queue(msgget(IPC_PRIVATE, 0600)); ASSERT_THAT(queue.get(), SyscallSucceeds()); - // Use a buffer with the maximum mount of bytes that can be transformed to - // make it easier to exhaust the queue limit. - struct msgmax { - int64_t mtype; - char mtext[msgMax]; - }; - msgmax buf{1, ""}; // Has max amount of bytes. const size_t msgCount = msgMnb / msgMax; // Number of messages that can be @@ -494,6 +491,8 @@ TEST(MsgqueueTest, MsgSndBlocking) { SyscallSucceeds()); }); + const DisableSave ds; // Too many syscalls. + // To increase the chance of the last msgsnd blocking before doing a msgrcv, // we use MSG_COPY option to copy the last index in the queue. As long as // MSG_COPY fails, the queue hasn't yet been filled. When MSG_COPY succeeds, @@ -516,15 +515,9 @@ TEST(MsgqueueTest, MsgSndRmWhileBlocking) { Queue queue(msgget(IPC_PRIVATE, 0600)); ASSERT_THAT(queue.get(), SyscallSucceeds()); - // Use a buffer with the maximum mount of bytes that can be transformed to - // make it easier to exhaust the queue limit. - struct msgmax { - int64_t mtype; - char mtext[msgMax]; - }; + // Number of messages that can be sent without blocking. + const size_t msgCount = msgMnb / msgMax; - const size_t msgCount = msgMnb / msgMax; // Number of messages that can be - // sent without blocking. ScopedThread t([&] { // Fill the queue. msgmax buf{1, ""}; @@ -540,6 +533,8 @@ TEST(MsgqueueTest, MsgSndRmWhileBlocking) { EXPECT_TRUE((errno == EIDRM || errno == EINVAL)); }); + const DisableSave ds; // Too many syscalls. + // Similar to MsgSndBlocking, we do this to increase the chance of msgsnd // blocking before removing the queue. msgmax rcv; @@ -627,6 +622,105 @@ TEST(MsgqueueTest, MsgOpGeneral) { ScopedThread s10(sender(4)); } +void empty_sighandler(int sig, siginfo_t* info, void* context) {} + +TEST(MsgqueueTest, InterruptRecv) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + char buf[64]; + + absl::Notification done, exit; + + // Thread calling msgrcv with no corresponding send. It would block forever, + // but we'll interrupt with a signal below. + ScopedThread t([&] { + struct sigaction sa = {}; + sa.sa_sigaction = empty_sighandler; + sigfillset(&sa.sa_mask); + sa.sa_flags = SA_SIGINFO; + auto cleanup_sigaction = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kInterruptSignal, sa)); + auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE( + ScopedSignalMask(SIG_UNBLOCK, kInterruptSignal)); + + EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf), 0, 0), + SyscallFailsWithErrno(EINTR)); + + done.Notify(); + exit.WaitForNotification(); + }); + + const DisableSave ds; // Too many syscalls. + + // We want the signal to arrive while msgrcv is blocking, but not after the + // thread has exited. Signals that arrive before msgrcv are no-ops. + do { + EXPECT_THAT(kill(getpid(), kInterruptSignal), SyscallSucceeds()); + absl::SleepFor(absl::Milliseconds(100)); // Rate limit. + } while (!done.HasBeenNotified()); + + exit.Notify(); + t.Join(); +} + +TEST(MsgqueueTest, InterruptSend) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + msgmax buf{1, ""}; + // Number of messages that can be sent without blocking. + const size_t msgCount = msgMnb / msgMax; + + // Fill the queue. + for (size_t i = 0; i < msgCount; i++) { + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + absl::Notification done, exit; + + // Thread calling msgsnd on a full queue. It would block forever, but we'll + // interrupt with a signal below. + ScopedThread t([&] { + struct sigaction sa = {}; + sa.sa_sigaction = empty_sighandler; + sigfillset(&sa.sa_mask); + sa.sa_flags = SA_SIGINFO; + auto cleanup_sigaction = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kInterruptSignal, sa)); + auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE( + ScopedSignalMask(SIG_UNBLOCK, kInterruptSignal)); + + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallFailsWithErrno(EINTR)); + + done.Notify(); + exit.WaitForNotification(); + }); + + const DisableSave ds; // Too many syscalls. + + // We want the signal to arrive while msgsnd is blocking, but not after the + // thread has exited. Signals that arrive before msgsnd are no-ops. + do { + EXPECT_THAT(kill(getpid(), kInterruptSignal), SyscallSucceeds()); + absl::SleepFor(absl::Milliseconds(100)); // Rate limit. + } while (!done.HasBeenNotified()); + + exit.Notify(); + t.Join(); +} + } // namespace } // namespace testing } // namespace gvisor + +int main(int argc, char** argv) { + // Some tests depend on delivering a signal to the main thread. Block the + // target signal so that any other threads created by TestInit will also have + // the signal blocked. + sigset_t set; + sigemptyset(&set); + sigaddset(&set, gvisor::testing::kInterruptSignal); + TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); + + gvisor::testing::TestInit(&argc, &argv); + return gvisor::testing::RunAllTests(); +} |