diff options
author | Michael Pratt <mpratt@google.com> | 2018-10-10 14:09:24 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-10-10 14:10:17 -0700 |
commit | ddb34b3690c07f6c8efe2b96f89166145c4a7d3c (patch) | |
tree | 781361c955c356d26b484f572bc4ad41a250ab72 /pkg/sentry/fs/host/socket_unsafe.go | |
parent | b78552d30e0af4122710e01bc86cbde6bb412686 (diff) |
Enforce message size limits and avoid host calls with too many iovecs
Currently, in the face of FileMem fragmentation and a large sendmsg or
recvmsg call, host sockets may pass > 1024 iovecs to the host, which
will immediately cause the host to return EMSGSIZE.
When we detect this case, use a single intermediate buffer to pass to
the kernel, copying to/from the src/dst buffer.
To avoid creating unbounded intermediate buffers, enforce message size
checks and truncation w.r.t. the send buffer size. The same
functionality is added to netstack unix sockets for feature parity.
PiperOrigin-RevId: 216590198
Change-Id: I719a32e71c7b1098d5097f35e6daf7dd5190eff7
Diffstat (limited to 'pkg/sentry/fs/host/socket_unsafe.go')
-rw-r--r-- | pkg/sentry/fs/host/socket_unsafe.go | 64 |
1 files changed, 40 insertions, 24 deletions
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index bf8da6867..5e4c5feed 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -19,29 +19,23 @@ import ( "unsafe" ) -// buildIovec builds an iovec slice from the given []byte slice. -func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) { - var length uintptr - iovecs := make([]syscall.Iovec, 0, 10) - for i := range bufs { - if l := len(bufs[i]); l > 0 { - length += uintptr(l) - iovecs = append(iovecs, syscall.Iovec{ - Base: &bufs[i][0], - Len: uint64(l), - }) - } - } - return length, iovecs -} - -func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) { +// fdReadVec receives from fd to bufs. +// +// If the total length of bufs is > maxlen, fdReadVec will do a partial read +// and err will indicate why the message was truncated. +func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) { flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC) if peek { flags |= syscall.MSG_PEEK } - length, iovecs := buildIovec(bufs) + // Always truncate the receive buffer. All socket types will truncate + // received messages. + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, 0, 0, err + } var msg syscall.Msghdr if len(control) != 0 { @@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt msg.Iov = &iovecs[0] msg.Iovlen = uint64(len(iovecs)) } + n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags) if e != 0 { + // N.B. prioritize the syscall error over the buildIovec error. return 0, 0, 0, e } + // Copy data back to bufs. + if intermediate != nil { + copyToMulti(bufs, intermediate) + } + if n > length { - return length, n, msg.Controllen, nil + return length, n, msg.Controllen, err } - return n, n, msg.Controllen, nil + return n, n, msg.Controllen, err } -func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) { - _, iovecs := buildIovec(bufs) +// fdWriteVec sends from bufs to fd. +// +// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a +// partial write and err will indicate why the message was truncated. +func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) { + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, length, err + } + + // Copy data to intermediate buf. + if intermediate != nil { + copyFromMulti(intermediate, bufs) + } var msg syscall.Msghdr if len(iovecs) > 0 { msg.Iov = &iovecs[0] msg.Iovlen = uint64(len(iovecs)) } + n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL) if e != 0 { - return 0, e + // N.B. prioritize the syscall error over the buildIovec error. + return 0, length, e } - return n, nil + return n, length, err } |