diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/control/control.go | 42 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/io.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 35 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 13 |
5 files changed, 72 insertions, 30 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index abda364c9..c0238691d 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -45,7 +45,10 @@ type SCMRights interface { transport.RightsControlMessage // Files returns up to max RightsFiles. - Files(ctx context.Context, max int) RightsFiles + // + // Returned files are consumed and ownership is transferred to the caller. + // Subsequent calls to Files will return the next files. + Files(ctx context.Context, max int) (rf RightsFiles, truncated bool) } // RightsFiles represents a SCM_RIGHTS socket control message. A reference is @@ -71,14 +74,17 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { } // Files implements SCMRights.Files. -func (fs *RightsFiles) Files(ctx context.Context, max int) RightsFiles { +func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) { n := max + var trunc bool if l := len(*fs); n > l { n = l + } else if n < l { + trunc = true } rf := (*fs)[:n] *fs = (*fs)[n:] - return rf + return rf, trunc } // Clone implements transport.RightsControlMessage.Clone. @@ -99,8 +105,8 @@ func (fs *RightsFiles) Release() { } // rightsFDs gets up to the specified maximum number of FDs. -func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) []int32 { - files := rights.Files(t, max) +func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32, bool) { + files, trunc := rights.Files(t, max) fds := make([]int32, 0, len(files)) for i := 0; i < max && len(files) > 0; i++ { fd, err := t.FDMap().NewFDFrom(0, files[0], kernel.FDFlags{cloexec}, t.ThreadGroup().Limits()) @@ -114,19 +120,23 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) []int32 fds = append(fds, int32(fd)) } - return fds + return fds, trunc } // PackRights packs as many FDs as will fit into the unused capacity of buf. -func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte) []byte { +func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flags int) ([]byte, int) { maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4 // Linux does not return any FDs if none fit. if maxFDs <= 0 { - return buf + flags |= linux.MSG_CTRUNC + return buf, flags + } + fds, trunc := rightsFDs(t, rights, cloexec, maxFDs) + if trunc { + flags |= linux.MSG_CTRUNC } - fds := rightsFDs(t, rights, cloexec, maxFDs) align := t.Arch().Width() - return putCmsg(buf, linux.SCM_RIGHTS, align, fds) + return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds) } // scmCredentials represents an SCM_CREDENTIALS socket control message. @@ -176,7 +186,7 @@ func putUint32(buf []byte, n uint32) []byte { // putCmsg writes a control message header and as much data as will fit into // the unused capacity of a buffer. -func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte { +func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { space := AlignDown(cap(buf)-len(buf), 4) // We can't write to space that doesn't exist, so if we are going to align @@ -193,7 +203,8 @@ func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte { // a partial int32, so the length of the message will be // min(aligned length, header + datas). if space < linux.SizeOfControlMessageHeader { - return buf + flags |= linux.MSG_CTRUNC + return buf, flags } length := 4*len(data) + linux.SizeOfControlMessageHeader @@ -205,11 +216,12 @@ func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte { buf = putUint32(buf, msgType) for _, d := range data { if len(buf)+4 > cap(buf) { + flags |= linux.MSG_CTRUNC break } buf = putUint32(buf, uint32(d)) } - return alignSlice(buf, align) + return alignSlice(buf, align), flags } func putCmsgStruct(buf []byte, msgType uint32, align uint, data interface{}) []byte { @@ -253,7 +265,7 @@ func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, // PackCredentials packs the credentials in the control message (or default // credentials if none) into a buffer. -func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte) []byte { +func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int) ([]byte, int) { align := t.Arch().Width() // Default credentials if none are available. @@ -265,7 +277,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte) []byte { pid, uid, gid = creds.Credentials(t) } c := []int32{int32(pid), int32(uid), int32(gid)} - return putCmsg(buf, linux.SCM_CREDENTIALS, align, c) + return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c) } // AlignUp rounds a length up to an alignment. align must be a power of 2. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 7e840b452..9393acd28 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -87,6 +87,11 @@ type Socket interface { // senderAddrLen is the address length to be returned to the application, // not necessarily the actual length of the address. // + // flags control how RecvMsg should be completed. msgFlags indicate how + // the RecvMsg call was completed. Note that control message truncation + // may still be required even if the MSG_CTRUNC bit is not set in + // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately. + // // If err != nil, the recv was not successful. RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 382911d51..5a1475ec2 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -72,13 +72,18 @@ type EndpointReader struct { // Control contains the received control messages. Control transport.ControlMessages + + // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on + // the value of NumRights. + ControlTrunc bool } // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) { - n, ms, c, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From) + n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From) r.Control = c + r.ControlTrunc = ct r.MsgSize = ms if err != nil { return int64(n), err.ToError() diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index d5f7f7aa8..b734b4c20 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -130,7 +130,11 @@ type Endpoint interface { // // msgLen is the length of the read message consumed for datagram Endpoints. // msgLen is always the same as recvLen for stream Endpoints. - RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *syserr.Error) + // + // CMTruncated indicates that the numRights hint was used to receive fewer + // than the total available SCM_RIGHTS FDs. Additional truncation may be + // required by the caller. + RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error) // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. @@ -288,7 +292,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -328,7 +332,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -338,7 +342,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek m, notify, err = q.readQueue.Dequeue() } if err != nil { - return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err } src := []byte(m.Data) var copied uintptr @@ -347,7 +351,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek copied += uintptr(n) src = src[n:] } - return copied, uintptr(len(m.Data)), m.Control, m.Address, notify, nil + return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil } // RecvNotify implements Receiver.RecvNotify. @@ -440,7 +444,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -453,7 +457,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint // the next time Recv() is called. m, n, err := q.readQueue.Dequeue() if err != nil { - return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err } notify = n q.buffer = []byte(m.Data) @@ -469,7 +473,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint // Don't consume data since we are peeking. copied, data, _ = vecCopy(data, q.buffer) - return copied, copied, c, q.addr, notify, nil + return copied, copied, c, false, q.addr, notify, nil } // Consume data and control message since we are not peeking. @@ -484,9 +488,11 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint c.Credentials = nil } + var cmTruncated bool if c.Rights != nil && numRights == 0 { c.Rights.Release() c.Rights = nil + cmTruncated = true } haveRights := c.Rights != nil @@ -538,6 +544,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint if q.control.Rights != nil { // Consume rights. if numRights == 0 { + cmTruncated = true q.control.Rights.Release() } else { c.Rights = q.control.Rights @@ -546,7 +553,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint q.control.Rights = nil } } - return copied, copied, c, q.addr, notify, nil + return copied, copied, c, cmTruncated, q.addr, notify, nil } // A ConnectedEndpoint is an Endpoint that can be used to send Messages. @@ -775,18 +782,18 @@ func (e *baseEndpoint) Connected() bool { } // RecvMsg reads data and a control message from the endpoint. -func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *syserr.Error) { +func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) { e.Lock() if e.receiver == nil { e.Unlock() - return 0, 0, ControlMessages{}, syserr.ErrNotConnected + return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected } - recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) e.Unlock() if err != nil { - return 0, 0, ControlMessages{}, err + return 0, 0, ControlMessages{}, false, err } if notify { @@ -796,7 +803,7 @@ func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, pee if addr != nil { *addr = a } - return recvLen, msgLen, cms, nil + return recvLen, msgLen, cms, cmt, nil } // SendMsg writes data and a control message to the endpoint's peer. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index e9607aa01..26788ec31 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -490,6 +490,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if s.Passcred() { // Credentials take priority if they are enabled and there is space. wantCreds = rightsLen > 0 + if !wantCreds { + msgFlags |= linux.MSG_CTRUNC + } credLen := syscall.CmsgSpace(syscall.SizeofUcred) rightsLen -= credLen } @@ -516,6 +519,10 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } + if r.ControlTrunc { + msgFlags |= linux.MSG_CTRUNC + } + if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { if s.isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC @@ -546,12 +553,18 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if r.From != nil { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } + + if r.ControlTrunc { + msgFlags |= linux.MSG_CTRUNC + } + if trunc { // n and r.MsgSize are the same for streams. total += int64(r.MsgSize) } else { total += n } + if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { if total > 0 { err = nil |