summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs/host
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/fs/host')
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/socket.go145
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go113
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go64
4 files changed, 261 insertions, 62 deletions
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index c34f1c26b..6d5640f0a 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -15,6 +15,7 @@ go_library(
"inode_state.go",
"ioctl_unsafe.go",
"socket.go",
+ "socket_iovec.go",
"socket_state.go",
"socket_unsafe.go",
"tty.go",
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index e11772946..68ebf6402 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -19,6 +19,7 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/refs"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
@@ -33,6 +34,11 @@ import (
"gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier"
)
+// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
+//
+// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
+const maxSendBufferSize = 8 << 20
+
// endpoint encapsulates the state needed to represent a host Unix socket.
//
// TODO: Remove/merge with ConnectedEndpoint.
@@ -41,15 +47,17 @@ import (
type endpoint struct {
queue waiter.Queue `state:"zerovalue"`
- // stype is the type of Unix socket. (Ex: unix.SockStream,
- // unix.SockSeqpacket, unix.SockDgram)
- stype unix.SockType `state:"nosave"`
-
// fd is the host fd backing this file.
fd int `state:"nosave"`
// If srfd >= 0, it is the host fd that fd was imported from.
srfd int `state:"wait"`
+
+ // stype is the type of Unix socket.
+ stype unix.SockType `state:"nosave"`
+
+ // sndbuf is the size of the send buffer.
+ sndbuf int `state:"nosave"`
}
func (e *endpoint) init() error {
@@ -67,12 +75,21 @@ func (e *endpoint) init() error {
if err != nil {
return err
}
+ e.stype = unix.SockType(stype)
+
+ e.sndbuf, err = syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return err
+ }
+ if e.sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", e.sndbuf)
+ return syserror.EINVAL
+ }
if err := syscall.SetNonblock(e.fd, true); err != nil {
return err
}
- e.stype = unix.SockType(stype)
return fdnotifier.AddFD(int32(e.fd), &e.queue)
}
@@ -189,13 +206,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 0
return nil
case *tcpip.SendBufferSizeOption:
- v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
- *o = tcpip.SendBufferSizeOption(v)
- return translateError(err)
+ *o = tcpip.SendBufferSizeOption(e.sndbuf)
+ return nil
case *tcpip.ReceiveBufferSizeOption:
- v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF)
- *o = tcpip.ReceiveBufferSizeOption(v)
- return translateError(err)
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ *o = tcpip.ReceiveBufferSizeOption(e.sndbuf)
+ return nil
case *tcpip.ReuseAddressOption:
v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
*o = tcpip.ReuseAddressOption(v)
@@ -240,33 +257,47 @@ func (e *endpoint) SendMsg(data [][]byte, controlMessages unix.ControlMessages,
if to != nil {
return 0, tcpip.ErrInvalidEndpointState
}
- return sendMsg(e.fd, data, controlMessages)
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := e.stype == unix.SockStream
+
+ return sendMsg(e.fd, data, controlMessages, e.sndbuf, truncate)
}
-func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages) (uintptr, *tcpip.Error) {
+func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages, maxlen int, truncate bool) (uintptr, *tcpip.Error) {
if !controlMessages.Empty() {
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := fdWriteVec(fd, data)
+ n, totalLen, err := fdWriteVec(fd, data, maxlen, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
return n, translateError(err)
}
// RecvMsg implements unix.Endpoint.RecvMsg.
func (e *endpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
- return recvMsg(e.fd, data, numRights, peek, addr)
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cm, err := recvMsg(e.fd, data, numRights, peek, addr, e.sndbuf)
+ if rl > 0 && err == tcpip.ErrWouldBlock {
+ // Message did not fill buffer; that's fine, no need to block.
+ err = nil
+ }
+ return rl, ml, cm, err
}
-func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
+func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress, maxlen int) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
var cm unet.ControlMessage
if numRights > 0 {
cm.EnableFDs(int(numRights))
}
- rl, ml, cl, err := fdReadVec(fd, data, []byte(cm), peek)
- if err == syscall.EAGAIN {
- return 0, 0, unix.ControlMessages{}, tcpip.ErrWouldBlock
- }
- if err != nil {
- return 0, 0, unix.ControlMessages{}, translateError(err)
+ rl, ml, cl, rerr := fdReadVec(fd, data, []byte(cm), peek, maxlen)
+ if rl == 0 && rerr != nil {
+ return 0, 0, unix.ControlMessages{}, translateError(rerr)
}
// Trim the control data if we received less than the full amount.
@@ -276,7 +307,7 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu
// Avoid extra allocations in the case where there isn't any control data.
if len(cm) == 0 {
- return rl, ml, unix.ControlMessages{}, nil
+ return rl, ml, unix.ControlMessages{}, translateError(rerr)
}
fds, err := cm.ExtractFDs()
@@ -285,9 +316,9 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu
}
if len(fds) == 0 {
- return rl, ml, unix.ControlMessages{}, nil
+ return rl, ml, unix.ControlMessages{}, translateError(rerr)
}
- return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil
+ return rl, ml, control.New(nil, nil, newSCMRights(fds)), translateError(rerr)
}
// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host FD
@@ -307,7 +338,27 @@ func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*Conne
return nil, tcpip.ErrInvalidEndpointState
}
- e := &ConnectedEndpoint{path: path, queue: queue, file: file}
+ stype, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return nil, translateError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return nil, translateError(err)
+ }
+ if sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", sndbuf)
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ e := &ConnectedEndpoint{
+ path: path,
+ queue: queue,
+ file: file,
+ stype: unix.SockType(stype),
+ sndbuf: sndbuf,
+ }
// AtomicRefCounters start off with a single reference. We need two.
e.ref.IncRef()
@@ -346,6 +397,17 @@ type ConnectedEndpoint struct {
// writeClosed is true if the FD has write shutdown or if it has been
// closed.
writeClosed bool
+
+ // stype is the type of Unix socket.
+ stype unix.SockType
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int
}
// Send implements unix.ConnectedEndpoint.Send.
@@ -355,7 +417,12 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess
if c.writeClosed {
return 0, false, tcpip.ErrClosedForSend
}
- n, err := sendMsg(c.file.FD(), data, controlMessages)
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == unix.SockStream
+
+ n, err := sendMsg(c.file.FD(), data, controlMessages, c.sndbuf, truncate)
// There is no need for the callee to call SendNotify because sendMsg uses
// the host's sendmsg(2) and the host kernel's queue.
return n, false, err
@@ -411,7 +478,15 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p
if c.readClosed {
return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, tcpip.ErrClosedForReceive
}
- rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil)
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil, c.sndbuf)
+ if rl > 0 && err == tcpip.ErrWouldBlock {
+ // Message did not fill buffer; that's fine, no need to block.
+ err = nil
+ }
+
// There is no need for the callee to call RecvNotify because recvMsg uses
// the host's recvmsg(2) and the host kernel's queue.
return rl, ml, cm, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, err
@@ -460,20 +535,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize.
func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
- v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
- if err != nil {
- return -1
- }
- return int64(v)
+ return int64(c.sndbuf)
}
// RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize.
func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
- v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF)
- if err != nil {
- return -1
- }
- return int64(v)
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
}
// Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release.
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
new file mode 100644
index 000000000..1a9587b90
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -0,0 +1,113 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += uintptr(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > uintptr(maxlen) {
+ if truncate {
+ stopLen = uintptr(maxlen)
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > maxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total uintptr
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := l
+ if total+uintptr(stop) > stopLen {
+ stop = int(stopLen - total)
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += uintptr(stop)
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index bf8da6867..5e4c5feed 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,29 +19,23 @@ import (
"unsafe"
)
-// buildIovec builds an iovec slice from the given []byte slice.
-func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) {
- var length uintptr
- iovecs := make([]syscall.Iovec, 0, 10)
- for i := range bufs {
- if l := len(bufs[i]); l > 0 {
- length += uintptr(l)
- iovecs = append(iovecs, syscall.Iovec{
- Base: &bufs[i][0],
- Len: uint64(l),
- })
- }
- }
- return length, iovecs
-}
-
-func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
if peek {
flags |= syscall.MSG_PEEK
}
- length, iovecs := buildIovec(bufs)
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, err
+ }
var msg syscall.Msghdr
if len(control) != 0 {
@@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
return 0, 0, 0, e
}
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
if n > length {
- return length, n, msg.Controllen, nil
+ return length, n, msg.Controllen, err
}
- return n, n, msg.Controllen, nil
+ return n, n, msg.Controllen, err
}
-func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) {
- _, iovecs := buildIovec(bufs)
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
var msg syscall.Msghdr
if len(iovecs) > 0 {
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
if e != 0 {
- return 0, e
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
}
- return n, nil
+ return n, length, err
}