diff options
author | Ian Lewis <ianlewis@google.com> | 2020-01-08 16:35:43 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-01-08 17:24:05 -0800 |
commit | fbb2c008e26a7e9d860f6cbf796ea7c375858502 (patch) | |
tree | 698e17f595ea9a689579cf009011171c4bbdcf9a /pkg/sentry/socket | |
parent | 565b64148314018e1234196182b55c4f01772e77 (diff) |
Return correct length with MSG_TRUNC for unix sockets.
This change calls a new Truncate method on the EndpointReader in RecvMsg for
both netlink and unix sockets. This allows readers such as sockets to peek at
the length of data without actually reading it to a buffer.
Fixes #993 #1240
PiperOrigin-RevId: 288800167
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/netlink/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 29 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/io.go | 13 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 23 |
4 files changed, 48 insertions, 18 deletions
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 79589e3c8..136821963 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 4a1b87a9a..d2e3644a6 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -500,29 +499,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have trunc := flags&linux.MSG_TRUNC != 0 r := unix.EndpointReader{ + Ctx: t, Endpoint: s.ep, Peek: flags&linux.MSG_PEEK != 0, } + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + // If MSG_TRUNC is set with a zero byte destination then we still need // to read the message and discard it, or in the case where MSG_PEEK is // set, leave it be. In both cases the full message length must be - // returned. However, the memory manager for the destination will not read - // the endpoint if the destination is zero length. - // - // In order for the endpoint to be read when the destination size is zero, - // we must cause a read of the endpoint by using a separate fake zero - // length block sequence and calling the EndpointReader directly. + // returned. if trunc && dst.Addrs.NumBytes() == 0 { - // Perform a read to a zero byte block sequence. We can ignore the - // original destination since it was zero bytes. The length returned by - // ReadToBlocks is ignored and we return the full message length to comply - // with MSG_TRUNC. - _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0)))) - return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } } - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC @@ -540,7 +539,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 2ec1a662d..2447f24ef 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -83,6 +83,19 @@ type EndpointReader struct { ControlTrunc bool } +// Truncate calls RecvMsg on the endpoint without writing to a destination. +func (r *EndpointReader) Truncate() error { + // Ignore bytes read since it will always be zero. + _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From) + r.Control = c + r.ControlTrunc = ct + r.MsgSize = ms + if err != nil { + return err.ToError() + } + return nil +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 885758054..91effe89a 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -544,8 +544,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if senderRequested { r.From = &tcpip.FullAddress{} } + + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. + if trunc && dst.Addrs.NumBytes() == 0 { + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } + + } + var total int64 - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { + if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait { var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { @@ -580,7 +599,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var from linux.SockAddr var fromLen uint32 if r.From != nil { |