summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs/host/socket_unsafe.go
diff options
context:
space:
mode:
authorMichael Pratt <mpratt@google.com>2018-10-10 14:09:24 -0700
committerShentubot <shentubot@google.com>2018-10-10 14:10:17 -0700
commitddb34b3690c07f6c8efe2b96f89166145c4a7d3c (patch)
tree781361c955c356d26b484f572bc4ad41a250ab72 /pkg/sentry/fs/host/socket_unsafe.go
parentb78552d30e0af4122710e01bc86cbde6bb412686 (diff)
Enforce message size limits and avoid host calls with too many iovecs
Currently, in the face of FileMem fragmentation and a large sendmsg or recvmsg call, host sockets may pass > 1024 iovecs to the host, which will immediately cause the host to return EMSGSIZE. When we detect this case, use a single intermediate buffer to pass to the kernel, copying to/from the src/dst buffer. To avoid creating unbounded intermediate buffers, enforce message size checks and truncation w.r.t. the send buffer size. The same functionality is added to netstack unix sockets for feature parity. PiperOrigin-RevId: 216590198 Change-Id: I719a32e71c7b1098d5097f35e6daf7dd5190eff7
Diffstat (limited to 'pkg/sentry/fs/host/socket_unsafe.go')
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go64
1 files changed, 40 insertions, 24 deletions
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index bf8da6867..5e4c5feed 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,29 +19,23 @@ import (
"unsafe"
)
-// buildIovec builds an iovec slice from the given []byte slice.
-func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) {
- var length uintptr
- iovecs := make([]syscall.Iovec, 0, 10)
- for i := range bufs {
- if l := len(bufs[i]); l > 0 {
- length += uintptr(l)
- iovecs = append(iovecs, syscall.Iovec{
- Base: &bufs[i][0],
- Len: uint64(l),
- })
- }
- }
- return length, iovecs
-}
-
-func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
if peek {
flags |= syscall.MSG_PEEK
}
- length, iovecs := buildIovec(bufs)
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, err
+ }
var msg syscall.Msghdr
if len(control) != 0 {
@@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
return 0, 0, 0, e
}
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
if n > length {
- return length, n, msg.Controllen, nil
+ return length, n, msg.Controllen, err
}
- return n, n, msg.Controllen, nil
+ return n, n, msg.Controllen, err
}
-func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) {
- _, iovecs := buildIovec(bufs)
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
var msg syscall.Msghdr
if len(iovecs) > 0 {
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
if e != 0 {
- return 0, e
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
}
- return n, nil
+ return n, length, err
}