diff options
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 67 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 1 | ||||
-rw-r--r-- | test/syscalls/linux/udp_socket.cc | 89 |
3 files changed, 140 insertions, 17 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 3392ac645..a97db5348 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -157,15 +157,17 @@ type SocketOperations struct { // from Endpoint. readCM tcpip.ControlMessages sender tcpip.FullAddress + // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When // false, the same timestamp is instead stored and can be read via the - // SIOCGSTAMP ioctl. See socket(7). + // SIOCGSTAMP ioctl. It is protected by readMu. See socket(7). sockOptTimestamp bool - // timestampValid indicates whether timestamp has been set. + // timestampValid indicates whether timestamp for SIOCGSTAMP has been + // set. It is protected by readMu. timestampValid bool - // timestampNS holds the timestamp to use with SIOCGSTAMP. It is only - // valid when timestampValid is true. + // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only + // valid when timestampValid is true. It is protected by readMu. timestampNS int64 } @@ -266,7 +268,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { } func (s *SocketOperations) isPacketBased() bool { - return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM + return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } // fetchReadView updates the readView field of the socket if it's currently @@ -1480,6 +1482,8 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy // coalescingRead is the fast path for non-blocking, non-peek, stream-based // case. It coalesces as many packets as possible before returning to the // caller. +// +// Precondition: s.readMu must be locked. func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { var err *syserr.Error var copied int @@ -1501,9 +1505,8 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } else { n, e = dst.CopyOut(ctx, s.readView) // Set the control message, even if 0 bytes were read. - if e == nil && s.readCM.HasTimestamp && s.sockOptTimestamp { - s.timestampNS = s.readCM.Timestamp - s.timestampValid = true + if e == nil { + s.updateTimestamp() } } copied += n @@ -1569,9 +1572,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe n, err := dst.CopyOut(ctx, s.readView) // Set the control message, even if 0 bytes were read. - if err == nil && s.readCM.HasTimestamp && s.sockOptTimestamp { - s.timestampNS = s.readCM.Timestamp - s.timestampValid = true + if err == nil { + s.updateTimestamp() } var addr interface{} var addrLen uint32 @@ -1582,11 +1584,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if peek { if l := len(s.readView); trunc && l > n { // isPacket must be true. - return l, addr, addrLen, socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.timestampValid, Timestamp: s.timestampNS}}, syserr.FromError(err) + return l, addr, addrLen, s.controlMessages(), syserr.FromError(err) } if isPacket || err != nil { - return int(n), addr, addrLen, socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.timestampValid, Timestamp: s.timestampNS}}, syserr.FromError(err) + return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err) } // We need to peek beyond the first message. @@ -1604,7 +1606,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // We got some data, so no need to return an error. err = nil } - return int(n), nil, 0, socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.timestampValid, Timestamp: s.timestampNS}}, syserr.FromError(err) + return int(n), nil, 0, s.controlMessages(), syserr.FromError(err) } var msgLen int @@ -1617,10 +1619,26 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe } if trunc { - return msgLen, addr, addrLen, socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.timestampValid, Timestamp: s.timestampNS}}, syserr.FromError(err) + return msgLen, addr, addrLen, s.controlMessages(), syserr.FromError(err) } - return int(n), addr, addrLen, socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.timestampValid, Timestamp: s.timestampNS}}, syserr.FromError(err) + return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err) +} + +func (s *SocketOperations) controlMessages() socket.ControlMessages { + return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}} +} + +// updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after +// successfully writing packet data out to userspace. +// +// Precondition: s.readMu must be locked. +func (s *SocketOperations) updateTimestamp() { + // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. + if !s.sockOptTimestamp { + s.timestampValid = true + s.timestampNS = s.readCM.Timestamp + } } // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by @@ -1771,6 +1789,23 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint + // sockets. + // TODO: Add a commonEndpoint method to support SIOCGSTAMP. + if int(args[1].Int()) == syscall.SIOCGSTAMP { + s.readMu.Lock() + defer s.readMu.Unlock() + if !s.timestampValid { + return 0, syserror.ENOENT + } + + tv := linux.NsecToTimeval(s.timestampNS) + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + } + return Ioctl(ctx, s.Endpoint, io, args) } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index e7f5ea998..9da5204c1 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2907,6 +2907,7 @@ cc_binary( linkstatic = 1, deps = [ ":socket_test_util", + ":unix_domain_socket_test_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index a02b418a3..38dfd0ad0 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -24,6 +24,7 @@ #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -892,12 +893,21 @@ TEST_P(UdpSocketTest, ErrorQueue) { SyscallFailsWithErrno(EAGAIN)); } +TEST_P(UdpSocketTest, SoTimestampOffByDefault) { + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + TEST_P(UdpSocketTest, SoTimestamp) { ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); int v = 1; - EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), SyscallSucceeds()); char buf[3]; @@ -926,12 +936,89 @@ TEST_P(UdpSocketTest, SoTimestamp) { memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // There should be nothing to get via ioctl. + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); } TEST_P(UdpSocketTest, WriteShutdownNotConnected) { EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); } +TEST_P(UdpSocketTest, TimestampIoctl) { + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from t_ to s_. + ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); +} + +TEST_P(UdpSocketTest, TimetstampIoctlNothingRead) { + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + struct timeval tv = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); +} + +// Test that the timestamp accessed via SIOCGSTAMP is still accessible after +// SO_TIMESTAMP is enabled and used to retrieve a timestamp. +TEST_P(UdpSocketTest, TimestampIoctlPersistence) { + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from t_ to s_. + ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // Enable SO_TIMESTAMP and send a message. + int v = 1; + EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + + // There should be a message for SO_TIMESTAMP. + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg = {}; + iovec iov = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + + // The ioctl should return the exact same values as before. + struct timeval tv2 = {}; + ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv2), SyscallSucceeds()); + ASSERT_EQ(tv.tv_sec, tv2.tv_sec); + ASSERT_EQ(tv.tv_usec, tv2.tv_usec); +} + INSTANTIATE_TEST_CASE_P(AllInetTests, UdpSocketTest, ::testing::Values(AF_INET, AF_INET6)); |