From eefa817cfdb04ff07e7069396f21bd6ba2c89957 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Thu, 18 Jul 2019 15:39:47 -0700 Subject: net/tcp/setockopt: impelment setsockopt(fd, SOL_TCP, TCP_INQ) PiperOrigin-RevId: 258859507 --- pkg/sentry/socket/control/control.go | 16 +++- pkg/sentry/socket/epsocket/epsocket.go | 61 ++++++++++++-- pkg/sentry/socket/unix/transport/unix.go | 38 +++++---- pkg/sentry/syscalls/linux/sys_socket.go | 5 ++ pkg/tcpip/stack/transport_test.go | 5 ++ pkg/tcpip/tcpip.go | 26 +++++- pkg/tcpip/transport/icmp/endpoint.go | 27 +++--- pkg/tcpip/transport/raw/endpoint.go | 28 ++++--- pkg/tcpip/transport/tcp/endpoint.go | 18 ++-- pkg/tcpip/transport/udp/endpoint.go | 28 ++++--- test/syscalls/linux/tcp_socket.cc | 138 +++++++++++++++++++++++++++++++ 11 files changed, 319 insertions(+), 71 deletions(-) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4f4a20dfe..4e95101b7 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -225,14 +225,14 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ return alignSlice(buf, align), flags } -func putCmsgStruct(buf []byte, msgType uint32, align uint, data interface{}) []byte { +func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte { if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader { return buf } ob := buf buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader)) - buf = putUint32(buf, linux.SOL_SOCKET) + buf = putUint32(buf, msgLevel) buf = putUint32(buf, msgType) hdrBuf := buf @@ -307,12 +307,24 @@ func alignSlice(buf []byte, align uint) []byte { func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { return putCmsgStruct( buf, + linux.SOL_SOCKET, linux.SO_TIMESTAMP, t.Arch().Width(), linux.NsecToTimeval(timestamp), ) } +// PackInq packs a TCP_INQ socket control message. +func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_TCP, + linux.TCP_INQ, + 4, + inq, + ) +} + // Parse parses a raw socket control message into portable objects. func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) { var ( diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 69eff7373..e57aed927 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -207,6 +207,10 @@ type commonEndpoint interface { // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and + // transport.Endpoint.GetSockOpt. + GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) } // SocketOperations encapsulates all the state needed to represent a network stack @@ -249,6 +253,10 @@ type SocketOperations struct { // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only // valid when timestampValid is true. It is protected by readMu. timestampNS int64 + + // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket + // level, because it takes into account data from readView. + sockOptInq bool } // New creates a new endpoint socket. @@ -634,6 +642,18 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( } return val, nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptInq { + val = 1 + } + return val, nil + } return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) } @@ -1048,6 +1068,15 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 return nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } @@ -1267,6 +1296,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * return syserr.TranslateNetstackError(err) } return nil + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1492,7 +1522,6 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_INQ, linux.TCP_KEEPCNT, linux.TCP_KEEPIDLE, linux.TCP_KEEPINTVL, @@ -1747,6 +1776,18 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq return 0, err } +func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { + if !s.sockOptInq { + return + } + rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { + return + } + cmsg.IP.HasInq = true + cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -1766,7 +1807,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) s.readMu.Unlock() - return n, 0, nil, 0, socket.ControlMessages{}, err + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, 0, nil, 0, cmsg, err } s.readMu.Lock() @@ -1779,8 +1822,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if !isPacket && peek && trunc { // MSG_TRUNC with MSG_PEEK on a TCP socket returns the // amount that could be read. - var rql tcpip.ReceiveQueueSizeOption - if err := s.Endpoint.GetSockOpt(&rql); err != nil { + rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } available := len(s.readView) + int(rql) @@ -1848,7 +1891,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe n = msgLen } - return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err) + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } func (s *SocketOperations) controlMessages() socket.ControlMessages { @@ -2086,9 +2131,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCINQ: - var v tcpip.ReceiveQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index b0765ba55..7fb9cb1e0 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -179,6 +179,10 @@ type Endpoint interface { // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOptInt gets a socket option for simple cases when a return + // value has the int type. + GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + // State returns the current state of the socket, as represented by Linux in // procfs. State() uint32 @@ -834,33 +838,39 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.SendQueueSizeOption: +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v = int(e.receiver.RecvQueuedSize()) e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs + return v, nil + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: return nil - case *tcpip.ReceiveQueueSizeOption: + case *tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() return tcpip.ErrNotConnected } - qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize()) + qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) e.Unlock() if qs < 0 { return tcpip.ErrQueueSizeNotSupported diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 195734257..fa568a660 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -802,6 +802,11 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData) } + if cms.IP.HasInq { + // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. + controlData = control.PackInq(t, cms.IP.Inq, controlData) + } + if cms.Unix.Rights != nil { controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags) } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 788ffcc8c..b418db046 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -90,6 +90,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + return -1, tcpip.ErrUnknownProtocolOption +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch opt.(type) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c4076666a..c5d79da5e 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -288,6 +288,12 @@ type ControlMessages struct { // Timestamp is the time (in ns) that the last packed used to create // the read data was received. Timestamp int64 + + // HasInq indicates whether Inq is valid/set. + HasInq bool + + // Inq is the number of bytes ready to be received. + Inq int32 } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -383,6 +389,10 @@ type Endpoint interface { // *Option types. GetSockOpt(opt interface{}) *Error + // GetSockOptInt gets a socket option for simple cases where a return + // value has the int type. + GetSockOptInt(SockOpt) (int, *Error) + // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. State() uint32 @@ -408,6 +418,18 @@ type WriteOptions struct { EndOfRecord bool } +// SockOpt represents socket options which values have the int type. +type SockOpt int + +const ( + // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of + // unread bytes in the input buffer should be returned. + ReceiveQueueSizeOption SockOpt = iota + + // TODO(b/137664753): convert all int socket options to be handled via + // GetSockOptInt. +) + // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} @@ -424,10 +446,6 @@ type ReceiveBufferSizeOption int // unread bytes in the output buffer should be returned. type SendQueueSizeOption int -// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of -// unread bytes in the input buffer should be returned. -type ReceiveQueueSizeOption int - // V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a80ceafd0..ba6671c26 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -314,6 +314,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() + v = p.data.Size() + } + e.rcvMu.Unlock() + return v, nil + } + return -1, tcpip.ErrUnknownProtocolOption +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { @@ -332,17 +348,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.rcvMu.Unlock() return nil - case *tcpip.ReceiveQueueSizeOption: - e.rcvMu.Lock() - if e.rcvList.Empty() { - *o = 0 - } else { - p := e.rcvList.Front() - *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) - } - e.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index a29587658..b633cd9d8 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -487,6 +487,23 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + } + + return -1, tcpip.ErrUnknownProtocolOption +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { @@ -505,17 +522,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { ep.rcvMu.Unlock() return nil - case *tcpip.ReceiveQueueSizeOption: - ep.rcvMu.Lock() - if ep.rcvList.Empty() { - *o = 0 - } else { - p := ep.rcvList.Front() - *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) - } - ep.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index beb90afb5..89154391b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1100,6 +1100,15 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + return e.readyReceiveSize() + } + return -1, tcpip.ErrUnknownProtocolOption +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { @@ -1130,15 +1139,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.rcvListMu.Unlock() return nil - case *tcpip.ReceiveQueueSizeOption: - v, err := e.readyReceiveSize() - if err != nil { - return err - } - - *o = tcpip.ReceiveQueueSizeOption(v) - return nil - case *tcpip.DelayOption: *o = 0 if v := atomic.LoadUint32(&e.delay); v != 0 { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cb0ea83a6..70f4a2b8c 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -189,7 +189,6 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess p := e.rcvList.Front() e.rcvList.Remove(p) e.rcvBufSize -= p.data.Size() - e.rcvMu.Unlock() if addr != nil { @@ -539,6 +538,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() + v = p.data.Size() + } + e.rcvMu.Unlock() + return v, nil + } + return -1, tcpip.ErrUnknownProtocolOption +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { @@ -573,17 +588,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil - case *tcpip.ReceiveQueueSizeOption: - e.rcvMu.Lock() - if e.rcvList.Empty() { - *o = 0 - } else { - p := e.rcvList.Front() - *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) - } - e.rcvMu.Unlock() - return nil - case *tcpip.MulticastTTLOption: e.mu.Lock() *o = tcpip.MulticastTTLOption(e.multicastTTL) diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 8d77431f2..77aab1e7d 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -520,6 +521,143 @@ TEST_P(TcpSocketTest, SetNoDelay) { EXPECT_EQ(get, kSockOptOff); } +#ifndef TCP_INQ +#define TCP_INQ 36 +#endif + +TEST_P(TcpSocketTest, TcpInqSetSockOpt) { + char buf[1024]; + ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + // TCP_INQ is disabled by default. + int val = -1; + socklen_t slen = sizeof(val); + EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen), + SyscallSucceedsWithValue(0)); + ASSERT_EQ(val, 0); + + // Try to set TCP_INQ. + val = 1; + EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), + SyscallSucceedsWithValue(0)); + val = -1; + slen = sizeof(val); + EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen), + SyscallSucceedsWithValue(0)); + ASSERT_EQ(val, 1); + + // Try to unset TCP_INQ. + val = 0; + EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), + SyscallSucceedsWithValue(0)); + val = -1; + slen = sizeof(val); + EXPECT_THAT(getsockopt(t_, SOL_TCP, TCP_INQ, &val, &slen), + SyscallSucceedsWithValue(0)); + ASSERT_EQ(val, 0); +} + +TEST_P(TcpSocketTest, TcpInq) { + char buf[1024]; + // Write more than one TCP segment. + int size = sizeof(buf); + int kChunk = sizeof(buf) / 4; + for (int i = 0; i < size; i += kChunk) { + ASSERT_THAT(RetryEINTR(write)(s_, buf, kChunk), + SyscallSucceedsWithValue(kChunk)); + } + + int val = 1; + kChunk = sizeof(buf) / 2; + EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), + SyscallSucceedsWithValue(0)); + + // Wait when all data will be in the received queue. + while (true) { + ASSERT_THAT(ioctl(t_, TIOCINQ, &size), SyscallSucceeds()); + if (size == sizeof(buf)) { + break; + } + usleep(10000); + } + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(sizeof(int))); + size = sizeof(buf); + struct iovec iov; + for (int i = 0; size != 0; i += kChunk) { + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + iov.iov_base = buf; + iov.iov_len = kChunk; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + ASSERT_THAT(RetryEINTR(recvmsg)(t_, &msg, 0), + SyscallSucceedsWithValue(kChunk)); + size -= kChunk; + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + ASSERT_EQ(cmsg->cmsg_level, SOL_TCP); + ASSERT_EQ(cmsg->cmsg_type, TCP_INQ); + + int inq = 0; + memcpy(&inq, CMSG_DATA(cmsg), sizeof(int)); + ASSERT_EQ(inq, size); + } +} + +TEST_P(TcpSocketTest, TcpSCMPriority) { + char buf[1024]; + ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + int val = 1; + EXPECT_THAT(setsockopt(t_, SOL_TCP, TCP_INQ, &val, sizeof(val)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_TIMESTAMP, &val, sizeof(val)), + SyscallSucceedsWithValue(0)); + + struct msghdr msg = {}; + std::vector control( + CMSG_SPACE(sizeof(struct timeval) + CMSG_SPACE(sizeof(int)))); + struct iovec iov; + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + ASSERT_THAT(RetryEINTR(recvmsg)(t_, &msg, 0), + SyscallSucceedsWithValue(sizeof(buf))); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + // TODO(b/78348848): SO_TIMESTAMP isn't implemented for TCP sockets. + if (!IsRunningOnGvisor() || cmsg->cmsg_level == SOL_SOCKET) { + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); + + cmsg = CMSG_NXTHDR(&msg, cmsg); + ASSERT_NE(cmsg, nullptr); + } + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + ASSERT_EQ(cmsg->cmsg_level, SOL_TCP); + ASSERT_EQ(cmsg->cmsg_type, TCP_INQ); + + int inq = 0; + memcpy(&inq, CMSG_DATA(cmsg), sizeof(int)); + ASSERT_EQ(inq, 0); + + cmsg = CMSG_NXTHDR(&msg, cmsg); + ASSERT_EQ(cmsg, nullptr); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3