summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry
diff options
context:
space:
mode:
authorMichael Pratt <mpratt@google.com>2018-10-10 14:09:24 -0700
committerShentubot <shentubot@google.com>2018-10-10 14:10:17 -0700
commitddb34b3690c07f6c8efe2b96f89166145c4a7d3c (patch)
tree781361c955c356d26b484f572bc4ad41a250ab72 /pkg/sentry
parentb78552d30e0af4122710e01bc86cbde6bb412686 (diff)
Enforce message size limits and avoid host calls with too many iovecs
Currently, in the face of FileMem fragmentation and a large sendmsg or recvmsg call, host sockets may pass > 1024 iovecs to the host, which will immediately cause the host to return EMSGSIZE. When we detect this case, use a single intermediate buffer to pass to the kernel, copying to/from the src/dst buffer. To avoid creating unbounded intermediate buffers, enforce message size checks and truncation w.r.t. the send buffer size. The same functionality is added to netstack unix sockets for feature parity. PiperOrigin-RevId: 216590198 Change-Id: I719a32e71c7b1098d5097f35e6daf7dd5190eff7
Diffstat (limited to 'pkg/sentry')
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/socket.go145
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go113
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go64
-rw-r--r--pkg/sentry/socket/unix/unix.go17
5 files changed, 274 insertions, 66 deletions
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index c34f1c26b..6d5640f0a 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -15,6 +15,7 @@ go_library(
"inode_state.go",
"ioctl_unsafe.go",
"socket.go",
+ "socket_iovec.go",
"socket_state.go",
"socket_unsafe.go",
"tty.go",
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index e11772946..68ebf6402 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -19,6 +19,7 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/refs"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
@@ -33,6 +34,11 @@ import (
"gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier"
)
+// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
+//
+// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
+const maxSendBufferSize = 8 << 20
+
// endpoint encapsulates the state needed to represent a host Unix socket.
//
// TODO: Remove/merge with ConnectedEndpoint.
@@ -41,15 +47,17 @@ import (
type endpoint struct {
queue waiter.Queue `state:"zerovalue"`
- // stype is the type of Unix socket. (Ex: unix.SockStream,
- // unix.SockSeqpacket, unix.SockDgram)
- stype unix.SockType `state:"nosave"`
-
// fd is the host fd backing this file.
fd int `state:"nosave"`
// If srfd >= 0, it is the host fd that fd was imported from.
srfd int `state:"wait"`
+
+ // stype is the type of Unix socket.
+ stype unix.SockType `state:"nosave"`
+
+ // sndbuf is the size of the send buffer.
+ sndbuf int `state:"nosave"`
}
func (e *endpoint) init() error {
@@ -67,12 +75,21 @@ func (e *endpoint) init() error {
if err != nil {
return err
}
+ e.stype = unix.SockType(stype)
+
+ e.sndbuf, err = syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return err
+ }
+ if e.sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", e.sndbuf)
+ return syserror.EINVAL
+ }
if err := syscall.SetNonblock(e.fd, true); err != nil {
return err
}
- e.stype = unix.SockType(stype)
return fdnotifier.AddFD(int32(e.fd), &e.queue)
}
@@ -189,13 +206,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 0
return nil
case *tcpip.SendBufferSizeOption:
- v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
- *o = tcpip.SendBufferSizeOption(v)
- return translateError(err)
+ *o = tcpip.SendBufferSizeOption(e.sndbuf)
+ return nil
case *tcpip.ReceiveBufferSizeOption:
- v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF)
- *o = tcpip.ReceiveBufferSizeOption(v)
- return translateError(err)
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ *o = tcpip.ReceiveBufferSizeOption(e.sndbuf)
+ return nil
case *tcpip.ReuseAddressOption:
v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
*o = tcpip.ReuseAddressOption(v)
@@ -240,33 +257,47 @@ func (e *endpoint) SendMsg(data [][]byte, controlMessages unix.ControlMessages,
if to != nil {
return 0, tcpip.ErrInvalidEndpointState
}
- return sendMsg(e.fd, data, controlMessages)
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := e.stype == unix.SockStream
+
+ return sendMsg(e.fd, data, controlMessages, e.sndbuf, truncate)
}
-func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages) (uintptr, *tcpip.Error) {
+func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages, maxlen int, truncate bool) (uintptr, *tcpip.Error) {
if !controlMessages.Empty() {
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := fdWriteVec(fd, data)
+ n, totalLen, err := fdWriteVec(fd, data, maxlen, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
return n, translateError(err)
}
// RecvMsg implements unix.Endpoint.RecvMsg.
func (e *endpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
- return recvMsg(e.fd, data, numRights, peek, addr)
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cm, err := recvMsg(e.fd, data, numRights, peek, addr, e.sndbuf)
+ if rl > 0 && err == tcpip.ErrWouldBlock {
+ // Message did not fill buffer; that's fine, no need to block.
+ err = nil
+ }
+ return rl, ml, cm, err
}
-func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
+func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress, maxlen int) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
var cm unet.ControlMessage
if numRights > 0 {
cm.EnableFDs(int(numRights))
}
- rl, ml, cl, err := fdReadVec(fd, data, []byte(cm), peek)
- if err == syscall.EAGAIN {
- return 0, 0, unix.ControlMessages{}, tcpip.ErrWouldBlock
- }
- if err != nil {
- return 0, 0, unix.ControlMessages{}, translateError(err)
+ rl, ml, cl, rerr := fdReadVec(fd, data, []byte(cm), peek, maxlen)
+ if rl == 0 && rerr != nil {
+ return 0, 0, unix.ControlMessages{}, translateError(rerr)
}
// Trim the control data if we received less than the full amount.
@@ -276,7 +307,7 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu
// Avoid extra allocations in the case where there isn't any control data.
if len(cm) == 0 {
- return rl, ml, unix.ControlMessages{}, nil
+ return rl, ml, unix.ControlMessages{}, translateError(rerr)
}
fds, err := cm.ExtractFDs()
@@ -285,9 +316,9 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu
}
if len(fds) == 0 {
- return rl, ml, unix.ControlMessages{}, nil
+ return rl, ml, unix.ControlMessages{}, translateError(rerr)
}
- return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil
+ return rl, ml, control.New(nil, nil, newSCMRights(fds)), translateError(rerr)
}
// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host FD
@@ -307,7 +338,27 @@ func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*Conne
return nil, tcpip.ErrInvalidEndpointState
}
- e := &ConnectedEndpoint{path: path, queue: queue, file: file}
+ stype, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return nil, translateError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return nil, translateError(err)
+ }
+ if sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", sndbuf)
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ e := &ConnectedEndpoint{
+ path: path,
+ queue: queue,
+ file: file,
+ stype: unix.SockType(stype),
+ sndbuf: sndbuf,
+ }
// AtomicRefCounters start off with a single reference. We need two.
e.ref.IncRef()
@@ -346,6 +397,17 @@ type ConnectedEndpoint struct {
// writeClosed is true if the FD has write shutdown or if it has been
// closed.
writeClosed bool
+
+ // stype is the type of Unix socket.
+ stype unix.SockType
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int
}
// Send implements unix.ConnectedEndpoint.Send.
@@ -355,7 +417,12 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess
if c.writeClosed {
return 0, false, tcpip.ErrClosedForSend
}
- n, err := sendMsg(c.file.FD(), data, controlMessages)
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == unix.SockStream
+
+ n, err := sendMsg(c.file.FD(), data, controlMessages, c.sndbuf, truncate)
// There is no need for the callee to call SendNotify because sendMsg uses
// the host's sendmsg(2) and the host kernel's queue.
return n, false, err
@@ -411,7 +478,15 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p
if c.readClosed {
return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, tcpip.ErrClosedForReceive
}
- rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil)
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil, c.sndbuf)
+ if rl > 0 && err == tcpip.ErrWouldBlock {
+ // Message did not fill buffer; that's fine, no need to block.
+ err = nil
+ }
+
// There is no need for the callee to call RecvNotify because recvMsg uses
// the host's recvmsg(2) and the host kernel's queue.
return rl, ml, cm, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, err
@@ -460,20 +535,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize.
func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
- v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
- if err != nil {
- return -1
- }
- return int64(v)
+ return int64(c.sndbuf)
}
// RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize.
func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
- v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF)
- if err != nil {
- return -1
- }
- return int64(v)
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
}
// Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release.
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
new file mode 100644
index 000000000..1a9587b90
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -0,0 +1,113 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += uintptr(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > uintptr(maxlen) {
+ if truncate {
+ stopLen = uintptr(maxlen)
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > maxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total uintptr
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := l
+ if total+uintptr(stop) > stopLen {
+ stop = int(stopLen - total)
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += uintptr(stop)
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index bf8da6867..5e4c5feed 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,29 +19,23 @@ import (
"unsafe"
)
-// buildIovec builds an iovec slice from the given []byte slice.
-func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) {
- var length uintptr
- iovecs := make([]syscall.Iovec, 0, 10)
- for i := range bufs {
- if l := len(bufs[i]); l > 0 {
- length += uintptr(l)
- iovecs = append(iovecs, syscall.Iovec{
- Base: &bufs[i][0],
- Len: uint64(l),
- })
- }
- }
- return length, iovecs
-}
-
-func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
if peek {
flags |= syscall.MSG_PEEK
}
- length, iovecs := buildIovec(bufs)
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, err
+ }
var msg syscall.Msghdr
if len(control) != 0 {
@@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
return 0, 0, 0, e
}
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
if n > length {
- return length, n, msg.Controllen, nil
+ return length, n, msg.Controllen, err
}
- return n, n, msg.Controllen, nil
+ return n, n, msg.Controllen, err
}
-func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) {
- _, iovecs := buildIovec(bufs)
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
var msg syscall.Msghdr
if len(iovecs) > 0 {
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
if e != 0 {
- return 0, e
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
}
- return n, nil
+ return n, length, err
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 1c22e78b3..e30378e60 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -378,7 +378,8 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
w.To = ep
}
- if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ n, err := src.CopyInTo(t, &w)
+ if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
return int(n), syserr.FromError(err)
}
@@ -388,15 +389,23 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
s.EventRegister(&e, waiter.EventOut)
defer s.EventUnregister(&e)
+ total := n
for {
- if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock {
- return int(n), syserr.FromError(err)
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ n, err = src.CopyInTo(t, &w)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
}
if err := t.Block(ch); err != nil {
- return 0, syserr.FromError(err)
+ break
}
}
+
+ return int(total), syserr.FromError(err)
}
// Passcred implements unix.Credentialer.Passcred.