summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/epsocket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go61
1 files changed, 53 insertions, 8 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 69eff7373..e57aed927 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -207,6 +207,10 @@ type commonEndpoint interface {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+
+ // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
}
// SocketOperations encapsulates all the state needed to represent a network stack
@@ -249,6 +253,10 @@ type SocketOperations struct {
// timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
+
+ // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket
+ // level, because it takes into account data from readView.
+ sockOptInq bool
}
// New creates a new endpoint socket.
@@ -634,6 +642,18 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (
}
return val, nil
}
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return val, nil
+ }
return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
}
@@ -1048,6 +1068,15 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
return nil
}
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
@@ -1267,6 +1296,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
return syserr.TranslateNetstackError(err)
}
return nil
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1492,7 +1522,6 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) {
linux.TCP_FASTOPEN_CONNECT,
linux.TCP_FASTOPEN_KEY,
linux.TCP_FASTOPEN_NO_COOKIE,
- linux.TCP_INQ,
linux.TCP_KEEPCNT,
linux.TCP_KEEPIDLE,
linux.TCP_KEEPINTVL,
@@ -1747,6 +1776,18 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
return 0, err
}
+func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) {
+ if !s.sockOptInq {
+ return
+ }
+ rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return
+ }
+ cmsg.IP.HasInq = true
+ cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -1766,7 +1807,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
s.readMu.Unlock()
- return n, 0, nil, 0, socket.ControlMessages{}, err
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, 0, nil, 0, cmsg, err
}
s.readMu.Lock()
@@ -1779,8 +1822,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if !isPacket && peek && trunc {
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
// amount that could be read.
- var rql tcpip.ReceiveQueueSizeOption
- if err := s.Endpoint.GetSockOpt(&rql); err != nil {
+ rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := len(s.readView) + int(rql)
@@ -1848,7 +1891,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
n = msgLen
}
- return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
}
func (s *SocketOperations) controlMessages() socket.ControlMessages {
@@ -2086,9 +2131,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCINQ:
- var v tcpip.ReceiveQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {