diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/fs/host/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket.go | 145 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket_iovec.go | 113 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket_unsafe.go | 64 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 17 | ||||
-rw-r--r-- | pkg/syserr/netstack.go | 2 | ||||
-rw-r--r-- | pkg/syserror/syserror.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/link/rawfile/errors.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/queue/queue.go | 69 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/unix/connectionless.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/unix/unix.go | 49 |
14 files changed, 381 insertions, 98 deletions
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index c34f1c26b..6d5640f0a 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -15,6 +15,7 @@ go_library( "inode_state.go", "ioctl_unsafe.go", "socket.go", + "socket_iovec.go", "socket_state.go", "socket_unsafe.go", "tty.go", diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index e11772946..68ebf6402 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -19,6 +19,7 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/fd" + "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/refs" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" @@ -33,6 +34,11 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier" ) +// maxSendBufferSize is the maximum host send buffer size allowed for endpoint. +// +// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max). +const maxSendBufferSize = 8 << 20 + // endpoint encapsulates the state needed to represent a host Unix socket. // // TODO: Remove/merge with ConnectedEndpoint. @@ -41,15 +47,17 @@ import ( type endpoint struct { queue waiter.Queue `state:"zerovalue"` - // stype is the type of Unix socket. (Ex: unix.SockStream, - // unix.SockSeqpacket, unix.SockDgram) - stype unix.SockType `state:"nosave"` - // fd is the host fd backing this file. fd int `state:"nosave"` // If srfd >= 0, it is the host fd that fd was imported from. srfd int `state:"wait"` + + // stype is the type of Unix socket. + stype unix.SockType `state:"nosave"` + + // sndbuf is the size of the send buffer. + sndbuf int `state:"nosave"` } func (e *endpoint) init() error { @@ -67,12 +75,21 @@ func (e *endpoint) init() error { if err != nil { return err } + e.stype = unix.SockType(stype) + + e.sndbuf, err = syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err != nil { + return err + } + if e.sndbuf > maxSendBufferSize { + log.Warningf("Socket send buffer too large: %d", e.sndbuf) + return syserror.EINVAL + } if err := syscall.SetNonblock(e.fd, true); err != nil { return err } - e.stype = unix.SockType(stype) return fdnotifier.AddFD(int32(e.fd), &e.queue) } @@ -189,13 +206,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 0 return nil case *tcpip.SendBufferSizeOption: - v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) - *o = tcpip.SendBufferSizeOption(v) - return translateError(err) + *o = tcpip.SendBufferSizeOption(e.sndbuf) + return nil case *tcpip.ReceiveBufferSizeOption: - v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF) - *o = tcpip.ReceiveBufferSizeOption(v) - return translateError(err) + // N.B. Unix sockets don't use the receive buffer. We'll claim it is + // the same size as the send buffer. + *o = tcpip.ReceiveBufferSizeOption(e.sndbuf) + return nil case *tcpip.ReuseAddressOption: v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR) *o = tcpip.ReuseAddressOption(v) @@ -240,33 +257,47 @@ func (e *endpoint) SendMsg(data [][]byte, controlMessages unix.ControlMessages, if to != nil { return 0, tcpip.ErrInvalidEndpointState } - return sendMsg(e.fd, data, controlMessages) + + // Since stream sockets don't preserve message boundaries, we can write + // only as much of the message as fits in the send buffer. + truncate := e.stype == unix.SockStream + + return sendMsg(e.fd, data, controlMessages, e.sndbuf, truncate) } -func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages) (uintptr, *tcpip.Error) { +func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages, maxlen int, truncate bool) (uintptr, *tcpip.Error) { if !controlMessages.Empty() { return 0, tcpip.ErrInvalidEndpointState } - n, err := fdWriteVec(fd, data) + n, totalLen, err := fdWriteVec(fd, data, maxlen, truncate) + if n < totalLen && err == nil { + // The host only returns a short write if it would otherwise + // block (and only for stream sockets). + err = syserror.EAGAIN + } return n, translateError(err) } // RecvMsg implements unix.Endpoint.RecvMsg. func (e *endpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) { - return recvMsg(e.fd, data, numRights, peek, addr) + // N.B. Unix sockets don't have a receive buffer, the send buffer + // serves both purposes. + rl, ml, cm, err := recvMsg(e.fd, data, numRights, peek, addr, e.sndbuf) + if rl > 0 && err == tcpip.ErrWouldBlock { + // Message did not fill buffer; that's fine, no need to block. + err = nil + } + return rl, ml, cm, err } -func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) { +func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress, maxlen int) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) { var cm unet.ControlMessage if numRights > 0 { cm.EnableFDs(int(numRights)) } - rl, ml, cl, err := fdReadVec(fd, data, []byte(cm), peek) - if err == syscall.EAGAIN { - return 0, 0, unix.ControlMessages{}, tcpip.ErrWouldBlock - } - if err != nil { - return 0, 0, unix.ControlMessages{}, translateError(err) + rl, ml, cl, rerr := fdReadVec(fd, data, []byte(cm), peek, maxlen) + if rl == 0 && rerr != nil { + return 0, 0, unix.ControlMessages{}, translateError(rerr) } // Trim the control data if we received less than the full amount. @@ -276,7 +307,7 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu // Avoid extra allocations in the case where there isn't any control data. if len(cm) == 0 { - return rl, ml, unix.ControlMessages{}, nil + return rl, ml, unix.ControlMessages{}, translateError(rerr) } fds, err := cm.ExtractFDs() @@ -285,9 +316,9 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu } if len(fds) == 0 { - return rl, ml, unix.ControlMessages{}, nil + return rl, ml, unix.ControlMessages{}, translateError(rerr) } - return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil + return rl, ml, control.New(nil, nil, newSCMRights(fds)), translateError(rerr) } // NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host FD @@ -307,7 +338,27 @@ func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*Conne return nil, tcpip.ErrInvalidEndpointState } - e := &ConnectedEndpoint{path: path, queue: queue, file: file} + stype, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE) + if err != nil { + return nil, translateError(err) + } + + sndbuf, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err != nil { + return nil, translateError(err) + } + if sndbuf > maxSendBufferSize { + log.Warningf("Socket send buffer too large: %d", sndbuf) + return nil, tcpip.ErrInvalidEndpointState + } + + e := &ConnectedEndpoint{ + path: path, + queue: queue, + file: file, + stype: unix.SockType(stype), + sndbuf: sndbuf, + } // AtomicRefCounters start off with a single reference. We need two. e.ref.IncRef() @@ -346,6 +397,17 @@ type ConnectedEndpoint struct { // writeClosed is true if the FD has write shutdown or if it has been // closed. writeClosed bool + + // stype is the type of Unix socket. + stype unix.SockType + + // sndbuf is the size of the send buffer. + // + // N.B. When this is smaller than the host size, we present it via + // GetSockOpt and message splitting/rejection in SendMsg, but do not + // prevent lots of small messages from filling the real send buffer + // size on the host. + sndbuf int } // Send implements unix.ConnectedEndpoint.Send. @@ -355,7 +417,12 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMess if c.writeClosed { return 0, false, tcpip.ErrClosedForSend } - n, err := sendMsg(c.file.FD(), data, controlMessages) + + // Since stream sockets don't preserve message boundaries, we can write + // only as much of the message as fits in the send buffer. + truncate := c.stype == unix.SockStream + + n, err := sendMsg(c.file.FD(), data, controlMessages, c.sndbuf, truncate) // There is no need for the callee to call SendNotify because sendMsg uses // the host's sendmsg(2) and the host kernel's queue. return n, false, err @@ -411,7 +478,15 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p if c.readClosed { return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, tcpip.ErrClosedForReceive } - rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil) + + // N.B. Unix sockets don't have a receive buffer, the send buffer + // serves both purposes. + rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil, c.sndbuf) + if rl > 0 && err == tcpip.ErrWouldBlock { + // Message did not fill buffer; that's fine, no need to block. + err = nil + } + // There is no need for the callee to call RecvNotify because recvMsg uses // the host's recvmsg(2) and the host kernel's queue. return rl, ml, cm, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, err @@ -460,20 +535,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize. func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { - v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - return -1 - } - return int64(v) + return int64(c.sndbuf) } // RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize. func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { - v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF) - if err != nil { - return -1 - } - return int64(v) + // N.B. Unix sockets don't use the receive buffer. We'll claim it is + // the same size as the send buffer. + return int64(c.sndbuf) } // Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release. diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go new file mode 100644 index 000000000..1a9587b90 --- /dev/null +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -0,0 +1,113 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/syserror" +) + +// maxIovs is the maximum number of iovecs to pass to the host. +var maxIovs = linux.UIO_MAXIOV + +// copyToMulti copies as many bytes from src to dst as possible. +func copyToMulti(dst [][]byte, src []byte) { + for _, d := range dst { + done := copy(d, src) + src = src[done:] + if len(src) == 0 { + break + } + } +} + +// copyFromMulti copies as many bytes from src to dst as possible. +func copyFromMulti(dst []byte, src [][]byte) { + for _, s := range src { + done := copy(dst, s) + dst = dst[done:] + if len(dst) == 0 { + break + } + } +} + +// buildIovec builds an iovec slice from the given []byte slice. +// +// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error. +// +// If length < the total length of bufs, err indicates why, even when returning +// a truncated iovec. +// +// If intermediate != nil, iovecs references intermediate rather than bufs and +// the caller must copy to/from bufs as necessary. +func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovecs []syscall.Iovec, intermediate []byte, err error) { + var iovsRequired int + for _, b := range bufs { + length += uintptr(len(b)) + if len(b) > 0 { + iovsRequired++ + } + } + + stopLen := length + if length > uintptr(maxlen) { + if truncate { + stopLen = uintptr(maxlen) + err = syserror.EAGAIN + } else { + return 0, nil, nil, syserror.EMSGSIZE + } + } + + if iovsRequired > maxIovs { + // The kernel will reject our call if we pass this many iovs. + // Use a single intermediate buffer instead. + b := make([]byte, stopLen) + + return stopLen, []syscall.Iovec{{ + Base: &b[0], + Len: uint64(stopLen), + }}, b, err + } + + var total uintptr + iovecs = make([]syscall.Iovec, 0, iovsRequired) + for i := range bufs { + l := len(bufs[i]) + if l == 0 { + continue + } + + stop := l + if total+uintptr(stop) > stopLen { + stop = int(stopLen - total) + } + + iovecs = append(iovecs, syscall.Iovec{ + Base: &bufs[i][0], + Len: uint64(stop), + }) + + total += uintptr(stop) + if total >= stopLen { + break + } + } + + return total, iovecs, nil, err +} diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index bf8da6867..5e4c5feed 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -19,29 +19,23 @@ import ( "unsafe" ) -// buildIovec builds an iovec slice from the given []byte slice. -func buildIovec(bufs [][]byte) (uintptr, []syscall.Iovec) { - var length uintptr - iovecs := make([]syscall.Iovec, 0, 10) - for i := range bufs { - if l := len(bufs[i]); l > 0 { - length += uintptr(l) - iovecs = append(iovecs, syscall.Iovec{ - Base: &bufs[i][0], - Len: uint64(l), - }) - } - } - return length, iovecs -} - -func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) { +// fdReadVec receives from fd to bufs. +// +// If the total length of bufs is > maxlen, fdReadVec will do a partial read +// and err will indicate why the message was truncated. +func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) { flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC) if peek { flags |= syscall.MSG_PEEK } - length, iovecs := buildIovec(bufs) + // Always truncate the receive buffer. All socket types will truncate + // received messages. + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, 0, 0, err + } var msg syscall.Msghdr if len(control) != 0 { @@ -53,30 +47,52 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool) (readLen uintpt msg.Iov = &iovecs[0] msg.Iovlen = uint64(len(iovecs)) } + n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags) if e != 0 { + // N.B. prioritize the syscall error over the buildIovec error. return 0, 0, 0, e } + // Copy data back to bufs. + if intermediate != nil { + copyToMulti(bufs, intermediate) + } + if n > length { - return length, n, msg.Controllen, nil + return length, n, msg.Controllen, err } - return n, n, msg.Controllen, nil + return n, n, msg.Controllen, err } -func fdWriteVec(fd int, bufs [][]byte) (uintptr, error) { - _, iovecs := buildIovec(bufs) +// fdWriteVec sends from bufs to fd. +// +// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a +// partial write and err will indicate why the message was truncated. +func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) { + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, length, err + } + + // Copy data to intermediate buf. + if intermediate != nil { + copyFromMulti(intermediate, bufs) + } var msg syscall.Msghdr if len(iovecs) > 0 { msg.Iov = &iovecs[0] msg.Iovlen = uint64(len(iovecs)) } + n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL) if e != 0 { - return 0, e + // N.B. prioritize the syscall error over the buildIovec error. + return 0, length, e } - return n, nil + return n, length, err } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1c22e78b3..e30378e60 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -378,7 +378,8 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] w.To = ep } - if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + n, err := src.CopyInTo(t, &w) + if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { return int(n), syserr.FromError(err) } @@ -388,15 +389,23 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] s.EventRegister(&e, waiter.EventOut) defer s.EventUnregister(&e) + total := n for { - if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock { - return int(n), syserr.FromError(err) + // Shorten src to reflect bytes previously written. + src = src.DropFirst64(n) + + n, err = src.CopyInTo(t, &w) + total += n + if err != syserror.ErrWouldBlock { + break } if err := t.Block(ch); err != nil { - return 0, syserr.FromError(err) + break } } + + return int(total), syserr.FromError(err) } // Passcred implements unix.Credentialer.Passcred. diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index c40fb7dbf..b9786b48f 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -78,6 +78,8 @@ var netstackErrorTranslations = map[*tcpip.Error]*Error{ tcpip.ErrNoLinkAddress: ErrHostDown, tcpip.ErrBadAddress: ErrBadAddress, tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, + tcpip.ErrMessageTooLong: ErrMessageTooLong, + tcpip.ErrNoBufferSpace: ErrNoBufferSpace, } // TranslateNetstackError converts an error from the tcpip package to a sentry diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 6f8a7a319..5bc74e65e 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -44,6 +44,7 @@ var ( ELIBBAD = error(syscall.ELIBBAD) ELOOP = error(syscall.ELOOP) EMFILE = error(syscall.EMFILE) + EMSGSIZE = error(syscall.EMSGSIZE) ENAMETOOLONG = error(syscall.ENAMETOOLONG) ENOATTR = ENODATA ENODATA = error(syscall.ENODATA) diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index 7f213793e..de7593d9c 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -41,6 +41,8 @@ var translations = map[syscall.Errno]*tcpip.Error{ syscall.ENOTCONN: tcpip.ErrNotConnected, syscall.ECONNRESET: tcpip.ErrConnectionReset, syscall.ECONNABORTED: tcpip.ErrConnectionAborted, + syscall.EMSGSIZE: tcpip.ErrMessageTooLong, + syscall.ENOBUFS: tcpip.ErrNoBufferSpace, } // TranslateErrno translate an errno from the syscall package into a diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f5b5ec86b..cef27948c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -98,6 +98,8 @@ var ( ErrNoLinkAddress = &Error{msg: "no remote link address"} ErrBadAddress = &Error{msg: "bad address"} ErrNetworkUnreachable = &Error{msg: "network is unreachable"} + ErrMessageTooLong = &Error{msg: "message too long"} + ErrNoBufferSpace = &Error{msg: "no buffer space available"} ) // Errors related to Subnet diff --git a/pkg/tcpip/transport/queue/queue.go b/pkg/tcpip/transport/queue/queue.go index eb9ee8a3f..b3d2ea68b 100644 --- a/pkg/tcpip/transport/queue/queue.go +++ b/pkg/tcpip/transport/queue/queue.go @@ -24,12 +24,23 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter" ) -// Entry implements Linker interface and has both Length and Release methods. +// Entry implements Linker interface and has additional required methods. type Entry interface { ilist.Linker + + // Length returns the number of bytes stored in the entry. Length() int64 + + // Release releases any resources held by the entry. Release() + + // Peek returns a copy of the entry. It must be Released separately. Peek() Entry + + // Truncate reduces the number of bytes stored in the entry to n bytes. + // + // Preconditions: n <= Length(). + Truncate(n int64) } // Queue is a buffer queue. @@ -52,7 +63,7 @@ func New(ReaderQueue *waiter.Queue, WriterQueue *waiter.Queue, limit int64) *Que } // Close closes q for reading and writing. It is immediately not writable and -// will become unreadble will no more data is pending. +// will become unreadable when no more data is pending. // // Both the read and write queues must be notified after closing: // q.ReaderQueue.Notify(waiter.EventIn) @@ -86,38 +97,74 @@ func (q *Queue) IsReadable() bool { return q.closed || q.dataList.Front() != nil } +// bufWritable returns true if there is space for writing. +// +// N.B. Linux only considers a unix socket "writable" if >75% of the buffer is +// free. +// +// See net/unix/af_unix.c:unix_writeable. +func (q *Queue) bufWritable() bool { + return 4*q.used < q.limit +} + // IsWritable determines if q is currently writable. func (q *Queue) IsWritable() bool { q.mu.Lock() defer q.mu.Unlock() - return q.closed || q.used < q.limit + return q.closed || q.bufWritable() } // Enqueue adds an entry to the data queue if room is available. // +// If truncate is true, Enqueue may truncate the message beforing enqueuing it. +// Otherwise, the entire message must fit. If n < e.Length(), err indicates why. +// // If notify is true, ReaderQueue.Notify must be called: // q.ReaderQueue.Notify(waiter.EventIn) -func (q *Queue) Enqueue(e Entry) (notify bool, err *tcpip.Error) { +func (q *Queue) Enqueue(e Entry, truncate bool) (l int64, notify bool, err *tcpip.Error) { q.mu.Lock() if q.closed { q.mu.Unlock() - return false, tcpip.ErrClosedForSend + return 0, false, tcpip.ErrClosedForSend + } + + free := q.limit - q.used + + l = e.Length() + + if l > free && truncate { + if free == 0 { + // Message can't fit right now. + q.mu.Unlock() + return 0, false, tcpip.ErrWouldBlock + } + + e.Truncate(free) + l = e.Length() + err = tcpip.ErrWouldBlock + } + + if l > q.limit { + // Message is too big to ever fit. + q.mu.Unlock() + return 0, false, tcpip.ErrMessageTooLong } - if q.used >= q.limit { + if l > free { + // Message can't fit right now. q.mu.Unlock() - return false, tcpip.ErrWouldBlock + return 0, false, tcpip.ErrWouldBlock } notify = q.dataList.Front() == nil - q.used += e.Length() + q.used += l q.dataList.PushBack(e) q.mu.Unlock() - return notify, nil + return l, notify, err } // Dequeue removes the first entry in the data queue, if one exists. @@ -137,13 +184,13 @@ func (q *Queue) Dequeue() (e Entry, notify bool, err *tcpip.Error) { return nil, false, err } - notify = q.used >= q.limit + notify = !q.bufWritable() e = q.dataList.Front().(Entry) q.dataList.Remove(e) q.used -= e.Length() - notify = notify && q.used < q.limit + notify = notify && q.bufWritable() q.mu.Unlock() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 6143390b3..bed7ec6a6 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -315,6 +315,8 @@ func loadError(s string) *tcpip.Error { tcpip.ErrNoLinkAddress, tcpip.ErrBadAddress, tcpip.ErrNetworkUnreachable, + tcpip.ErrMessageTooLong, + tcpip.ErrNoBufferSpace, } messageToError = make(map[string]*tcpip.Error) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6ed805357..840e95302 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,6 +15,7 @@ package udp import ( + "math" "sync" "gvisor.googlesource.com/gvisor/pkg/sleep" @@ -264,6 +265,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, tcpip.ErrInvalidOptionValue } + if p.Size() > math.MaxUint16 { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } + to := opts.To e.mu.RLock() diff --git a/pkg/tcpip/transport/unix/connectionless.go b/pkg/tcpip/transport/unix/connectionless.go index ebd4802b0..ae93c61d7 100644 --- a/pkg/tcpip/transport/unix/connectionless.go +++ b/pkg/tcpip/transport/unix/connectionless.go @@ -105,14 +105,12 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo e.Lock() n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() - if err != nil { - return 0, err - } + if notify { connected.SendNotify() } - return n, nil + return n, err } // Type implements Endpoint.Type. diff --git a/pkg/tcpip/transport/unix/unix.go b/pkg/tcpip/transport/unix/unix.go index 0bb00df42..718606cd1 100644 --- a/pkg/tcpip/transport/unix/unix.go +++ b/pkg/tcpip/transport/unix/unix.go @@ -260,20 +260,28 @@ type message struct { Address tcpip.FullAddress } -// Length returns number of bytes stored in the Message. +// Length returns number of bytes stored in the message. func (m *message) Length() int64 { return int64(len(m.Data)) } -// Release releases any resources held by the Message. +// Release releases any resources held by the message. func (m *message) Release() { m.Control.Release() } +// Peek returns a copy of the message. func (m *message) Peek() queue.Entry { return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address} } +// Truncate reduces the length of the message payload to n bytes. +// +// Preconditions: n <= m.Length(). +func (m *message) Truncate(n int64) { + m.Data.CapLength(int(n)) +} + // A Receiver can be used to receive Messages. type Receiver interface { // Recv receives a single message. This method does not block. @@ -623,23 +631,33 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) // Send implements ConnectedEndpoint.Send. func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { - var l int + var l int64 for _, d := range data { - l += len(d) - } - // Discard empty stream packets. Since stream sockets don't preserve - // message boundaries, sending zero bytes is a no-op. In Linux, the - // receiver actually uses a zero-length receive as an indication that the - // stream was closed. - if l == 0 && e.endpoint.Type() == SockStream { - controlMessages.Release() - return 0, false, nil + l += int64(len(d)) + } + + truncate := false + if e.endpoint.Type() == SockStream { + // Since stream sockets don't preserve message boundaries, we + // can write only as much of the message as fits in the queue. + truncate = true + + // Discard empty stream packets. Since stream sockets don't + // preserve message boundaries, sending zero bytes is a no-op. + // In Linux, the receiver actually uses a zero-length receive + // as an indication that the stream was closed. + if l == 0 { + controlMessages.Release() + return 0, false, nil + } } + v := make([]byte, 0, l) for _, d := range data { v = append(v, d...) } - notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}) + + l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate) return uintptr(l), notify, err } @@ -793,15 +811,12 @@ func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoin n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() - if err != nil { - return 0, err - } if notify { e.connected.SendNotify() } - return n, nil + return n, err } // SetSockOpt sets a socket option. Currently not supported. |