summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/socket.go145
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go113
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go64
-rw-r--r--pkg/sentry/socket/unix/unix.go17
-rw-r--r--pkg/syserr/netstack.go2
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/link/rawfile/errors.go2
-rw-r--r--pkg/tcpip/tcpip.go2
-rw-r--r--pkg/tcpip/transport/queue/queue.go69
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/unix/connectionless.go6
-rw-r--r--pkg/tcpip/transport/unix/unix.go49
14 files changed, 381 insertions, 98 deletions
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index c34f1c26b..6d5640f0a 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -15,6 +15,7 @@ go_library(
"inode_state.go",
"ioctl_unsafe.go",
"socket.go",
+ "socket_iovec.go",
"socket_state.go",
"socket_unsafe.go",
"tty.go",
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index e11772946..68ebf6402 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -19,6 +19,7 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/refs"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
@@ -33,6 +34,11 @@ import (
"gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier"
)
+// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
+//
+// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
+const maxSendBufferSize = 8 << 20
+
// endpoint encapsulates the state needed to represent a host Unix socket.
//
// TODO: Remove/merge with ConnectedEndpoint.
@@ -41,15 +47,17 @@ import (
type endpoint struct {
queue waiter.Queue `state:"zerovalue"`
- // stype is the type of Unix socket. (Ex: unix.SockStream,
- // unix.SockSeqpacket, unix.SockDgram)
- stype unix.SockType `state:"nosave"`
-
// fd is the host fd backing this file.
fd int `state:"nosave"`
// If srfd >= 0, it is the host fd that fd was imported from.
srfd int `state:"wait"`
+
+ // stype is the type of Unix socket.
+ stype unix.SockType `state:"nosave"`
+
+ // sndbuf is the size of the send buffer.
+ sndbuf int `state:"nosave"`
}
func (e *endpoint) init() error {
@@ -67,12 +75,21 @@ func (e *endpoint) init() error {
if err != nil {
return err
}
+ e.stype = unix.SockType(stype)
+
+ e.sndbuf, err = syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return err
+ }
+ if e.sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", e.sndbuf)
+ return syserror.EINVAL
+ }
if err := syscall.SetNonblock(e.fd, true); err != nil {
return err
}
- e.stype = unix.SockType(stype)
return fdnotifier.AddFD(int32(e.fd), &e.queue)
}
@@ -189,13 +206,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 0
return nil
case *tcpip.SendBufferSizeOption:
- v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
- *o = tcpip.SendBufferSizeOption(v)
- return translateError(err)
+ *o = tcpip.SendBufferSizeOption(e.sndbuf)
+ return nil
case *tcpip.ReceiveBufferSizeOption:
- v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF)
- *o = tcpip.ReceiveBufferSizeOption(v)
- return translateError(err)
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ *o = tcpip.ReceiveBufferSizeOption(e.sndbuf)
+ return nil
case *tcpip.ReuseAddressOption:
v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
*o = tcpip.ReuseAddressOption(v)
@@ -240,33 +257,47 @@ func (e *endpoint) SendMsg(data [][]byte, controlMessages unix.ControlMessages,
if to != nil {
return 0, tcpip.ErrInvalidEndpointState
}
- return sendMsg(e.fd, data, controlMessages)
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := e.stype == unix.SockStream
+
+ return sendMsg(e.fd, data, controlMessages, e.sndbuf, truncate)
}
-func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages) (uintptr, *tcpip.Error) {
+func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages, maxlen int, truncate bool) (uintptr, *tcpip.Error) {
if !controlMessages.Empty() {
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := fdWriteVec(fd, data)
+ n, totalLen, err := fdWriteVec(fd, data, maxlen, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
return n, translateError(err)
}
// RecvMsg implements unix.Endpoint.RecvMsg.
func (e *endpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
- return recvMsg(e.fd, data, numRights, peek, addr)
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cm, err := recvMsg(e.fd, data, numRights, peek, addr, e.sndbuf)
+ if rl > 0 && err == tcpip.ErrWouldBlock {
+ // Message did not fill buffer; that's fine, no need to block.
+ err = nil
+ }
+ return rl, ml, cm, err
}
-func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
+func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress, maxlen int) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
var cm unet.ControlMessage
if numRights > 0 {
cm.EnableFDs(int(numRights))
}
- rl, ml, cl, err := fdReadVec(fd, data, []byte(cm), peek)
- if err == syscall.EAGAIN {
- return 0, 0, unix.ControlMessages{}, tcpip.ErrWouldBlock
- }
- if err != nil {
- return 0, 0, unix.ControlMessages{}, translateError(err)
+ rl, ml, cl, rerr := fdReadVec(fd, data, []byte(cm), peek, maxlen)
+ if rl == 0 && rerr != nil {
+ return 0, 0, unix.ControlMessages{}, translateError(rerr)
}
// Trim the control data if we received less than the full amount.
@@ -276,7 +307,7 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu
// Avoid extra allocations in the case where there isn't any control data.
if len(cm) == 0 {
- return rl, ml, unix.ControlMessages{}, nil
+ return rl, ml, unix.ControlMessages{}, translateError(rerr)
}
fds, err := cm.ExtractFDs()
@@ -285,9 +316,9 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu
}
if len(fds) == 0 {
- return rl, ml, unix.ControlMessages{}, nil
+ return rl, ml, unix.ControlMessages{}, translateError(rerr)
}
- return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil
+ return rl, ml, control.New(nil, nil, newSCMRights(fds)), translateError(rerr)
}
// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host FD
@@ -307,7 +338,27 @@ func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*Conne
return nil, tcpip.ErrInvalidEndpointState
}
- e := &ConnectedEndpoint{path: path, queue: queue, file: file}
+ stype, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return nil, translateError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return nil, translateError(err)
+ }
+ if sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", sndbuf)
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ e := &ConnectedEndpoint{
+ path: path,
+ queue: queue,
+ file: file,
+ stype: unix.SockType(stype),
+ sndbuf: sndbuf,
+ }
// AtomicRefCounters start off with a single reference. We need two.
e.ref.IncRef()
@@ -346,6 +397,17 @@ type ConnectedEndpoint struct {
// writeClosed is true if the FD has write shutdown or if it has been
// closed.
writeClosed bool
+
+ // stype is the type of Unix socket.
+ stype unix.SockType
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int
}
// Send implements unix.ConnectedEndpoint.Send.
@@ -355,7 +417,12 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess
if c.writeClosed {
return 0, false, tcpip.ErrClosedForSend
}
- n, err := sendMsg(c.file.FD(), data, controlMessages)
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == unix.SockStream
+
+ n, err := sendMsg(c.file.FD(), data, controlMessages, c.sndbuf, truncate)
// There is no need for the callee to call SendNotify because sendMsg uses
// the host's sendmsg(2) and the host kernel's queue.
return n, false, err
@@ -411,7 +478,15 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p
if c.readClosed {
return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, tcpip.ErrClosedForReceive
}
- rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil)
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil, c.sndbuf)
+ if rl > 0 && err == tcpip.ErrWouldBlock {
+ // Message did not fill buffer; that's fine, no need to block.
+ err = nil
+ }
+
// There is no need for the callee to call RecvNotify because recvMsg uses
// the host's recvmsg(2) and the host kernel's queue.
return rl, ml, cm, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, err
@@ -460,20 +535,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize.
func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
- v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
- if err != nil {
- return -1
- }
- return int64(v)
+ return int64(c.sndbuf)
}
// RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize.
func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
- v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF)
- if err != nil {
- return -1
- }
- return int64(v)
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
}
// Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release.
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
new file mode 100644
index 000000000..1a9587b90
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -0,0 +1,113 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += uintptr(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > uintptr(maxlen) {
+ if truncate {
+ stopLen = uintptr(maxlen)
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > maxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total uintptr
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := l
+ if total+uintptr(stop) > stopLen {
+ stop = int(stopLen - total)
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += uintptr(stop)
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index bf8da6867..5e4c5feed 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,29 +19,23 @@ import (
"unsafe"
)
-// buildIovec builds an iovec slice from the given []byte slice.
-func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) {
- var length uintptr
- iovecs := make([]syscall.Iovec, 0, 10)
- for i := range bufs {
- if l := len(bufs[i]); l > 0 {
- length += uintptr(l)
- iovecs = append(iovecs, syscall.Iovec{
- Base: &bufs[i][0],
- Len: uint64(l),
- })
- }
- }
- return length, iovecs
-}
-
-func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
if peek {
flags |= syscall.MSG_PEEK
}
- length, iovecs := buildIovec(bufs)
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, err
+ }
var msg syscall.Msghdr
if len(control) != 0 {
@@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
return 0, 0, 0, e
}
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
if n > length {
- return length, n, msg.Controllen, nil
+ return length, n, msg.Controllen, err
}
- return n, n, msg.Controllen, nil
+ return n, n, msg.Controllen, err
}
-func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) {
- _, iovecs := buildIovec(bufs)
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
var msg syscall.Msghdr
if len(iovecs) > 0 {
msg.Iov = &iovecs[0]
msg.Iovlen = uint64(len(iovecs))
}
+
n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
if e != 0 {
- return 0, e
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
}
- return n, nil
+ return n, length, err
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 1c22e78b3..e30378e60 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -378,7 +378,8 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
w.To = ep
}
- if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ n, err := src.CopyInTo(t, &w)
+ if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
return int(n), syserr.FromError(err)
}
@@ -388,15 +389,23 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
s.EventRegister(&e, waiter.EventOut)
defer s.EventUnregister(&e)
+ total := n
for {
- if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock {
- return int(n), syserr.FromError(err)
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ n, err = src.CopyInTo(t, &w)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
}
if err := t.Block(ch); err != nil {
- return 0, syserr.FromError(err)
+ break
}
}
+
+ return int(total), syserr.FromError(err)
}
// Passcred implements unix.Credentialer.Passcred.
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index c40fb7dbf..b9786b48f 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -78,6 +78,8 @@ var netstackErrorTranslations = map[*tcpip.Error]*Error{
tcpip.ErrNoLinkAddress: ErrHostDown,
tcpip.ErrBadAddress: ErrBadAddress,
tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
+ tcpip.ErrMessageTooLong: ErrMessageTooLong,
+ tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
}
// TranslateNetstackError converts an error from the tcpip package to a sentry
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 6f8a7a319..5bc74e65e 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -44,6 +44,7 @@ var (
ELIBBAD = error(syscall.ELIBBAD)
ELOOP = error(syscall.ELOOP)
EMFILE = error(syscall.EMFILE)
+ EMSGSIZE = error(syscall.EMSGSIZE)
ENAMETOOLONG = error(syscall.ENAMETOOLONG)
ENOATTR = ENODATA
ENODATA = error(syscall.ENODATA)
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index 7f213793e..de7593d9c 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -41,6 +41,8 @@ var translations = map[syscall.Errno]*tcpip.Error{
syscall.ENOTCONN: tcpip.ErrNotConnected,
syscall.ECONNRESET: tcpip.ErrConnectionReset,
syscall.ECONNABORTED: tcpip.ErrConnectionAborted,
+ syscall.EMSGSIZE: tcpip.ErrMessageTooLong,
+ syscall.ENOBUFS: tcpip.ErrNoBufferSpace,
}
// TranslateErrno translate an errno from the syscall package into a
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f5b5ec86b..cef27948c 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -98,6 +98,8 @@ var (
ErrNoLinkAddress = &Error{msg: "no remote link address"}
ErrBadAddress = &Error{msg: "bad address"}
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
+ ErrMessageTooLong = &Error{msg: "message too long"}
+ ErrNoBufferSpace = &Error{msg: "no buffer space available"}
)
// Errors related to Subnet
diff --git a/pkg/tcpip/transport/queue/queue.go b/pkg/tcpip/transport/queue/queue.go
index eb9ee8a3f..b3d2ea68b 100644
--- a/pkg/tcpip/transport/queue/queue.go
+++ b/pkg/tcpip/transport/queue/queue.go
@@ -24,12 +24,23 @@ import (
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
-// Entry implements Linker interface and has both Length and Release methods.
+// Entry implements Linker interface and has additional required methods.
type Entry interface {
ilist.Linker
+
+ // Length returns the number of bytes stored in the entry.
Length() int64
+
+ // Release releases any resources held by the entry.
Release()
+
+ // Peek returns a copy of the entry. It must be Released separately.
Peek() Entry
+
+ // Truncate reduces the number of bytes stored in the entry to n bytes.
+ //
+ // Preconditions: n <= Length().
+ Truncate(n int64)
}
// Queue is a buffer queue.
@@ -52,7 +63,7 @@ func New(ReaderQueue *waiter.Queue, WriterQueue *waiter.Queue, limit int64) *Que
}
// Close closes q for reading and writing. It is immediately not writable and
-// will become unreadble will no more data is pending.
+// will become unreadable when no more data is pending.
//
// Both the read and write queues must be notified after closing:
// q.ReaderQueue.Notify(waiter.EventIn)
@@ -86,38 +97,74 @@ func (q *Queue) IsReadable() bool {
return q.closed || q.dataList.Front() != nil
}
+// bufWritable returns true if there is space for writing.
+//
+// N.B. Linux only considers a unix socket "writable" if >75% of the buffer is
+// free.
+//
+// See net/unix/af_unix.c:unix_writeable.
+func (q *Queue) bufWritable() bool {
+ return 4*q.used < q.limit
+}
+
// IsWritable determines if q is currently writable.
func (q *Queue) IsWritable() bool {
q.mu.Lock()
defer q.mu.Unlock()
- return q.closed || q.used < q.limit
+ return q.closed || q.bufWritable()
}
// Enqueue adds an entry to the data queue if room is available.
//
+// If truncate is true, Enqueue may truncate the message beforing enqueuing it.
+// Otherwise, the entire message must fit. If n < e.Length(), err indicates why.
+//
// If notify is true, ReaderQueue.Notify must be called:
// q.ReaderQueue.Notify(waiter.EventIn)
-func (q *Queue) Enqueue(e Entry) (notify bool, err *tcpip.Error) {
+func (q *Queue) Enqueue(e Entry, truncate bool) (l int64, notify bool, err *tcpip.Error) {
q.mu.Lock()
if q.closed {
q.mu.Unlock()
- return false, tcpip.ErrClosedForSend
+ return 0, false, tcpip.ErrClosedForSend
+ }
+
+ free := q.limit - q.used
+
+ l = e.Length()
+
+ if l > free && truncate {
+ if free == 0 {
+ // Message can't fit right now.
+ q.mu.Unlock()
+ return 0, false, tcpip.ErrWouldBlock
+ }
+
+ e.Truncate(free)
+ l = e.Length()
+ err = tcpip.ErrWouldBlock
+ }
+
+ if l > q.limit {
+ // Message is too big to ever fit.
+ q.mu.Unlock()
+ return 0, false, tcpip.ErrMessageTooLong
}
- if q.used >= q.limit {
+ if l > free {
+ // Message can't fit right now.
q.mu.Unlock()
- return false, tcpip.ErrWouldBlock
+ return 0, false, tcpip.ErrWouldBlock
}
notify = q.dataList.Front() == nil
- q.used += e.Length()
+ q.used += l
q.dataList.PushBack(e)
q.mu.Unlock()
- return notify, nil
+ return l, notify, err
}
// Dequeue removes the first entry in the data queue, if one exists.
@@ -137,13 +184,13 @@ func (q *Queue) Dequeue() (e Entry, notify bool, err *tcpip.Error) {
return nil, false, err
}
- notify = q.used >= q.limit
+ notify = !q.bufWritable()
e = q.dataList.Front().(Entry)
q.dataList.Remove(e)
q.used -= e.Length()
- notify = notify && q.used < q.limit
+ notify = notify && q.bufWritable()
q.mu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 6143390b3..bed7ec6a6 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -315,6 +315,8 @@ func loadError(s string) *tcpip.Error {
tcpip.ErrNoLinkAddress,
tcpip.ErrBadAddress,
tcpip.ErrNetworkUnreachable,
+ tcpip.ErrMessageTooLong,
+ tcpip.ErrNoBufferSpace,
}
messageToError = make(map[string]*tcpip.Error)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6ed805357..840e95302 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,6 +15,7 @@
package udp
import (
+ "math"
"sync"
"gvisor.googlesource.com/gvisor/pkg/sleep"
@@ -264,6 +265,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, tcpip.ErrInvalidOptionValue
}
+ if p.Size() > math.MaxUint16 {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
+
to := opts.To
e.mu.RLock()
diff --git a/pkg/tcpip/transport/unix/connectionless.go b/pkg/tcpip/transport/unix/connectionless.go
index ebd4802b0..ae93c61d7 100644
--- a/pkg/tcpip/transport/unix/connectionless.go
+++ b/pkg/tcpip/transport/unix/connectionless.go
@@ -105,14 +105,12 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo
e.Lock()
n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
e.Unlock()
- if err != nil {
- return 0, err
- }
+
if notify {
connected.SendNotify()
}
- return n, nil
+ return n, err
}
// Type implements Endpoint.Type.
diff --git a/pkg/tcpip/transport/unix/unix.go b/pkg/tcpip/transport/unix/unix.go
index 0bb00df42..718606cd1 100644
--- a/pkg/tcpip/transport/unix/unix.go
+++ b/pkg/tcpip/transport/unix/unix.go
@@ -260,20 +260,28 @@ type message struct {
Address tcpip.FullAddress
}
-// Length returns number of bytes stored in the Message.
+// Length returns number of bytes stored in the message.
func (m *message) Length() int64 {
return int64(len(m.Data))
}
-// Release releases any resources held by the Message.
+// Release releases any resources held by the message.
func (m *message) Release() {
m.Control.Release()
}
+// Peek returns a copy of the message.
func (m *message) Peek() queue.Entry {
return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address}
}
+// Truncate reduces the length of the message payload to n bytes.
+//
+// Preconditions: n <= m.Length().
+func (m *message) Truncate(n int64) {
+ m.Data.CapLength(int(n))
+}
+
// A Receiver can be used to receive Messages.
type Receiver interface {
// Recv receives a single message. This method does not block.
@@ -623,23 +631,33 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
// Send implements ConnectedEndpoint.Send.
func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) {
- var l int
+ var l int64
for _, d := range data {
- l += len(d)
- }
- // Discard empty stream packets. Since stream sockets don't preserve
- // message boundaries, sending zero bytes is a no-op. In Linux, the
- // receiver actually uses a zero-length receive as an indication that the
- // stream was closed.
- if l == 0 && e.endpoint.Type() == SockStream {
- controlMessages.Release()
- return 0, false, nil
+ l += int64(len(d))
+ }
+
+ truncate := false
+ if e.endpoint.Type() == SockStream {
+ // Since stream sockets don't preserve message boundaries, we
+ // can write only as much of the message as fits in the queue.
+ truncate = true
+
+ // Discard empty stream packets. Since stream sockets don't
+ // preserve message boundaries, sending zero bytes is a no-op.
+ // In Linux, the receiver actually uses a zero-length receive
+ // as an indication that the stream was closed.
+ if l == 0 {
+ controlMessages.Release()
+ return 0, false, nil
+ }
}
+
v := make([]byte, 0, l)
for _, d := range data {
v = append(v, d...)
}
- notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from})
+
+ l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate)
return uintptr(l), notify, err
}
@@ -793,15 +811,12 @@ func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoin
n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
e.Unlock()
- if err != nil {
- return 0, err
- }
if notify {
e.connected.SendNotify()
}
- return n, nil
+ return n, err
}
// SetSockOpt sets a socket option. Currently not supported.