summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue.go53
-rw-r--r--test/syscalls/linux/BUILD3
-rw-r--r--test/syscalls/linux/msgqueue.cc138
3 files changed, 151 insertions, 43 deletions
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go
index c111297d7..fab396d7c 100644
--- a/pkg/sentry/kernel/msgqueue/msgqueue.go
+++ b/pkg/sentry/kernel/msgqueue/msgqueue.go
@@ -208,11 +208,13 @@ func (r *Registry) FindByID(id ipc.ID) (*Queue, error) {
// Send appends a message to the message queue, and returns an error if sending
// fails. See msgsnd(2).
-func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) {
+func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) error {
// Try to perform a non-blocking send using queue.append. If EWOULDBLOCK
// is returned, start the blocking procedure. Otherwise, return normally.
creds := auth.CredentialsFromContext(ctx)
- if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
+
+ // Fast path: first attempt a non-blocking push.
+ if err := q.push(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
return err
}
@@ -220,25 +222,30 @@ func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid i
return linuxerr.EAGAIN
}
+ // Slow path: at this point, the queue was found to be full, and we were
+ // asked to block.
+
e, ch := waiter.NewChannelEntry(nil)
q.senders.EventRegister(&e, waiter.EventOut)
+ defer q.senders.EventUnregister(&e)
+ // Note: we need to check again before blocking the first time since space
+ // may have become available.
for {
- if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
- break
+ if err := q.push(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
+ return err
+ }
+ if err := b.Block(ch); err != nil {
+ return err
}
- b.Block(ch)
}
-
- q.senders.EventUnregister(&e)
- return err
}
-// append appends a message to the queue's message list and notifies waiting
+// push appends a message to the queue's message list and notifies waiting
// receivers that a message has been inserted. It returns an error if adding
// the message would cause the queue to exceed its maximum capacity, which can
// be used as a signal to block the task. Other errors should be returned as is.
-func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error {
+func (q *Queue) push(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error {
if m.Type <= 0 {
return linuxerr.EINVAL
}
@@ -295,15 +302,14 @@ func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials,
}
// Receive removes a message from the queue and returns it. See msgrcv(2).
-func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) {
+func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (*Message, error) {
if maxSize < 0 || maxSize > maxMessageBytes {
return nil, linuxerr.EINVAL
}
max := uint64(maxSize)
-
- // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK
- // is returned, start the blocking procedure. Otherwise, return normally.
creds := auth.CredentialsFromContext(ctx)
+
+ // Fast path: first attempt a non-blocking pop.
if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
return msg, err
}
@@ -312,24 +318,30 @@ func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int
return nil, linuxerr.ENOMSG
}
+ // Slow path: at this point, the queue was found to be empty, and we were
+ // asked to block.
+
e, ch := waiter.NewChannelEntry(nil)
q.receivers.EventRegister(&e, waiter.EventIn)
+ defer q.receivers.EventUnregister(&e)
+ // Note: we need to check again before blocking the first time since a
+ // message may have become available.
for {
- if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
- break
+ if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
+ return msg, err
+ }
+ if err := b.Block(ch); err != nil {
+ return nil, err
}
- b.Block(ch)
}
- q.receivers.EventUnregister(&e)
- return msg, err
}
// pop pops the first message from the queue that matches the given type. It
// returns an error for all the cases specified in msgrcv(2). If the queue is
// empty or no message of the specified type is available, a EWOULDBLOCK error
// is returned, which can then be used as a signal to block the process or fail.
-func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) {
+func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (*Message, error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -350,6 +362,7 @@ func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, m
}
// Get a message from the queue.
+ var msg *Message
switch {
case mType == 0:
msg = q.messages.Front()
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index b7e83b5de..71cc9b51d 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -4173,10 +4173,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "//test/util:signal_util",
"//test/util:temp_path",
- "//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)
diff --git a/test/syscalls/linux/msgqueue.cc b/test/syscalls/linux/msgqueue.cc
index cc4022077..8994b739a 100644
--- a/test/syscalls/linux/msgqueue.cc
+++ b/test/syscalls/linux/msgqueue.cc
@@ -13,12 +13,15 @@
// limitations under the License.
#include <errno.h>
+#include <signal.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
+#include "absl/synchronization/notification.h"
#include "absl/time/clock.h"
#include "test/util/capability_util.h"
+#include "test/util/signal_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -31,6 +34,8 @@ constexpr int msgMax = 8192; // Max size for message in bytes.
constexpr int msgMni = 32000; // Max number of identifiers.
constexpr int msgMnb = 16384; // Default max size of message queue in bytes.
+constexpr int kInterruptSignal = SIGALRM;
+
// Queue is a RAII class used to automatically clean message queues.
class Queue {
public:
@@ -73,6 +78,12 @@ bool operator==(msgbuf& a, msgbuf& b) {
return a.mtype == b.mtype;
}
+// msgmax represents a buffer for the largest possible single message.
+struct msgmax {
+ int64_t mtype;
+ char mtext[msgMax];
+};
+
// Test simple creation and retrieval for msgget(2).
TEST(MsgqueueTest, MsgGet) {
const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
@@ -310,13 +321,6 @@ TEST(MsgqueueTest, MsgOpLimits) {
SyscallFailsWithErrno(EINVAL));
// Limit for queue.
- // Use a buffer with the maximum mount of bytes that can be transformed to
- // make it easier to exhaust the queue limit.
- struct msgmax {
- int64_t mtype;
- char mtext[msgMax];
- };
-
msgmax limit{1, ""};
for (size_t i = 0, msgCount = msgMnb / msgMax; i < msgCount; i++) {
EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), 0),
@@ -470,13 +474,6 @@ TEST(MsgqueueTest, MsgSndBlocking) {
Queue queue(msgget(IPC_PRIVATE, 0600));
ASSERT_THAT(queue.get(), SyscallSucceeds());
- // Use a buffer with the maximum mount of bytes that can be transformed to
- // make it easier to exhaust the queue limit.
- struct msgmax {
- int64_t mtype;
- char mtext[msgMax];
- };
-
msgmax buf{1, ""}; // Has max amount of bytes.
const size_t msgCount = msgMnb / msgMax; // Number of messages that can be
@@ -494,6 +491,8 @@ TEST(MsgqueueTest, MsgSndBlocking) {
SyscallSucceeds());
});
+ const DisableSave ds; // Too many syscalls.
+
// To increase the chance of the last msgsnd blocking before doing a msgrcv,
// we use MSG_COPY option to copy the last index in the queue. As long as
// MSG_COPY fails, the queue hasn't yet been filled. When MSG_COPY succeeds,
@@ -516,15 +515,9 @@ TEST(MsgqueueTest, MsgSndRmWhileBlocking) {
Queue queue(msgget(IPC_PRIVATE, 0600));
ASSERT_THAT(queue.get(), SyscallSucceeds());
- // Use a buffer with the maximum mount of bytes that can be transformed to
- // make it easier to exhaust the queue limit.
- struct msgmax {
- int64_t mtype;
- char mtext[msgMax];
- };
+ // Number of messages that can be sent without blocking.
+ const size_t msgCount = msgMnb / msgMax;
- const size_t msgCount = msgMnb / msgMax; // Number of messages that can be
- // sent without blocking.
ScopedThread t([&] {
// Fill the queue.
msgmax buf{1, ""};
@@ -540,6 +533,8 @@ TEST(MsgqueueTest, MsgSndRmWhileBlocking) {
EXPECT_TRUE((errno == EIDRM || errno == EINVAL));
});
+ const DisableSave ds; // Too many syscalls.
+
// Similar to MsgSndBlocking, we do this to increase the chance of msgsnd
// blocking before removing the queue.
msgmax rcv;
@@ -627,6 +622,105 @@ TEST(MsgqueueTest, MsgOpGeneral) {
ScopedThread s10(sender(4));
}
+void empty_sighandler(int sig, siginfo_t* info, void* context) {}
+
+TEST(MsgqueueTest, InterruptRecv) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ char buf[64];
+
+ absl::Notification done, exit;
+
+ // Thread calling msgrcv with no corresponding send. It would block forever,
+ // but we'll interrupt with a signal below.
+ ScopedThread t([&] {
+ struct sigaction sa = {};
+ sa.sa_sigaction = empty_sighandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ auto cleanup_sigaction =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kInterruptSignal, sa));
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ScopedSignalMask(SIG_UNBLOCK, kInterruptSignal));
+
+ EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf), 0, 0),
+ SyscallFailsWithErrno(EINTR));
+
+ done.Notify();
+ exit.WaitForNotification();
+ });
+
+ const DisableSave ds; // Too many syscalls.
+
+ // We want the signal to arrive while msgrcv is blocking, but not after the
+ // thread has exited. Signals that arrive before msgrcv are no-ops.
+ do {
+ EXPECT_THAT(kill(getpid(), kInterruptSignal), SyscallSucceeds());
+ absl::SleepFor(absl::Milliseconds(100)); // Rate limit.
+ } while (!done.HasBeenNotified());
+
+ exit.Notify();
+ t.Join();
+}
+
+TEST(MsgqueueTest, InterruptSend) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ msgmax buf{1, ""};
+ // Number of messages that can be sent without blocking.
+ const size_t msgCount = msgMnb / msgMax;
+
+ // Fill the queue.
+ for (size_t i = 0; i < msgCount; i++) {
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+ }
+
+ absl::Notification done, exit;
+
+ // Thread calling msgsnd on a full queue. It would block forever, but we'll
+ // interrupt with a signal below.
+ ScopedThread t([&] {
+ struct sigaction sa = {};
+ sa.sa_sigaction = empty_sighandler;
+ sigfillset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ auto cleanup_sigaction =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kInterruptSignal, sa));
+ auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ScopedSignalMask(SIG_UNBLOCK, kInterruptSignal));
+
+ EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallFailsWithErrno(EINTR));
+
+ done.Notify();
+ exit.WaitForNotification();
+ });
+
+ const DisableSave ds; // Too many syscalls.
+
+ // We want the signal to arrive while msgsnd is blocking, but not after the
+ // thread has exited. Signals that arrive before msgsnd are no-ops.
+ do {
+ EXPECT_THAT(kill(getpid(), kInterruptSignal), SyscallSucceeds());
+ absl::SleepFor(absl::Milliseconds(100)); // Rate limit.
+ } while (!done.HasBeenNotified());
+
+ exit.Notify();
+ t.Join();
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // Some tests depend on delivering a signal to the main thread. Block the
+ // target signal so that any other threads created by TestInit will also have
+ // the signal blocked.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, gvisor::testing::kInterruptSignal);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+ return gvisor::testing::RunAllTests();
+}