From eefa817cfdb04ff07e7069396f21bd6ba2c89957 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Thu, 18 Jul 2019 15:39:47 -0700 Subject: net/tcp/setockopt: impelment setsockopt(fd, SOL_TCP, TCP_INQ) PiperOrigin-RevId: 258859507 --- pkg/sentry/socket/control/control.go | 16 +++++++-- pkg/sentry/socket/epsocket/epsocket.go | 61 +++++++++++++++++++++++++++----- pkg/sentry/socket/unix/transport/unix.go | 38 ++++++++++++-------- 3 files changed, 91 insertions(+), 24 deletions(-) (limited to 'pkg/sentry/socket') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4f4a20dfe..4e95101b7 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -225,14 +225,14 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ return alignSlice(buf, align), flags } -func putCmsgStruct(buf []byte, msgType uint32, align uint, data interface{}) []byte { +func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte { if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader { return buf } ob := buf buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader)) - buf = putUint32(buf, linux.SOL_SOCKET) + buf = putUint32(buf, msgLevel) buf = putUint32(buf, msgType) hdrBuf := buf @@ -307,12 +307,24 @@ func alignSlice(buf []byte, align uint) []byte { func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { return putCmsgStruct( buf, + linux.SOL_SOCKET, linux.SO_TIMESTAMP, t.Arch().Width(), linux.NsecToTimeval(timestamp), ) } +// PackInq packs a TCP_INQ socket control message. +func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_TCP, + linux.TCP_INQ, + 4, + inq, + ) +} + // Parse parses a raw socket control message into portable objects. func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) { var ( diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 69eff7373..e57aed927 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -207,6 +207,10 @@ type commonEndpoint interface { // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and + // transport.Endpoint.GetSockOpt. + GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) } // SocketOperations encapsulates all the state needed to represent a network stack @@ -249,6 +253,10 @@ type SocketOperations struct { // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only // valid when timestampValid is true. It is protected by readMu. timestampNS int64 + + // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket + // level, because it takes into account data from readView. + sockOptInq bool } // New creates a new endpoint socket. @@ -634,6 +642,18 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( } return val, nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptInq { + val = 1 + } + return val, nil + } return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) } @@ -1048,6 +1068,15 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 return nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } @@ -1267,6 +1296,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * return syserr.TranslateNetstackError(err) } return nil + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1492,7 +1522,6 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_INQ, linux.TCP_KEEPCNT, linux.TCP_KEEPIDLE, linux.TCP_KEEPINTVL, @@ -1747,6 +1776,18 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq return 0, err } +func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { + if !s.sockOptInq { + return + } + rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { + return + } + cmsg.IP.HasInq = true + cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -1766,7 +1807,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) s.readMu.Unlock() - return n, 0, nil, 0, socket.ControlMessages{}, err + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, 0, nil, 0, cmsg, err } s.readMu.Lock() @@ -1779,8 +1822,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if !isPacket && peek && trunc { // MSG_TRUNC with MSG_PEEK on a TCP socket returns the // amount that could be read. - var rql tcpip.ReceiveQueueSizeOption - if err := s.Endpoint.GetSockOpt(&rql); err != nil { + rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } available := len(s.readView) + int(rql) @@ -1848,7 +1891,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe n = msgLen } - return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err) + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } func (s *SocketOperations) controlMessages() socket.ControlMessages { @@ -2086,9 +2131,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCINQ: - var v tcpip.ReceiveQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index b0765ba55..7fb9cb1e0 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -179,6 +179,10 @@ type Endpoint interface { // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOptInt gets a socket option for simple cases when a return + // value has the int type. + GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + // State returns the current state of the socket, as represented by Linux in // procfs. State() uint32 @@ -834,33 +838,39 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.SendQueueSizeOption: +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v = int(e.receiver.RecvQueuedSize()) e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs + return v, nil + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: return nil - case *tcpip.ReceiveQueueSizeOption: + case *tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() return tcpip.ErrNotConnected } - qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize()) + qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) e.Unlock() if qs < 0 { return tcpip.ErrQueueSizeNotSupported -- cgit v1.2.3