diff options
author | Ian Gudger <igudger@google.com> | 2019-04-29 21:20:05 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-29 21:21:08 -0700 |
commit | 81ecd8b6eab7457b331762626f8c210fec3504e6 (patch) | |
tree | 4c5e5aaf3ac8ff475657c66671dd6828938ae45e /pkg/sentry/socket/unix | |
parent | 2843f2a956f5ef23e621f571f5c3e6a1e4a8223a (diff) |
Implement the MSG_CTRUNC msghdr flag for Unix sockets.
Updates google/gvisor#206
PiperOrigin-RevId: 245880573
Change-Id: Ifa715e98d47f64b8a32b04ae9378d6cd6bd4025e
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r-- | pkg/sentry/socket/unix/io.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 35 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 13 |
3 files changed, 40 insertions, 15 deletions
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 382911d51..5a1475ec2 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -72,13 +72,18 @@ type EndpointReader struct { // Control contains the received control messages. Control transport.ControlMessages + + // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on + // the value of NumRights. + ControlTrunc bool } // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) { - n, ms, c, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From) + n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From) r.Control = c + r.ControlTrunc = ct r.MsgSize = ms if err != nil { return int64(n), err.ToError() diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index d5f7f7aa8..b734b4c20 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -130,7 +130,11 @@ type Endpoint interface { // // msgLen is the length of the read message consumed for datagram Endpoints. // msgLen is always the same as recvLen for stream Endpoints. - RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *syserr.Error) + // + // CMTruncated indicates that the numRights hint was used to receive fewer + // than the total available SCM_RIGHTS FDs. Additional truncation may be + // required by the caller. + RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error) // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. @@ -288,7 +292,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -328,7 +332,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -338,7 +342,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek m, notify, err = q.readQueue.Dequeue() } if err != nil { - return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err } src := []byte(m.Data) var copied uintptr @@ -347,7 +351,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek copied += uintptr(n) src = src[n:] } - return copied, uintptr(len(m.Data)), m.Control, m.Address, notify, nil + return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil } // RecvNotify implements Receiver.RecvNotify. @@ -440,7 +444,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -453,7 +457,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint // the next time Recv() is called. m, n, err := q.readQueue.Dequeue() if err != nil { - return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err } notify = n q.buffer = []byte(m.Data) @@ -469,7 +473,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint // Don't consume data since we are peeking. copied, data, _ = vecCopy(data, q.buffer) - return copied, copied, c, q.addr, notify, nil + return copied, copied, c, false, q.addr, notify, nil } // Consume data and control message since we are not peeking. @@ -484,9 +488,11 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint c.Credentials = nil } + var cmTruncated bool if c.Rights != nil && numRights == 0 { c.Rights.Release() c.Rights = nil + cmTruncated = true } haveRights := c.Rights != nil @@ -538,6 +544,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint if q.control.Rights != nil { // Consume rights. if numRights == 0 { + cmTruncated = true q.control.Rights.Release() } else { c.Rights = q.control.Rights @@ -546,7 +553,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint q.control.Rights = nil } } - return copied, copied, c, q.addr, notify, nil + return copied, copied, c, cmTruncated, q.addr, notify, nil } // A ConnectedEndpoint is an Endpoint that can be used to send Messages. @@ -775,18 +782,18 @@ func (e *baseEndpoint) Connected() bool { } // RecvMsg reads data and a control message from the endpoint. -func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *syserr.Error) { +func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) { e.Lock() if e.receiver == nil { e.Unlock() - return 0, 0, ControlMessages{}, syserr.ErrNotConnected + return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected } - recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) e.Unlock() if err != nil { - return 0, 0, ControlMessages{}, err + return 0, 0, ControlMessages{}, false, err } if notify { @@ -796,7 +803,7 @@ func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, pee if addr != nil { *addr = a } - return recvLen, msgLen, cms, nil + return recvLen, msgLen, cms, cmt, nil } // SendMsg writes data and a control message to the endpoint's peer. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index e9607aa01..26788ec31 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -490,6 +490,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if s.Passcred() { // Credentials take priority if they are enabled and there is space. wantCreds = rightsLen > 0 + if !wantCreds { + msgFlags |= linux.MSG_CTRUNC + } credLen := syscall.CmsgSpace(syscall.SizeofUcred) rightsLen -= credLen } @@ -516,6 +519,10 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } + if r.ControlTrunc { + msgFlags |= linux.MSG_CTRUNC + } + if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { if s.isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC @@ -546,12 +553,18 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if r.From != nil { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } + + if r.ControlTrunc { + msgFlags |= linux.MSG_CTRUNC + } + if trunc { // n and r.MsgSize are the same for streams. total += int64(r.MsgSize) } else { total += n } + if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { if total > 0 { err = nil |