diff options
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 61 |
1 files changed, 53 insertions, 8 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 69eff7373..e57aed927 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -207,6 +207,10 @@ type commonEndpoint interface { // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and + // transport.Endpoint.GetSockOpt. + GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) } // SocketOperations encapsulates all the state needed to represent a network stack @@ -249,6 +253,10 @@ type SocketOperations struct { // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only // valid when timestampValid is true. It is protected by readMu. timestampNS int64 + + // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket + // level, because it takes into account data from readView. + sockOptInq bool } // New creates a new endpoint socket. @@ -634,6 +642,18 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( } return val, nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptInq { + val = 1 + } + return val, nil + } return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) } @@ -1048,6 +1068,15 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 return nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } @@ -1267,6 +1296,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * return syserr.TranslateNetstackError(err) } return nil + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1492,7 +1522,6 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_INQ, linux.TCP_KEEPCNT, linux.TCP_KEEPIDLE, linux.TCP_KEEPINTVL, @@ -1747,6 +1776,18 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq return 0, err } +func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { + if !s.sockOptInq { + return + } + rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { + return + } + cmsg.IP.HasInq = true + cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -1766,7 +1807,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) s.readMu.Unlock() - return n, 0, nil, 0, socket.ControlMessages{}, err + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, 0, nil, 0, cmsg, err } s.readMu.Lock() @@ -1779,8 +1822,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if !isPacket && peek && trunc { // MSG_TRUNC with MSG_PEEK on a TCP socket returns the // amount that could be read. - var rql tcpip.ReceiveQueueSizeOption - if err := s.Endpoint.GetSockOpt(&rql); err != nil { + rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } available := len(s.readView) + int(rql) @@ -1848,7 +1891,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe n = msgLen } - return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err) + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } func (s *SocketOperations) controlMessages() socket.ControlMessages { @@ -2086,9 +2131,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCINQ: - var v tcpip.ReceiveQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { |