summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix
diff options
context:
space:
mode:
authorIan Gudger <igudger@google.com>2019-04-29 21:20:05 -0700
committerShentubot <shentubot@google.com>2019-04-29 21:21:08 -0700
commit81ecd8b6eab7457b331762626f8c210fec3504e6 (patch)
tree4c5e5aaf3ac8ff475657c66671dd6828938ae45e /pkg/sentry/socket/unix
parent2843f2a956f5ef23e621f571f5c3e6a1e4a8223a (diff)
Implement the MSG_CTRUNC msghdr flag for Unix sockets.
Updates google/gvisor#206 PiperOrigin-RevId: 245880573 Change-Id: Ifa715e98d47f64b8a32b04ae9378d6cd6bd4025e
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r--pkg/sentry/socket/unix/io.go7
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go35
-rw-r--r--pkg/sentry/socket/unix/unix.go13
3 files changed, 40 insertions, 15 deletions
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 382911d51..5a1475ec2 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -72,13 +72,18 @@ type EndpointReader struct {
// Control contains the received control messages.
Control transport.ControlMessages
+
+ // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on
+ // the value of NumRights.
+ ControlTrunc bool
}
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
- n, ms, c, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
r.Control = c
+ r.ControlTrunc = ct
r.MsgSize = ms
if err != nil {
return int64(n), err.ToError()
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index d5f7f7aa8..b734b4c20 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -130,7 +130,11 @@ type Endpoint interface {
//
// msgLen is the length of the read message consumed for datagram Endpoints.
// msgLen is always the same as recvLen for stream Endpoints.
- RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *syserr.Error)
+ //
+ // CMTruncated indicates that the numRights hint was used to receive fewer
+ // than the total available SCM_RIGHTS FDs. Additional truncation may be
+ // required by the caller.
+ RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
@@ -288,7 +292,7 @@ type Receiver interface {
// See Endpoint.RecvMsg for documentation on shared arguments.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *syserr.Error)
+ Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -328,7 +332,7 @@ type queueReceiver struct {
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
var m *message
var notify bool
var err *syserr.Error
@@ -338,7 +342,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek
m, notify, err = q.readQueue.Dequeue()
}
if err != nil {
- return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
}
src := []byte(m.Data)
var copied uintptr
@@ -347,7 +351,7 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek
copied += uintptr(n)
src = src[n:]
}
- return copied, uintptr(len(m.Data)), m.Control, m.Address, notify, nil
+ return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil
}
// RecvNotify implements Receiver.RecvNotify.
@@ -440,7 +444,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
}
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -453,7 +457,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
// the next time Recv() is called.
m, n, err := q.readQueue.Dequeue()
if err != nil {
- return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
}
notify = n
q.buffer = []byte(m.Data)
@@ -469,7 +473,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
// Don't consume data since we are peeking.
copied, data, _ = vecCopy(data, q.buffer)
- return copied, copied, c, q.addr, notify, nil
+ return copied, copied, c, false, q.addr, notify, nil
}
// Consume data and control message since we are not peeking.
@@ -484,9 +488,11 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
c.Credentials = nil
}
+ var cmTruncated bool
if c.Rights != nil && numRights == 0 {
c.Rights.Release()
c.Rights = nil
+ cmTruncated = true
}
haveRights := c.Rights != nil
@@ -538,6 +544,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
if q.control.Rights != nil {
// Consume rights.
if numRights == 0 {
+ cmTruncated = true
q.control.Rights.Release()
} else {
c.Rights = q.control.Rights
@@ -546,7 +553,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
q.control.Rights = nil
}
}
- return copied, copied, c, q.addr, notify, nil
+ return copied, copied, c, cmTruncated, q.addr, notify, nil
}
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
@@ -775,18 +782,18 @@ func (e *baseEndpoint) Connected() bool {
}
// RecvMsg reads data and a control message from the endpoint.
-func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *syserr.Error) {
+func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
e.Lock()
if e.receiver == nil {
e.Unlock()
- return 0, 0, ControlMessages{}, syserr.ErrNotConnected
+ return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
}
- recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
e.Unlock()
if err != nil {
- return 0, 0, ControlMessages{}, err
+ return 0, 0, ControlMessages{}, false, err
}
if notify {
@@ -796,7 +803,7 @@ func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, pee
if addr != nil {
*addr = a
}
- return recvLen, msgLen, cms, nil
+ return recvLen, msgLen, cms, cmt, nil
}
// SendMsg writes data and a control message to the endpoint's peer.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index e9607aa01..26788ec31 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -490,6 +490,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if s.Passcred() {
// Credentials take priority if they are enabled and there is space.
wantCreds = rightsLen > 0
+ if !wantCreds {
+ msgFlags |= linux.MSG_CTRUNC
+ }
credLen := syscall.CmsgSpace(syscall.SizeofUcred)
rightsLen -= credLen
}
@@ -516,6 +519,10 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
if s.isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
@@ -546,12 +553,18 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if r.From != nil {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
if trunc {
// n and r.MsgSize are the same for streams.
total += int64(r.MsgSize)
} else {
total += n
}
+
if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil