summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/control.go42
-rw-r--r--pkg/sentry/socket/socket.go5
-rw-r--r--pkg/sentry/socket/unix/io.go7
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go35
-rw-r--r--pkg/sentry/socket/unix/unix.go13
5 files changed, 72 insertions, 30 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index abda364c9..c0238691d 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -45,7 +45,10 @@ type SCMRights interface {
transport.RightsControlMessage
// Files returns up to max RightsFiles.
- Files(ctx context.Context, max int) RightsFiles
+ //
+ // Returned files are consumed and ownership is transferred to the caller.
+ // Subsequent calls to Files will return the next files.
+ Files(ctx context.Context, max int) (rf RightsFiles, truncated bool)
}
// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
@@ -71,14 +74,17 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
}
// Files implements SCMRights.Files.
-func (fs *RightsFiles) Files(ctx context.Context, max int) RightsFiles {
+func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
n := max
+ var trunc bool
if l := len(*fs); n > l {
n = l
+ } else if n < l {
+ trunc = true
}
rf := (*fs)[:n]
*fs = (*fs)[n:]
- return rf
+ return rf, trunc
}
// Clone implements transport.RightsControlMessage.Clone.
@@ -99,8 +105,8 @@ func (fs *RightsFiles) Release() {
}
// rightsFDs gets up to the specified maximum number of FDs.
-func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) []int32 {
- files := rights.Files(t, max)
+func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32, bool) {
+ files, trunc := rights.Files(t, max)
fds := make([]int32, 0, len(files))
for i := 0; i < max && len(files) > 0; i++ {
fd, err := t.FDMap().NewFDFrom(0, files[0], kernel.FDFlags{cloexec}, t.ThreadGroup().Limits())
@@ -114,19 +120,23 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) []int32
fds = append(fds, int32(fd))
}
- return fds
+ return fds, trunc
}
// PackRights packs as many FDs as will fit into the unused capacity of buf.
-func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte) []byte {
+func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flags int) ([]byte, int) {
maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
// Linux does not return any FDs if none fit.
if maxFDs <= 0 {
- return buf
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+ fds, trunc := rightsFDs(t, rights, cloexec, maxFDs)
+ if trunc {
+ flags |= linux.MSG_CTRUNC
}
- fds := rightsFDs(t, rights, cloexec, maxFDs)
align := t.Arch().Width()
- return putCmsg(buf, linux.SCM_RIGHTS, align, fds)
+ return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
}
// scmCredentials represents an SCM_CREDENTIALS socket control message.
@@ -176,7 +186,7 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
-func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte {
+func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
space := AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
@@ -193,7 +203,8 @@ func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte {
// a partial int32, so the length of the message will be
// min(aligned length, header + datas).
if space < linux.SizeOfControlMessageHeader {
- return buf
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
}
length := 4*len(data) + linux.SizeOfControlMessageHeader
@@ -205,11 +216,12 @@ func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte {
buf = putUint32(buf, msgType)
for _, d := range data {
if len(buf)+4 > cap(buf) {
+ flags |= linux.MSG_CTRUNC
break
}
buf = putUint32(buf, uint32(d))
}
- return alignSlice(buf, align)
+ return alignSlice(buf, align), flags
}
func putCmsgStruct(buf []byte, msgType uint32, align uint, data interface{}) []byte {
@@ -253,7 +265,7 @@ func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID,
// PackCredentials packs the credentials in the control message (or default
// credentials if none) into a buffer.
-func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte) []byte {
+func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int) ([]byte, int) {
align := t.Arch().Width()
// Default credentials if none are available.
@@ -265,7 +277,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte) []byte {
pid, uid, gid = creds.Credentials(t)
}
c := []int32{int32(pid), int32(uid), int32(gid)}
- return putCmsg(buf, linux.SCM_CREDENTIALS, align, c)
+ return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
}
// AlignUp rounds a length up to an alignment. align must be a power of 2.
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 7e840b452..9393acd28 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -87,6 +87,11 @@ type Socket interface {
// senderAddrLen is the address length to be returned to the application,
// not necessarily the actual length of the address.
//
+ // flags control how RecvMsg should be completed. msgFlags indicate how
+ // the RecvMsg call was completed. Note that control message truncation
+ // may still be required even if the MSG_CTRUNC bit is not set in
+ // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately.
+ //
// If err != nil, the recv was not successful.
RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 382911d51..5a1475ec2 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -72,13 +72,18 @@ type EndpointReader struct {
// Control contains the received control messages.
Control transport.ControlMessages
+
+ // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on
+ // the value of NumRights.
+ ControlTrunc bool
}
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
- n, ms, c, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
r.Control = c
+ r.ControlTrunc = ct
r.MsgSize = ms
if err != nil {
return int64(n), err.ToError()
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index d5f7f7aa8..b734b4c20 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -130,7 +130,11 @@ type Endpoint interface {
//
// msgLen is the length of the read message consumed for datagram Endpoints.
// msgLen is always the same as recvLen for stream Endpoints.
- RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *syserr.Error)
+ //
+ // CMTruncated indicates that the numRights hint was used to receive fewer
+ // than the total available SCM_RIGHTS FDs. Additional truncation may be
+ // required by the caller.
+ RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
@@ -288,7 +292,7 @@ type Receiver interface {
// See Endpoint.RecvMsg for documentation on shared arguments.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *syserr.Error)
+ Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -328,7 +332,7 @@ type queueReceiver struct {
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
var m *message
var notify bool
var err *syserr.Error
@@ -338,7 +342,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek
m, notify, err = q.readQueue.Dequeue()
}
if err != nil {
- return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
}
src := []byte(m.Data)
var copied uintptr
@@ -347,7 +351,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek
copied += uintptr(n)
src = src[n:]
}
- return copied, uintptr(len(m.Data)), m.Control, m.Address, notify, nil
+ return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil
}
// RecvNotify implements Receiver.RecvNotify.
@@ -440,7 +444,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
}
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -453,7 +457,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
// the next time Recv() is called.
m, n, err := q.readQueue.Dequeue()
if err != nil {
- return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
}
notify = n
q.buffer = []byte(m.Data)
@@ -469,7 +473,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
// Don't consume data since we are peeking.
copied, data, _ = vecCopy(data, q.buffer)
- return copied, copied, c, q.addr, notify, nil
+ return copied, copied, c, false, q.addr, notify, nil
}
// Consume data and control message since we are not peeking.
@@ -484,9 +488,11 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
c.Credentials = nil
}
+ var cmTruncated bool
if c.Rights != nil && numRights == 0 {
c.Rights.Release()
c.Rights = nil
+ cmTruncated = true
}
haveRights := c.Rights != nil
@@ -538,6 +544,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
if q.control.Rights != nil {
// Consume rights.
if numRights == 0 {
+ cmTruncated = true
q.control.Rights.Release()
} else {
c.Rights = q.control.Rights
@@ -546,7 +553,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
q.control.Rights = nil
}
}
- return copied, copied, c, q.addr, notify, nil
+ return copied, copied, c, cmTruncated, q.addr, notify, nil
}
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
@@ -775,18 +782,18 @@ func (e *baseEndpoint) Connected() bool {
}
// RecvMsg reads data and a control message from the endpoint.
-func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *syserr.Error) {
+func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
e.Lock()
if e.receiver == nil {
e.Unlock()
- return 0, 0, ControlMessages{}, syserr.ErrNotConnected
+ return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
}
- recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
e.Unlock()
if err != nil {
- return 0, 0, ControlMessages{}, err
+ return 0, 0, ControlMessages{}, false, err
}
if notify {
@@ -796,7 +803,7 @@ func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, pee
if addr != nil {
*addr = a
}
- return recvLen, msgLen, cms, nil
+ return recvLen, msgLen, cms, cmt, nil
}
// SendMsg writes data and a control message to the endpoint's peer.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index e9607aa01..26788ec31 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -490,6 +490,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if s.Passcred() {
// Credentials take priority if they are enabled and there is space.
wantCreds = rightsLen > 0
+ if !wantCreds {
+ msgFlags |= linux.MSG_CTRUNC
+ }
credLen := syscall.CmsgSpace(syscall.SizeofUcred)
rightsLen -= credLen
}
@@ -516,6 +519,10 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
if s.isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
@@ -546,12 +553,18 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if r.From != nil {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
if trunc {
// n and r.MsgSize are the same for streams.
total += int64(r.MsgSize)
} else {
total += n
}
+
if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil