summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs/host/socket_unsafe.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/fs/host/socket_unsafe.go')
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go64
1 files changed, 40 insertions, 24 deletions
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index bf8da6867..5e4c5feed 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,29 +19,23 @@ import (
"unsafe"
)
-// buildIovec builds an iovec slice from the given []byte slice.
-func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) {
- var length uintptr
- iovecs := make([]syscall.Iovec, 0, 10)
- for i := range bufs {
- if l := len(bufs[i]); l > 0 {
- length += uintptr(l)
- iovecs = append(iovecs, syscall.Iovec{
- Base: &bufs[i][0],
- Len: uint64(l),
- })
- }
- }
- return length, iovecs
-}
-
-func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
if peek {
flags |= syscall.MSG_PEEK
}
- length, iovecs := buildIovec(bufs)
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, err
+ }
var msg syscall.Msghdr
if len(control) != 0 {
@@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
return 0, 0, 0, e
}
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
if n > length {
- return length, n, msg.Controllen, nil
+ return length, n, msg.Controllen, err
}
- return n, n, msg.Controllen, nil
+ return n, n, msg.Controllen, err
}
-func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) {
- _, iovecs := buildIovec(bufs)
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
var msg syscall.Msghdr
if len(iovecs) > 0 {
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
if e != 0 {
- return 0, e
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
}
- return n, nil
+ return n, length, err
}