summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/control.go16
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go61
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go38
3 files changed, 91 insertions, 24 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4f4a20dfe..4e95101b7 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -225,14 +225,14 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([
return alignSlice(buf, align), flags
}
-func putCmsgStruct(buf []byte, msgType uint32, align uint, data interface{}) []byte {
+func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte {
if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
return buf
}
ob := buf
buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader))
- buf = putUint32(buf, linux.SOL_SOCKET)
+ buf = putUint32(buf, msgLevel)
buf = putUint32(buf, msgType)
hdrBuf := buf
@@ -307,12 +307,24 @@ func alignSlice(buf []byte, align uint) []byte {
func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
return putCmsgStruct(
buf,
+ linux.SOL_SOCKET,
linux.SO_TIMESTAMP,
t.Arch().Width(),
linux.NsecToTimeval(timestamp),
)
}
+// PackInq packs a TCP_INQ socket control message.
+func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_TCP,
+ linux.TCP_INQ,
+ 4,
+ inq,
+ )
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) {
var (
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 69eff7373..e57aed927 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -207,6 +207,10 @@ type commonEndpoint interface {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+
+ // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
}
// SocketOperations encapsulates all the state needed to represent a network stack
@@ -249,6 +253,10 @@ type SocketOperations struct {
// timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
+
+ // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket
+ // level, because it takes into account data from readView.
+ sockOptInq bool
}
// New creates a new endpoint socket.
@@ -634,6 +642,18 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (
}
return val, nil
}
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return val, nil
+ }
return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
}
@@ -1048,6 +1068,15 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
return nil
}
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
@@ -1267,6 +1296,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
return syserr.TranslateNetstackError(err)
}
return nil
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1492,7 +1522,6 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) {
linux.TCP_FASTOPEN_CONNECT,
linux.TCP_FASTOPEN_KEY,
linux.TCP_FASTOPEN_NO_COOKIE,
- linux.TCP_INQ,
linux.TCP_KEEPCNT,
linux.TCP_KEEPIDLE,
linux.TCP_KEEPINTVL,
@@ -1747,6 +1776,18 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
return 0, err
}
+func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) {
+ if !s.sockOptInq {
+ return
+ }
+ rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return
+ }
+ cmsg.IP.HasInq = true
+ cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -1766,7 +1807,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
s.readMu.Unlock()
- return n, 0, nil, 0, socket.ControlMessages{}, err
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, 0, nil, 0, cmsg, err
}
s.readMu.Lock()
@@ -1779,8 +1822,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if !isPacket && peek && trunc {
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
// amount that could be read.
- var rql tcpip.ReceiveQueueSizeOption
- if err := s.Endpoint.GetSockOpt(&rql); err != nil {
+ rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := len(s.readView) + int(rql)
@@ -1848,7 +1891,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
n = msgLen
}
- return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
}
func (s *SocketOperations) controlMessages() socket.ControlMessages {
@@ -2086,9 +2131,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCINQ:
- var v tcpip.ReceiveQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index b0765ba55..7fb9cb1e0 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -179,6 +179,10 @@ type Endpoint interface {
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+ // GetSockOptInt gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+
// State returns the current state of the socket, as represented by Linux in
// procfs.
State() uint32
@@ -834,33 +838,39 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- case *tcpip.SendQueueSizeOption:
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v = int(e.receiver.RecvQueuedSize())
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
+ return v, nil
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
return nil
- case *tcpip.ReceiveQueueSizeOption:
+ case *tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
return tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize())
+ qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
e.Unlock()
if qs < 0 {
return tcpip.ErrQueueSizeNotSupported