summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs
diff options
context:
space:
mode:
authorIan Gudger <igudger@google.com>2019-04-29 21:20:05 -0700
committerShentubot <shentubot@google.com>2019-04-29 21:21:08 -0700
commit81ecd8b6eab7457b331762626f8c210fec3504e6 (patch)
tree4c5e5aaf3ac8ff475657c66671dd6828938ae45e /pkg/sentry/fs
parent2843f2a956f5ef23e621f571f5c3e6a1e4a8223a (diff)
Implement the MSG_CTRUNC msghdr flag for Unix sockets.
Updates google/gvisor#206 PiperOrigin-RevId: 245880573 Change-Id: Ifa715e98d47f64b8a32b04ae9378d6cd6bd4025e
Diffstat (limited to 'pkg/sentry/fs')
-rw-r--r--pkg/sentry/fs/host/control.go7
-rw-r--r--pkg/sentry/fs/host/socket.go16
-rw-r--r--pkg/sentry/fs/host/socket_test.go2
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go12
4 files changed, 21 insertions, 16 deletions
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
index 480f0c8f4..9ebb9bbb3 100644
--- a/pkg/sentry/fs/host/control.go
+++ b/pkg/sentry/fs/host/control.go
@@ -32,17 +32,20 @@ func newSCMRights(fds []int) control.SCMRights {
}
// Files implements control.SCMRights.Files.
-func (c *scmRights) Files(ctx context.Context, max int) control.RightsFiles {
+func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFiles, bool) {
n := max
+ var trunc bool
if l := len(c.fds); n > l {
n = l
+ } else if n < l {
+ trunc = true
}
rf := control.RightsFiles(fdsToFiles(ctx, c.fds[:n]))
// Only consume converted FDs (fdsToFiles may convert fewer than n FDs).
c.fds = c.fds[len(rf):]
- return rf
+ return rf, trunc
}
// Clone implements transport.RightsControlMessage.Clone.
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 3034e9441..3ed137006 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -282,11 +282,11 @@ func (c *ConnectedEndpoint) EventUpdate() {
}
// Recv implements transport.Receiver.Recv.
-func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, tcpip.FullAddress, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.readClosed {
- return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive
}
var cm unet.ControlMessage
@@ -296,7 +296,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p
// N.B. Unix sockets don't have a receive buffer, the send buffer
// serves both purposes.
- rl, ml, cl, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf)
+ rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf)
if rl > 0 && err != nil {
// We got some data, so all we need to do on error is return
// the data that we got. Short reads are fine, no need to
@@ -304,7 +304,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p
err = nil
}
if err != nil {
- return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, syserr.FromError(err)
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
}
// There is no need for the callee to call RecvNotify because fdReadVec uses
@@ -317,18 +317,18 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p
// Avoid extra allocations in the case where there isn't any control data.
if len(cm) == 0 {
- return rl, ml, transport.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
}
fds, err := cm.ExtractFDs()
if err != nil {
- return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, syserr.FromError(err)
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
}
if len(fds) == 0 {
- return rl, ml, transport.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
}
- return rl, ml, control.New(nil, nil, newSCMRights(fds)), tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+ return rl, ml, control.New(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
}
// close releases all resources related to the endpoint.
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index cc760a7e1..06392a65a 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -207,7 +207,7 @@ func TestSend(t *testing.T) {
func TestRecv(t *testing.T) {
e := ConnectedEndpoint{readClosed: true}
- if _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive {
+ if _, _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive {
t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, syserr.ErrClosedForReceive)
}
}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index 8873705c0..e57be0506 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -23,7 +23,7 @@ import (
//
// If the total length of bufs is > maxlen, fdReadVec will do a partial read
// and err will indicate why the message was truncated.
-func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) {
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, controlTrunc bool, err error) {
flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
if peek {
flags |= syscall.MSG_PEEK
@@ -34,7 +34,7 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re
length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
if err != nil && len(iovecs) == 0 {
// No partial write to do, return error immediately.
- return 0, 0, 0, err
+ return 0, 0, 0, false, err
}
var msg syscall.Msghdr
@@ -51,7 +51,7 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re
n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
if e != 0 {
// N.B. prioritize the syscall error over the buildIovec error.
- return 0, 0, 0, e
+ return 0, 0, 0, false, e
}
// Copy data back to bufs.
@@ -59,11 +59,13 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re
copyToMulti(bufs, intermediate)
}
+ controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
+
if n > length {
- return length, n, msg.Controllen, err
+ return length, n, msg.Controllen, controlTrunc, err
}
- return n, n, msg.Controllen, err
+ return n, n, msg.Controllen, controlTrunc, err
}
// fdWriteVec sends from bufs to fd.