diff options
Diffstat (limited to 'pkg/sentry/socket/epsocket/epsocket.go')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 108 |
1 files changed, 87 insertions, 21 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 8e65e1b3f..e57aed927 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -40,7 +40,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -208,6 +207,10 @@ type commonEndpoint interface { // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and + // transport.Endpoint.GetSockOpt. + GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) } // SocketOperations encapsulates all the state needed to represent a network stack @@ -250,6 +253,10 @@ type SocketOperations struct { // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only // valid when timestampValid is true. It is protected by readMu. timestampNS int64 + + // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket + // level, because it takes into account data from readView. + sockOptInq bool } // New creates a new endpoint socket. @@ -286,14 +293,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // GetAddress reads an sockaddr struct from the given address and converts it // to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 // addresses. -func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { +func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) { + if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -318,7 +325,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -331,7 +338,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -344,6 +351,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { } return out, nil + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, nil + default: return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -466,7 +476,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, false /* strict */) if err != nil { return err } @@ -499,7 +509,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -537,7 +547,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, wq, terr := s.Endpoint.Accept() if terr != nil { @@ -575,10 +585,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - fdFlags := kernel.FDFlags{ + fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, - } - fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) + }) t.Kernel().RecordSocket(ns) @@ -633,6 +642,18 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( } return val, nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptInq { + val = 1 + } + return val, nil + } return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) } @@ -866,6 +887,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(v), nil + case linux.TCP_MAXSEG: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.MaxSegOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1035,6 +1068,15 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 return nil } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } @@ -1218,6 +1260,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + case linux.TCP_MAXSEG: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -1246,6 +1296,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * return syserr.TranslateNetstackError(err) } return nil + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1471,7 +1522,6 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_INQ, linux.TCP_KEEPCNT, linux.TCP_KEEPIDLE, linux.TCP_KEEPINTVL, @@ -1726,6 +1776,18 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq return 0, err } +func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { + if !s.sockOptInq { + return + } + rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { + return + } + cmsg.IP.HasInq = true + cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -1745,7 +1807,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) s.readMu.Unlock() - return n, 0, nil, 0, socket.ControlMessages{}, err + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, 0, nil, 0, cmsg, err } s.readMu.Lock() @@ -1758,8 +1822,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if !isPacket && peek && trunc { // MSG_TRUNC with MSG_PEEK on a TCP socket returns the // amount that could be read. - var rql tcpip.ReceiveQueueSizeOption - if err := s.Endpoint.GetSockOpt(&rql); err != nil { + rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } available := len(s.readView) + int(rql) @@ -1827,7 +1891,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe n = msgLen } - return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err) + cmsg := s.controlMessages() + s.fillCmsgInq(&cmsg) + return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } func (s *SocketOperations) controlMessages() socket.ControlMessages { @@ -1924,7 +1990,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to) + addrBuf, err := GetAddress(s.family, to, true /* strict */) if err != nil { return 0, err } @@ -1993,7 +2059,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } // Ioctl implements fs.FileOperations.Ioctl. -func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. @@ -2065,9 +2131,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCINQ: - var v tcpip.ReceiveQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { |