diff options
Diffstat (limited to 'pkg/sentry/socket/control/control.go')
-rw-r--r-- | pkg/sentry/socket/control/control.go | 63 |
1 files changed, 28 insertions, 35 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 6077b2150..4b036b323 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -17,6 +17,8 @@ package control import ( + "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" @@ -29,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "time" ) // SCMCredentials represents a SCM_CREDENTIALS socket control message. @@ -63,10 +64,10 @@ type RightsFiles []*fs.File // NewSCMRights creates a new SCM_RIGHTS socket control message representation // using local sentry FDs. -func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { +func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) { files := make(RightsFiles, 0, len(fds)) for _, fd := range fds { - file := t.GetFile(fd) + file := t.GetFile(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF @@ -486,26 +487,25 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) { var ( cmsgs socket.ControlMessages - fds linux.ControlMessageRights + fds []primitive.Int32 ) - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { return cmsgs, linuxerr.EINVAL } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, linuxerr.EINVAL } - if h.Length > uint64(len(buf)-i) { - return socket.ControlMessages{}, linuxerr.EINVAL - } - i += linux.SizeOfControlMessageHeader length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { + return socket.ControlMessages{}, linuxerr.EINVAL + } switch h.Level { case linux.SOL_SOCKET: @@ -518,11 +518,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) - } - - i += bits.AlignUp(length, width) + curFDs := make([]primitive.Int32, numRights) + primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize]) + fds = append(fds, curFDs...) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -530,23 +528,21 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + creds.UnmarshalUnsafe(buf) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, linuxerr.EINVAL } var ts linux.Timeval - ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(buf) cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true - i += bits.AlignUp(length, width) default: // Unknown message type. @@ -560,9 +556,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + tos.UnmarshalUnsafe(buf) cmsgs.IP.TOS = uint8(tos) - i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -571,19 +566,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) - + packetInfo.UnmarshalUnsafe(buf) cmsgs.IP.PacketInfo = packetInfo - i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -591,9 +583,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -606,18 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + tclass.UnmarshalUnsafe(buf) cmsgs.IP.TClass = uint32(tclass) - i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -625,9 +614,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -635,6 +623,11 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) default: return socket.ControlMessages{}, linuxerr.EINVAL } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] + } } if cmsgs.Unix.Credentials == nil { |