From ab2c2575d61266725ce13dff570663464a171342 Mon Sep 17 00:00:00 2001 From: Brian Geffon Date: Mon, 11 Jun 2018 16:39:39 -0700 Subject: Rpcinet is incorrectly handling MSG_TRUNC with SOCK_STREAM SOCK_STREAM has special behavior with respect to MSG_TRUNC. Specifically, the data isn't actually copied back out to userspace when MSG_TRUNC is provided on a SOCK_STREAM. According to tcp(7): "Since version 2.4, Linux supports the use of MSG_TRUNC in the flags argument of recv(2) (and recvmsg(2)). This flag causes the received bytes of data to be discarded, rather than passed back in a caller-supplied buffer." PiperOrigin-RevId: 200134860 Change-Id: I70f17a5f60ffe7794c3f0cfafd131c069202e90d --- pkg/sentry/socket/rpcinet/socket.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index ffe947500..6f1a4fe01 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -465,9 +465,13 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags res, err := rpcRecvMsg(t, req) if err == nil { - n, e := dst.CopyOut(t, res.Data) - if e == nil && n != len(res.Data) { - panic("CopyOut failed to copy full buffer") + var e error + var n int + if len(res.Data) > 0 { + n, e = dst.CopyOut(t, res.Data) + if e == nil && n != len(res.Data) { + panic("CopyOut failed to copy full buffer") + } } return int(res.Length), res.Address.GetAddress(), res.Address.GetLength(), socket.ControlMessages{}, syserr.FromError(e) } @@ -484,9 +488,13 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags for { res, err := rpcRecvMsg(t, req) if err == nil { - n, e := dst.CopyOut(t, res.Data) - if e == nil && n != len(res.Data) { - panic("CopyOut failed to copy full buffer") + var e error + var n int + if len(res.Data) > 0 { + n, e = dst.CopyOut(t, res.Data) + if e == nil && n != len(res.Data) { + panic("CopyOut failed to copy full buffer") + } } return int(res.Length), res.Address.GetAddress(), res.Address.GetLength(), socket.ControlMessages{}, syserr.FromError(e) } -- cgit v1.2.3