diff options
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_socket.go | 21 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 74 |
2 files changed, 95 insertions, 0 deletions
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 1513f28e7..564357bac 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -57,6 +57,10 @@ const nameLenOffset = 8 // to the ControlLen field. const controlLenOffset = 40 +// flagsOffset is the offset form the start of the MessageHeader64 struct +// to the Flags field. +const flagsOffset = 48 + // messageHeader64Len is the length of a MessageHeader64 struct. var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) @@ -743,6 +747,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) } cms.Unix.Release() + + if msg.Flags != 0 { + // Copy out the flags to the caller. + // + // TODO: Plumb through actual flags. + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(0)); err != nil { + return 0, err + } + } + return uintptr(n), nil } @@ -787,6 +801,13 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } } + // Copy out the flags to the caller. + // + // TODO: Plumb through actual flags. + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(0)); err != nil { + return 0, err + } + return uintptr(n), nil } diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index 974c0dd7b..c83fb82fe 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -183,6 +183,80 @@ TEST_P(AllSocketPairTest, SendmsgRecvmsg16KB) { memcmp(sent_data.data(), received_data.data(), sent_data.size())); } +TEST_P(AllSocketPairTest, RecvmsgMsghdrFlagsNotClearedOnFailure) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char received_data[10] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + + // Check that msghdr flags were not changed. + EXPECT_EQ(msg.msg_flags, -1); +} + +TEST_P(AllSocketPairTest, RecvmsgMsghdrFlagsCleared) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[sizeof(sent_data)] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data))); + + // Check that msghdr flags were cleared. + EXPECT_EQ(msg.msg_flags, 0); +} + +TEST_P(AllSocketPairTest, RecvmsgPeekMsghdrFlagsCleared) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[sizeof(sent_data)] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_PEEK), + SyscallSucceedsWithValue(sizeof(sent_data))); + EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(sent_data))); + + // Check that msghdr flags were cleared. + EXPECT_EQ(msg.msg_flags, 0); +} + TEST_P(AllSocketPairTest, RecvmmsgInvalidTimeout) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char buf[10]; |