diff options
Diffstat (limited to 'pkg/sentry/fs/host/socket_unsafe.go')
-rw-r--r-- | pkg/sentry/fs/host/socket_unsafe.go | 64 |
1 files changed, 40 insertions, 24 deletions
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index bf8da6867..5e4c5feed 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -19,29 +19,23 @@ import ( "unsafe" ) -// buildIovec builds an iovec slice from the given []byte slice. -func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) { - var length uintptr - iovecs := make([]syscall.Iovec, 0, 10) - for i := range bufs { - if l := len(bufs[i]); l > 0 { - length += uintptr(l) - iovecs = append(iovecs, syscall.Iovec{ - Base: &bufs[i][0], - Len: uint64(l), - }) - } - } - return length, iovecs -} - -func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) { +// fdReadVec receives from fd to bufs. +// +// If the total length of bufs is > maxlen, fdReadVec will do a partial read +// and err will indicate why the message was truncated. +func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) { flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC) if peek { flags |= syscall.MSG_PEEK } - length, iovecs := buildIovec(bufs) + // Always truncate the receive buffer. All socket types will truncate + // received messages. + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, 0, 0, err + } var msg syscall.Msghdr if len(control) != 0 { @@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt msg.Iov = &iovecs[0] msg.Iovlen = uint64(len(iovecs)) } + n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags) if e != 0 { + // N.B. prioritize the syscall error over the buildIovec error. return 0, 0, 0, e } + // Copy data back to bufs. + if intermediate != nil { + copyToMulti(bufs, intermediate) + } + if n > length { - return length, n, msg.Controllen, nil + return length, n, msg.Controllen, err } - return n, n, msg.Controllen, nil + return n, n, msg.Controllen, err } -func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) { - _, iovecs := buildIovec(bufs) +// fdWriteVec sends from bufs to fd. +// +// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a +// partial write and err will indicate why the message was truncated. +func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) { + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, length, err + } + + // Copy data to intermediate buf. + if intermediate != nil { + copyFromMulti(intermediate, bufs) + } var msg syscall.Msghdr if len(iovecs) > 0 { msg.Iov = &iovecs[0] msg.Iovlen = uint64(len(iovecs)) } + n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL) if e != 0 { - return 0, e + // N.B. prioritize the syscall error over the buildIovec error. + return 0, length, e } - return n, nil + return n, length, err } |